1//===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "Detail/DimLvlMapParser.h"
12
13#include "mlir/Dialect/SparseTensor/IR/Enums.h"
14#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20#include "mlir/Dialect/Complex/IR/Complex.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/DialectImplementation.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/OpImplementation.h"
26#include "mlir/IR/PatternMatch.h"
27#include "llvm/ADT/Bitset.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/FormatVariadic.h"
30
31#define GET_ATTRDEF_CLASSES
32#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
33#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
34
35// Forward declarations, following custom print/parsing methods are referenced
36// by the generated code for SparseTensorTypes.td.
37static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
38 mlir::sparse_tensor::Level &,
39 mlir::sparse_tensor::Level &);
40static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
41 mlir::sparse_tensor::Level);
42
43#define GET_TYPEDEF_CLASSES
44#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
45
46using namespace mlir;
47using namespace mlir::sparse_tensor;
48
49// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
50// well.
51namespace mlir::sparse_tensor {
52llvm::hash_code hash_value(LevelType lt) {
53 return llvm::hash_value(static_cast<uint64_t>(lt));
54}
55} // namespace mlir::sparse_tensor
56
57//===----------------------------------------------------------------------===//
58// Local Convenience Methods.
59//===----------------------------------------------------------------------===//
60
61static constexpr bool acceptBitWidth(unsigned bitWidth) {
62 switch (bitWidth) {
63 case 0:
64 case 8:
65 case 16:
66 case 32:
67 case 64:
68 return true;
69 default:
70 return false;
71 }
72}
73
74static SmallVector<Size>
75getSparseFieldShape(const SparseTensorEncodingAttr enc,
76 std::optional<ArrayRef<int64_t>> dimShape) {
77 assert(enc);
78 // With only encoding, we can not determine the static shape for leading
79 // batch levels, we therefore return a dynamic shape memref instead.
80 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
81 if (dimShape.has_value()) {
82 // If the actual tensor shape is provided, we can then refine the leading
83 // batch dimension.
84 SmallVector<int64_t> lvlShape =
85 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
86 memrefShape.assign(lvlShape.begin(),
87 lvlShape.begin() + enc.getBatchLvlRank());
88 }
89 // Another dynamic dimension to store the sparse level.
90 memrefShape.push_back(ShapedType::kDynamic);
91 return memrefShape;
92}
93
94//===----------------------------------------------------------------------===//
95// SparseTensorDialect StorageLayout.
96//===----------------------------------------------------------------------===//
97
98static constexpr Level kInvalidLevel = -1u;
99static constexpr Level kInvalidFieldIndex = -1u;
100static constexpr FieldIndex kDataFieldStartingIdx = 0;
101
102void StorageLayout::foreachField(
103 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
104 LevelType)>
105 callback) const {
106 const auto lvlTypes = enc.getLvlTypes();
107 const Level lvlRank = enc.getLvlRank();
108 SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
109 FieldIndex fieldIdx = kDataFieldStartingIdx;
110
111 ArrayRef cooSegsRef = cooSegs;
112 // Per-level storage.
113 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
114 const auto lt = lvlTypes[l];
115 if (isWithPosLT(lt)) {
116 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
117 return;
118 }
119 if (isWithCrdLT(lt)) {
120 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
121 return;
122 }
123 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
124 if (!cooSegsRef.front().isSoA) {
125 // AoS COO, all singletons are fused into one memrefs. Skips the entire
126 // COO segement.
127 l = cooSegsRef.front().lvlRange.second;
128 } else {
129 // SoA COO, each singleton level has one memref.
130 l++;
131 }
132 // Expire handled COO segment.
133 cooSegsRef = cooSegsRef.drop_front();
134 } else {
135 // Non COO levels.
136 l++;
137 }
138 }
139 // The values array.
140 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
141 LevelFormat::Undef)))
142 return;
143 // Put metadata at the end.
144 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
145 LevelFormat::Undef)))
146 return;
147}
148
149void sparse_tensor::foreachFieldAndTypeInSparseTensor(
150 SparseTensorType stt,
151 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
152 LevelType)>
153 callback) {
154 assert(stt.hasEncoding());
155
156 SmallVector<int64_t> memrefShape =
157 getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
158
159 const Type specType = StorageSpecifierType::get(stt.getEncoding());
160 // memref<[batch] x ? x pos> positions
161 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
162 // memref<[batch] x ? x crd> coordinates
163 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
164 // memref<[batch] x ? x eltType> values
165 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
166
167 StorageLayout(stt).foreachField(callback: [specType, posMemType, crdMemType, valMemType,
168 callback](FieldIndex fieldIdx,
169 SparseTensorFieldKind fieldKind,
170 Level lvl, LevelType lt) -> bool {
171 switch (fieldKind) {
172 case SparseTensorFieldKind::StorageSpec:
173 return callback(specType, fieldIdx, fieldKind, lvl, lt);
174 case SparseTensorFieldKind::PosMemRef:
175 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
176 case SparseTensorFieldKind::CrdMemRef:
177 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
178 case SparseTensorFieldKind::ValMemRef:
179 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
180 };
181 llvm_unreachable("unrecognized field kind");
182 });
183}
184
185unsigned StorageLayout::getNumFields() const {
186 unsigned numFields = 0;
187 foreachField(callback: [&numFields](FieldIndex, SparseTensorFieldKind, Level,
188 LevelType) -> bool {
189 numFields++;
190 return true;
191 });
192 return numFields;
193}
194
195unsigned StorageLayout::getNumDataFields() const {
196 unsigned numFields = 0; // one value memref
197 foreachField(callback: [&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
198 LevelType) -> bool {
199 if (fidx >= kDataFieldStartingIdx)
200 numFields++;
201 return true;
202 });
203 numFields -= 1; // the last field is StorageSpecifier
204 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
205 return numFields;
206}
207
208std::pair<FieldIndex, unsigned>
209StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
210 std::optional<Level> lvl) const {
211 FieldIndex fieldIdx = kInvalidFieldIndex;
212 unsigned stride = 1;
213 if (kind == SparseTensorFieldKind::CrdMemRef) {
214 assert(lvl.has_value());
215 const Level cooStart = enc.getAoSCOOStart();
216 const Level lvlRank = enc.getLvlRank();
217 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
218 lvl = cooStart;
219 stride = lvlRank - cooStart;
220 }
221 }
222 foreachField(callback: [lvl, kind, &fieldIdx](FieldIndex fIdx,
223 SparseTensorFieldKind fKind, Level fLvl,
224 LevelType lt) -> bool {
225 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
226 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
227 fieldIdx = fIdx;
228 // Returns false to break the iteration.
229 return false;
230 }
231 return true;
232 });
233 assert(fieldIdx != kInvalidFieldIndex);
234 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
235}
236
237//===----------------------------------------------------------------------===//
238// SparseTensorDialect Attribute Methods.
239//===----------------------------------------------------------------------===//
240
241std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
242 return isDynamic(v) ? std::nullopt
243 : std::make_optional(static_cast<uint64_t>(v));
244}
245
246std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
247 return getStatic(getOffset());
248}
249
250std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
251 return getStatic(getStride());
252}
253
254std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
255 return getStatic(getSize());
256}
257
258bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
259 return isDynamic(getOffset()) && isDynamic(getStride()) &&
260 isDynamic(getSize());
261}
262
263std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
264 return isDynamic(v) ? "?" : std::to_string(v);
265}
266
267void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
268 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
269 os << '(';
270 os << getStaticString(getOffset());
271 os << ", ";
272 os << getStaticString(getSize());
273 os << ", ";
274 os << getStaticString(getStride());
275 os << ')';
276}
277
278void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
279 print(printer.getStream());
280}
281
282static ParseResult parseOptionalStaticSlice(int64_t &result,
283 AsmParser &parser) {
284 auto parseResult = parser.parseOptionalInteger(result);
285 if (parseResult.has_value()) {
286 if (parseResult.value().succeeded() && result < 0) {
287 parser.emitError(
288 loc: parser.getCurrentLocation(),
289 message: "expect positive value or ? for slice offset/size/stride");
290 return failure();
291 }
292 return parseResult.value();
293 }
294
295 // Else, and '?' which represented dynamic slice
296 result = SparseTensorDimSliceAttr::kDynamic;
297 return parser.parseQuestion();
298}
299
300Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
301 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
302
303 if (failed(parser.parseLParen()) ||
304 failed(parseOptionalStaticSlice(offset, parser)) ||
305 failed(parser.parseComma()) ||
306 failed(parseOptionalStaticSlice(size, parser)) ||
307 failed(parser.parseComma()) ||
308 failed(parseOptionalStaticSlice(stride, parser)) ||
309 failed(parser.parseRParen()))
310 return {};
311
312 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
313 offset, size, stride);
314}
315
316LogicalResult
317SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
318 int64_t offset, int64_t size, int64_t stride) {
319 if (!isDynamic(offset) && offset < 0)
320 return emitError() << "expect non-negative value or ? for slice offset";
321 if (!isDynamic(size) && size <= 0)
322 return emitError() << "expect positive value or ? for slice size";
323 if (!isDynamic(stride) && stride <= 0)
324 return emitError() << "expect positive value or ? for slice stride";
325 return success();
326}
327
328SparseTensorEncodingAttr
329SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
330 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
331 return SparseTensorEncodingAttr::get(
332 getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(),
333 getCrdWidth(), getExplicitVal(), getImplicitVal());
334}
335
336SparseTensorEncodingAttr
337SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
338 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
339}
340
341SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
342 return withDimToLvl(AffineMap());
343}
344
345SparseTensorEncodingAttr
346SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
347 unsigned crdWidth) const {
348 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
349 return SparseTensorEncodingAttr::get(
350 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth,
351 crdWidth, getExplicitVal(), getImplicitVal());
352}
353
354SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
355 return withBitWidths(0, 0);
356}
357
358SparseTensorEncodingAttr
359SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
360 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
361 return SparseTensorEncodingAttr::get(
362 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
363 getCrdWidth(), explicitVal, getImplicitVal());
364}
365
366SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
367 return withExplicitVal(Attribute());
368}
369
370SparseTensorEncodingAttr
371SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
372 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
373 return SparseTensorEncodingAttr::get(
374 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
375 getCrdWidth(), getExplicitVal(), implicitVal);
376}
377
378SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
379 return withImplicitVal(Attribute());
380}
381
382SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
383 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
384 return SparseTensorEncodingAttr::get(
385 getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(),
386 getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices);
387}
388
389SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
390 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
391}
392
393uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
394 ArrayRef<LevelType> lvlTypes = getLvlTypes();
395 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
396 return std::distance(lastBatch, lvlTypes.rend());
397}
398
399bool SparseTensorEncodingAttr::isAllDense() const {
400 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
401}
402
403bool SparseTensorEncodingAttr::isAllOrdered() const {
404 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
405}
406
407Type SparseTensorEncodingAttr::getCrdElemType() const {
408 if (!getImpl())
409 return nullptr;
410 if (getCrdWidth())
411 return IntegerType::get(getContext(), getCrdWidth());
412 return IndexType::get(getContext());
413}
414
415Type SparseTensorEncodingAttr::getPosElemType() const {
416 if (!getImpl())
417 return nullptr;
418 if (getPosWidth())
419 return IntegerType::get(getContext(), getPosWidth());
420 return IndexType::get(getContext());
421}
422
423MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
424 std::optional<ArrayRef<int64_t>> dimShape) const {
425 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
426 return MemRefType::get(shape, getCrdElemType());
427}
428
429MemRefType SparseTensorEncodingAttr::getPosMemRefType(
430 std::optional<ArrayRef<int64_t>> dimShape) const {
431 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
432 return MemRefType::get(shape, getPosElemType());
433}
434
435bool SparseTensorEncodingAttr::isIdentity() const {
436 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
437}
438
439bool SparseTensorEncodingAttr::isPermutation() const {
440 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
441}
442
443Dimension SparseTensorEncodingAttr::getDimRank() const {
444 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
445 const auto dimToLvl = getDimToLvl();
446 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
447}
448
449Level SparseTensorEncodingAttr::getLvlRank() const {
450 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
451 return getLvlTypes().size();
452}
453
454LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
455 if (!getImpl())
456 return LevelFormat::Batch;
457 assert(l < getLvlRank() && "Level is out of bounds");
458 return getLvlTypes()[l];
459}
460
461bool SparseTensorEncodingAttr::isSlice() const {
462 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
463 return !getDimSlices().empty();
464}
465
466SparseTensorDimSliceAttr
467SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
468 assert(isSlice() && "Is not a slice");
469 const auto dimSlices = getDimSlices();
470 assert(dim < dimSlices.size() && "Dimension is out of bounds");
471 return dimSlices[dim];
472}
473
474std::optional<uint64_t>
475SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
476 return getDimSlice(dim).getStaticOffset();
477}
478
479std::optional<uint64_t>
480SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
481 return getDimSlice(dim).getStaticStride();
482}
483
484std::optional<uint64_t>
485SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
486 return getStaticDimSliceOffset(toDim(*this, lvl));
487}
488
489std::optional<uint64_t>
490SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
491 return getStaticDimSliceStride(toDim(*this, lvl));
492}
493
494SmallVector<int64_t>
495SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
496 CrdTransDirectionKind dir) const {
497 if (isIdentity())
498 return SmallVector<int64_t>(srcShape);
499
500 SmallVector<int64_t> ret;
501 unsigned rank =
502 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
503 ret.reserve(rank);
504
505 if (isPermutation()) {
506 for (unsigned r = 0; r < rank; r++) {
507 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
508 : toLvl(*this, r);
509 ret.push_back(srcShape[trans]);
510 }
511 return ret;
512 }
513
514 // Handle non-permutation maps.
515 AffineMap transMap =
516 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
517
518 SmallVector<AffineExpr> dimRep;
519 dimRep.reserve(srcShape.size());
520 for (int64_t sz : srcShape) {
521 if (!ShapedType::isDynamic(sz)) {
522 // Push back the max coordinate for the given dimension/level size.
523 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
524 } else {
525 // A dynamic size, use a AffineDimExpr to symbolize the value.
526 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
527 }
528 };
529
530 for (AffineExpr exp : transMap.getResults()) {
531 // Do constant propagation on the affine map.
532 AffineExpr evalExp =
533 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
534 // use llvm namespace here to avoid ambiguity
535 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
536 ret.push_back(c.getValue() + 1);
537 } else {
538 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
539 mod && mod.getKind() == AffineExprKind::Mod) {
540 // We can still infer a static bound for expressions in form
541 // "d % constant" since d % constant \in [0, constant).
542 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
543 ret.push_back(bound.getValue());
544 continue;
545 }
546 }
547 ret.push_back(ShapedType::kDynamic);
548 }
549 }
550 assert(ret.size() == rank);
551 return ret;
552}
553
554ValueRange
555SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
556 ValueRange crds,
557 CrdTransDirectionKind dir) const {
558 if (!getImpl())
559 return crds;
560
561 SmallVector<Type> retType(
562 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
563 builder.getIndexType());
564 auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
565 return transOp.getOutCrds();
566}
567
568Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
569 // Open "<{" part.
570 if (failed(parser.parseLess()))
571 return {};
572 if (failed(parser.parseLBrace()))
573 return {};
574
575 // Process the data from the parsed dictionary value into struct-like data.
576 SmallVector<LevelType> lvlTypes;
577 SmallVector<SparseTensorDimSliceAttr> dimSlices;
578 AffineMap dimToLvl = {};
579 AffineMap lvlToDim = {};
580 unsigned posWidth = 0;
581 unsigned crdWidth = 0;
582 Attribute explicitVal;
583 Attribute implicitVal;
584 StringRef attrName;
585 SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
586 "explicitVal", "implicitVal"};
587 while (succeeded(parser.parseOptionalKeyword(&attrName))) {
588 // Detect admissible keyword.
589 auto *it = find(keys, attrName);
590 if (it == keys.end()) {
591 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
592 return {};
593 }
594 unsigned keyWordIndex = it - keys.begin();
595 // Consume the `=` after keys
596 if (failed(parser.parseEqual()))
597 return {};
598 // Dispatch on keyword.
599 switch (keyWordIndex) {
600 case 0: { // map
601 ir_detail::DimLvlMapParser cParser(parser);
602 auto res = cParser.parseDimLvlMap();
603 if (failed(res))
604 return {};
605 const auto &dlm = *res;
606
607 const Level lvlRank = dlm.getLvlRank();
608 for (Level lvl = 0; lvl < lvlRank; lvl++)
609 lvlTypes.push_back(dlm.getLvlType(lvl));
610
611 const Dimension dimRank = dlm.getDimRank();
612 for (Dimension dim = 0; dim < dimRank; dim++)
613 dimSlices.push_back(dlm.getDimSlice(dim));
614 // NOTE: the old syntax requires an all-or-nothing approach to
615 // `dimSlices`; therefore, if any slice actually exists then we need
616 // to convert null-DSA into default/nop DSA.
617 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
618 return static_cast<bool>(slice.getImpl());
619 };
620 if (llvm::any_of(dimSlices, isDefined)) {
621 const auto defaultSlice =
622 SparseTensorDimSliceAttr::get(parser.getContext());
623 for (Dimension dim = 0; dim < dimRank; dim++)
624 if (!isDefined(dimSlices[dim]))
625 dimSlices[dim] = defaultSlice;
626 } else {
627 dimSlices.clear();
628 }
629
630 dimToLvl = dlm.getDimToLvlMap(parser.getContext());
631 lvlToDim = dlm.getLvlToDimMap(parser.getContext());
632 break;
633 }
634 case 1: { // posWidth
635 Attribute attr;
636 if (failed(parser.parseAttribute(attr)))
637 return {};
638 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
639 if (!intAttr) {
640 parser.emitError(parser.getNameLoc(),
641 "expected an integral position bitwidth");
642 return {};
643 }
644 posWidth = intAttr.getInt();
645 break;
646 }
647 case 2: { // crdWidth
648 Attribute attr;
649 if (failed(parser.parseAttribute(attr)))
650 return {};
651 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
652 if (!intAttr) {
653 parser.emitError(parser.getNameLoc(),
654 "expected an integral index bitwidth");
655 return {};
656 }
657 crdWidth = intAttr.getInt();
658 break;
659 }
660 case 3: { // explicitVal
661 Attribute attr;
662 if (failed(parser.parseAttribute(attr)))
663 return {};
664 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
665 explicitVal = result;
666 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
667 explicitVal = result;
668 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
669 explicitVal = result;
670 } else {
671 parser.emitError(parser.getNameLoc(),
672 "expected a numeric value for explicitVal");
673 return {};
674 }
675 break;
676 }
677 case 4: { // implicitVal
678 Attribute attr;
679 if (failed(parser.parseAttribute(attr)))
680 return {};
681 if (auto result = llvm::dyn_cast<FloatAttr>(attr)) {
682 implicitVal = result;
683 } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
684 implicitVal = result;
685 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
686 implicitVal = result;
687 } else {
688 parser.emitError(parser.getNameLoc(),
689 "expected a numeric value for implicitVal");
690 return {};
691 }
692 break;
693 }
694 } // switch
695 // Only last item can omit the comma.
696 if (parser.parseOptionalComma().failed())
697 break;
698 }
699
700 // Close "}>" part.
701 if (failed(parser.parseRBrace()))
702 return {};
703 if (failed(parser.parseGreater()))
704 return {};
705
706 // Construct struct-like storage for attribute.
707 if (!lvlToDim || lvlToDim.isEmpty()) {
708 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
709 }
710 return parser.getChecked<SparseTensorEncodingAttr>(
711 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
712 explicitVal, implicitVal, dimSlices);
713}
714
715void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
716 auto map = static_cast<AffineMap>(getDimToLvl());
717 // Empty affine map indicates identity map
718 if (!map)
719 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
720 printer << "<{ map = ";
721 printSymbols(map, printer);
722 printer << '(';
723 printDimensions(map, printer, getDimSlices());
724 printer << ") -> (";
725 printLevels(map, printer, getLvlTypes());
726 printer << ')';
727 // Print remaining members only for non-default values.
728 if (getPosWidth())
729 printer << ", posWidth = " << getPosWidth();
730 if (getCrdWidth())
731 printer << ", crdWidth = " << getCrdWidth();
732 if (getExplicitVal()) {
733 printer << ", explicitVal = " << getExplicitVal();
734 }
735 if (getImplicitVal())
736 printer << ", implicitVal = " << getImplicitVal();
737 printer << " }>";
738}
739
740void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
741 AsmPrinter &printer) const {
742 if (map.getNumSymbols() == 0)
743 return;
744 printer << '[';
745 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
746 printer << 's' << i << ", ";
747 if (map.getNumSymbols() >= 1)
748 printer << 's' << map.getNumSymbols() - 1;
749 printer << ']';
750}
751
752void SparseTensorEncodingAttr::printDimensions(
753 AffineMap &map, AsmPrinter &printer,
754 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
755 if (!dimSlices.empty()) {
756 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
757 printer << 'd' << i << " : " << dimSlices[i] << ", ";
758 if (map.getNumDims() >= 1) {
759 printer << 'd' << map.getNumDims() - 1 << " : "
760 << dimSlices[map.getNumDims() - 1];
761 }
762 } else {
763 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
764 printer << 'd' << i << ", ";
765 if (map.getNumDims() >= 1)
766 printer << 'd' << map.getNumDims() - 1;
767 }
768}
769
770void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
771 ArrayRef<LevelType> lvlTypes) const {
772 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
773 map.getResult(i).print(printer.getStream());
774 printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
775 }
776 if (map.getNumResults() >= 1) {
777 auto lastIndex = map.getNumResults() - 1;
778 map.getResult(lastIndex).print(printer.getStream());
779 printer << " : " << toMLIRString(lvlTypes[lastIndex]);
780 }
781}
782
783LogicalResult SparseTensorEncodingAttr::verify(
784 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
785 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
786 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
787 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
788 if (!acceptBitWidth(posWidth))
789 return emitError() << "unexpected position bitwidth: " << posWidth;
790 if (!acceptBitWidth(crdWidth))
791 return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
792
793 // Verify every COO segment.
794 auto *it = llvm::find_if(lvlTypes, isSingletonLT);
795 while (it != lvlTypes.end()) {
796 if (it == lvlTypes.begin() ||
797 !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>())
798 return emitError() << "expected compressed or loose_compressed level "
799 "before singleton level";
800
801 auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
802 if (!std::all_of(it, curCOOEnd, isSingletonLT))
803 return emitError() << "expected all singleton lvlTypes "
804 "following a singleton level";
805 // We can potentially support mixed SoA/AoS singleton levels.
806 if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
807 return it->isa<LevelPropNonDefault::SoA>() ==
808 i.isa<LevelPropNonDefault::SoA>();
809 })) {
810 return emitError() << "expected all singleton lvlTypes stored in the "
811 "same memory layout (SoA vs AoS).";
812 }
813 it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
814 }
815
816 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
817 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
818 return emitError() << "Batch lvlType can only be leading levels.";
819
820 // SoA property can only be applied on singleton level.
821 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
822 return lt.isa<LevelPropNonDefault::SoA>();
823 });
824 if (llvm::any_of(soaLvls, [](LevelType lt) {
825 return !lt.isa<LevelFormat::Singleton>();
826 })) {
827 return emitError() << "SoA is only applicable to singleton lvlTypes.";
828 }
829
830 // TODO: audit formats that actually are supported by backend.
831 if (auto it = llvm::find_if(lvlTypes, isNOutOfMLT);
832 it != std::end(lvlTypes)) {
833 if (it != lvlTypes.end() - 1)
834 return emitError() << "expected n_out_of_m to be the last level type";
835 if (!std::all_of(lvlTypes.begin(), it, isDenseLT))
836 return emitError() << "expected all dense lvlTypes "
837 "before a n_out_of_m level";
838 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
839 if (!isBlockSparsity(dimToLvl)) {
840 return emitError()
841 << "expected 1xm block structure for n_out_of_m level";
842 }
843 auto sizes = getBlockSize(dimToLvl);
844 unsigned coefficient = 0;
845 for (const auto &elem : sizes) {
846 if (elem != 0) {
847 if (elem != coefficient && coefficient != 0) {
848 return emitError() << "expected only one blocked level "
849 "with the same coefficients";
850 }
851 coefficient = elem;
852 }
853 }
854 if (coefficient != getM(*it)) {
855 return emitError() << "expected coeffiencts of Affine expressions "
856 "to be equal to m of n_out_of_m level";
857 }
858 }
859 }
860 // Before we can check that the level-rank is consistent/coherent
861 // across all fields, we need to define it. The source-of-truth for
862 // the `getLvlRank` method is the length of the level-types array,
863 // since it must always be provided and have full rank; therefore we
864 // use that same source-of-truth here.
865 const Level lvlRank = lvlTypes.size();
866 if (lvlRank == 0)
867 return emitError() << "expected a non-empty array for lvlTypes";
868 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
869 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
870 if (dimToLvl) {
871 if (dimToLvl.getNumResults() != lvlRank)
872 return emitError()
873 << "level-rank mismatch between dimToLvl and lvlTypes: "
874 << dimToLvl.getNumResults() << " != " << lvlRank;
875 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
876 // Symbols can't be inferred but are acceptable.
877 if (!inferRes && dimToLvl.getNumSymbols() == 0)
878 return emitError() << "failed to infer lvlToDim from dimToLvl";
879 if (lvlToDim && (inferRes != lvlToDim))
880 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
881 if (dimRank > lvlRank)
882 return emitError() << "unexpected dimToLvl mapping from " << dimRank
883 << " to " << lvlRank;
884 }
885 if (!dimSlices.empty()) {
886 if (dimSlices.size() != dimRank)
887 return emitError()
888 << "dimension-rank mismatch between dimSlices and dimToLvl: "
889 << dimSlices.size() << " != " << dimRank;
890 // Compiler support for `dimSlices` currently requires that the two
891 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
892 if (dimRank != lvlRank)
893 return emitError()
894 << "dimSlices expected dimension-rank to match level-rank: "
895 << dimRank << " != " << lvlRank;
896 }
897 return success();
898}
899
900LogicalResult SparseTensorEncodingAttr::verifyEncoding(
901 ArrayRef<Size> dimShape, Type elementType,
902 function_ref<InFlightDiagnostic()> emitError) const {
903 // Check structural integrity. In particular, this ensures that the
904 // level-rank is coherent across all the fields.
905 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
906 getPosWidth(), getCrdWidth(), getExplicitVal(),
907 getImplicitVal(), getDimSlices())))
908 return failure();
909 // Check integrity with tensor type specifics. In particular, we
910 // need only check that the dimension-rank of the tensor agrees with
911 // the dimension-rank of the encoding.
912 const Dimension dimRank = dimShape.size();
913 if (dimRank == 0)
914 return emitError() << "expected non-scalar sparse tensor";
915 if (getDimRank() != dimRank)
916 return emitError()
917 << "dimension-rank mismatch between encoding and tensor shape: "
918 << getDimRank() << " != " << dimRank;
919 if (auto expVal = getExplicitVal()) {
920 Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType();
921 if (attrType != elementType) {
922 return emitError() << "explicit value type mismatch between encoding and "
923 << "tensor element type: " << attrType
924 << " != " << elementType;
925 }
926 }
927 if (auto impVal = getImplicitVal()) {
928 Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType();
929 if (attrType != elementType) {
930 return emitError() << "implicit value type mismatch between encoding and "
931 << "tensor element type: " << attrType
932 << " != " << elementType;
933 }
934 // Currently, we only support zero as the implicit value.
935 auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
936 auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
937 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal);
938 if ((impFVal && impFVal.getValue().isNonZero()) ||
939 (impIntVal && !impIntVal.getValue().isZero()) ||
940 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
941 impComplexVal.getReal().isNonZero()))) {
942 return emitError() << "implicit value must be zero";
943 }
944 }
945 return success();
946}
947
948Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
949 SmallVector<COOSegment> coo = getCOOSegments();
950 assert(coo.size() == 1 || coo.empty());
951 if (!coo.empty() && coo.front().isAoS()) {
952 return coo.front().lvlRange.first;
953 }
954 return getLvlRank();
955}
956
957SmallVector<COOSegment>
958mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
959 SmallVector<COOSegment> ret;
960 if (getLvlRank() <= 1)
961 return ret;
962
963 ArrayRef<LevelType> lts = getLvlTypes();
964 Level l = 0;
965 while (l < getLvlRank()) {
966 auto lt = lts[l];
967 if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
968 auto cur = lts.begin() + l;
969 auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
970 return !lt.isa<LevelFormat::Singleton>();
971 });
972 unsigned cooLen = std::distance(cur, end);
973 if (cooLen > 1) {
974 // To support mixed SoA/AoS COO, we should break the segment when the
975 // storage scheme changes, for now we faithfully assume that all
976 // consecutive singleton levels have the same storage format as verified
977 // STEA.
978 ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
979 lts[l + 1].isa<LevelPropNonDefault::SoA>()});
980 }
981 l += cooLen;
982 } else {
983 l++;
984 }
985 }
986 return ret;
987}
988
989//===----------------------------------------------------------------------===//
990// SparseTensorType Methods.
991//===----------------------------------------------------------------------===//
992
993bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
994 bool isUnique) const {
995 if (!hasEncoding())
996 return false;
997 if (!isCompressedLvl(l: startLvl) && !isLooseCompressedLvl(l: startLvl))
998 return false;
999 for (Level l = startLvl + 1; l < lvlRank; ++l)
1000 if (!isSingletonLvl(l))
1001 return false;
1002 // If isUnique is true, then make sure that the last level is unique,
1003 // that is, when lvlRank == 1, the only compressed level is unique,
1004 // and when lvlRank > 1, the last singleton is unique.
1005 return !isUnique || isUniqueLvl(l: lvlRank - 1);
1006}
1007
1008RankedTensorType
1009mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
1010 SmallVector<LevelType> lvlTypes;
1011 lvlTypes.reserve(N: lvlRank);
1012 // A non-unique compressed level at beginning (unless this is
1013 // also the last level, then it is unique).
1014 lvlTypes.push_back(
1015 Elt: *buildLevelType(lf: LevelFormat::Compressed, ordered, unique: lvlRank == 1));
1016 if (lvlRank > 1) {
1017 // Followed by n-2 non-unique singleton levels.
1018 std::fill_n(std::back_inserter(x&: lvlTypes), lvlRank - 2,
1019 *buildLevelType(lf: LevelFormat::Singleton, ordered, unique: false));
1020 // Ends by a unique singleton level.
1021 lvlTypes.push_back(Elt: *buildLevelType(lf: LevelFormat::Singleton, ordered, unique: true));
1022 }
1023 auto enc = SparseTensorEncodingAttr::get(
1024 getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(),
1025 getCrdWidth(), getExplicitVal(), getImplicitVal());
1026 return RankedTensorType::get(getDimShape(), getElementType(), enc);
1027}
1028
1029//===----------------------------------------------------------------------===//
1030// Convenience Methods.
1031//===----------------------------------------------------------------------===//
1032
1033SparseTensorEncodingAttr
1034mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
1035 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
1036 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
1037 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
1038 return mdtp.getEncoding();
1039 return nullptr;
1040}
1041
1042AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
1043 MLIRContext *context) {
1044 auto map = static_cast<AffineMap>(dimToLvl);
1045 AffineMap lvlToDim;
1046 // Return an empty lvlToDim when inference is not successful.
1047 if (!map || map.getNumSymbols() != 0) {
1048 lvlToDim = AffineMap();
1049 } else if (map.isPermutation()) {
1050 lvlToDim = inversePermutation(map);
1051 } else if (isBlockSparsity(dimToLvl: map)) {
1052 lvlToDim = inverseBlockSparsity(dimToLvl: map, context);
1053 }
1054 return lvlToDim;
1055}
1056
1057AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
1058 MLIRContext *context) {
1059 SmallVector<AffineExpr> lvlExprs;
1060 auto numLvls = dimToLvl.getNumResults();
1061 lvlExprs.reserve(N: numLvls);
1062 // lvlExprComponents stores information of the floordiv and mod operations
1063 // applied to the same dimension, so as to build the lvlToDim map.
1064 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1065 for (unsigned i = 0, n = numLvls; i < n; i++) {
1066 auto result = dimToLvl.getResult(idx: i);
1067 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) {
1068 if (result.getKind() == AffineExprKind::FloorDiv) {
1069 // Position of the dimension in dimToLvl.
1070 auto pos = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()).getPosition();
1071 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1072 "expected only one floordiv for each dimension");
1073 SmallVector<AffineExpr, 3> components;
1074 // Level variable for floordiv.
1075 components.push_back(Elt: getAffineDimExpr(position: i, context));
1076 // Multiplier.
1077 components.push_back(Elt: binOp.getRHS());
1078 // Map key is the position of the dimension.
1079 lvlExprComponents[pos] = components;
1080 } else if (result.getKind() == AffineExprKind::Mod) {
1081 auto pos = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()).getPosition();
1082 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1083 "expected floordiv before mod");
1084 // Add level variable for mod to the same vector
1085 // of the corresponding floordiv.
1086 lvlExprComponents[pos].push_back(Elt: getAffineDimExpr(position: i, context));
1087 } else {
1088 assert(false && "expected floordiv or mod");
1089 }
1090 } else {
1091 lvlExprs.push_back(Elt: getAffineDimExpr(position: i, context));
1092 }
1093 }
1094 // Build lvlExprs from lvlExprComponents.
1095 // For example, for il = i floordiv 2 and ii = i mod 2, the components
1096 // would be [il, 2, ii]. It could be used to build the AffineExpr
1097 // i = il * 2 + ii in lvlToDim.
1098 for (auto &components : lvlExprComponents) {
1099 assert(components.second.size() == 3 &&
1100 "expected 3 components to build lvlExprs");
1101 auto mulOp = getAffineBinaryOpExpr(
1102 kind: AffineExprKind::Mul, lhs: components.second[0], rhs: components.second[1]);
1103 auto addOp =
1104 getAffineBinaryOpExpr(kind: AffineExprKind::Add, lhs: mulOp, rhs: components.second[2]);
1105 lvlExprs.push_back(Elt: addOp);
1106 }
1107 return dimToLvl.get(dimCount: dimToLvl.getNumResults(), symbolCount: 0, results: lvlExprs, context);
1108}
1109
1110SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
1111 assert(isBlockSparsity(dimToLvl) &&
1112 "expected dimToLvl to be block sparsity for calling getBlockSize");
1113 SmallVector<unsigned> blockSize;
1114 for (auto result : dimToLvl.getResults()) {
1115 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) {
1116 if (result.getKind() == AffineExprKind::Mod) {
1117 blockSize.push_back(
1118 Elt: dyn_cast<AffineConstantExpr>(Val: binOp.getRHS()).getValue());
1119 }
1120 } else {
1121 blockSize.push_back(Elt: 0);
1122 }
1123 }
1124 return blockSize;
1125}
1126
1127bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
1128 if (!dimToLvl)
1129 return false;
1130 std::map<unsigned, int64_t> coeffientMap;
1131 bool hasBlock = false;
1132 for (auto result : dimToLvl.getResults()) {
1133 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) {
1134 // Check for "dim op const".
1135 auto dimOp = dyn_cast<AffineDimExpr>(Val: binOp.getLHS());
1136 auto conOp = dyn_cast<AffineConstantExpr>(Val: binOp.getRHS());
1137 if (!dimOp || !conOp || conOp.getValue() <= 0)
1138 return false;
1139 // Inspect "dim / const" or "dim % const".
1140 auto pos = dimOp.getPosition();
1141 if (binOp.getKind() == AffineExprKind::FloorDiv) {
1142 // Expect only one floordiv for each dimension.
1143 auto [it, inserted] = coeffientMap.try_emplace(k: pos);
1144 if (!inserted)
1145 return false;
1146 // Record coefficient of the floordiv.
1147 it->second = conOp.getValue();
1148 } else if (binOp.getKind() == AffineExprKind::Mod) {
1149 // Expect floordiv before mod.
1150 auto it = coeffientMap.find(x: pos);
1151 if (it == coeffientMap.end())
1152 return false;
1153 // Expect mod to have the same coefficient as floordiv.
1154 if (conOp.getValue() != it->second)
1155 return false;
1156 hasBlock = true;
1157 } else {
1158 return false;
1159 }
1160 } else if (auto dimOp = dyn_cast<AffineDimExpr>(Val&: result)) {
1161 auto pos = dimOp.getPosition();
1162 // Expect dim to be unset.
1163 if (!coeffientMap.try_emplace(k: pos, args: 0).second)
1164 return false;
1165 } else {
1166 return false;
1167 }
1168 }
1169 return hasBlock;
1170}
1171
1172bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
1173 auto hasNonIdentityMap = [](Value v) {
1174 auto stt = tryGetSparseTensorType(v);
1175 return stt && !stt->isIdentity();
1176 };
1177
1178 return llvm::any_of(Range: op->getOperands(), P: hasNonIdentityMap) ||
1179 llvm::any_of(Range: op->getResults(), P: hasNonIdentityMap);
1180}
1181
1182Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1183 if (enc) {
1184 assert(enc.isPermutation() && "Non permutation map not supported");
1185 if (const auto dimToLvl = enc.getDimToLvl())
1186 return dimToLvl.getDimPosition(l);
1187 }
1188 return l;
1189}
1190
1191Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1192 if (enc) {
1193 assert(enc.isPermutation() && "Non permutation map not supported");
1194 if (const auto lvlToDim = enc.getLvlToDim())
1195 return lvlToDim.getDimPosition(d);
1196 }
1197 return d;
1198}
1199
1200/// We normalized sparse tensor encoding attribute by always using
1201/// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1202/// as other variants) lead to the same storage specifier type, and stripping
1203/// irrelevant fields that do not alter the sparse tensor memory layout.
1204static SparseTensorEncodingAttr
1205getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1206 SmallVector<LevelType> lts;
1207 for (auto lt : enc.getLvlTypes())
1208 lts.push_back(lt.stripStorageIrrelevantProperties());
1209
1210 return SparseTensorEncodingAttr::get(
1211 enc.getContext(), lts,
1212 AffineMap(), // dimToLvl (irrelevant to storage specifier)
1213 AffineMap(), // lvlToDim (irrelevant to storage specifier)
1214 // Always use `index` for memSize and lvlSize instead of reusing
1215 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1216 // value for different bitwidth, it also avoids casting between index and
1217 // integer (returned by DimOp)
1218 0, 0,
1219 Attribute(), // explicitVal (irrelevant to storage specifier)
1220 Attribute(), // implicitVal (irrelevant to storage specifier)
1221 enc.getDimSlices());
1222}
1223
1224StorageSpecifierType
1225StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1226 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1227}
1228
1229StorageSpecifierType
1230StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1231 MLIRContext *ctx,
1232 SparseTensorEncodingAttr encoding) {
1233 return Base::getChecked(emitError, ctx,
1234 getNormalizedEncodingForSpecifier(encoding));
1235}
1236
1237//===----------------------------------------------------------------------===//
1238// SparseTensorDialect Operations.
1239//===----------------------------------------------------------------------===//
1240
1241static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1242 return success(IsSuccess: lvl < getSparseTensorType(val: tensor).getLvlRank());
1243}
1244
1245static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1246 const Type etp = getMemRefType(mem).getElementType();
1247 return success(IsSuccess: width == 0 ? etp.isIndex() : etp.isInteger(width));
1248}
1249
1250static LogicalResult verifySparsifierGetterSetter(
1251 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1252 TypedValue<StorageSpecifierType> md, Operation *op) {
1253 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1254 return op->emitError(
1255 message: "redundant level argument for querying value memory size");
1256 }
1257
1258 const auto enc = md.getType().getEncoding();
1259 const Level lvlRank = enc.getLvlRank();
1260
1261 if (mdKind == StorageSpecifierKind::DimOffset ||
1262 mdKind == StorageSpecifierKind::DimStride)
1263 if (!enc.isSlice())
1264 return op->emitError(message: "requested slice data on non-slice tensor");
1265
1266 if (mdKind != StorageSpecifierKind::ValMemSize) {
1267 if (!lvl)
1268 return op->emitError(message: "missing level argument");
1269
1270 const Level l = lvl.value();
1271 if (l >= lvlRank)
1272 return op->emitError(message: "requested level is out of bounds");
1273
1274 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1275 return op->emitError(
1276 message: "requested position memory size on a singleton level");
1277 }
1278 return success();
1279}
1280
1281static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1282 switch (kind) {
1283 case SparseTensorFieldKind::CrdMemRef:
1284 return stt.getCrdType();
1285 case SparseTensorFieldKind::PosMemRef:
1286 return stt.getPosType();
1287 case SparseTensorFieldKind::ValMemRef:
1288 return stt.getElementType();
1289 case SparseTensorFieldKind::StorageSpec:
1290 return nullptr;
1291 }
1292 llvm_unreachable("Unrecognizable FieldKind");
1293}
1294
1295static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1296 SparseTensorType stt,
1297 RankedTensorType valTp,
1298 TypeRange lvlTps) {
1299 if (requiresStaticShape && !stt.hasStaticDimShape())
1300 return op->emitError(message: "the sparse-tensor must have static shape");
1301 if (!stt.hasEncoding())
1302 return op->emitError(message: "the sparse-tensor must have an encoding attribute");
1303
1304 // Verifies the trailing COO.
1305 Level cooStartLvl = stt.getAoSCOOStart();
1306 if (cooStartLvl < stt.getLvlRank()) {
1307 // We only supports trailing COO for now, must be the last input.
1308 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1309 // The coordinates should be in shape of <? x rank>
1310 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1311 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1312 return op->emitError(message: "input/output trailing COO level-ranks don't match");
1313 }
1314 }
1315
1316 // Verifies that all types match.
1317 StorageLayout layout(stt.getEncoding());
1318 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1319 return op->emitError(message: "inconsistent number of fields between input/output");
1320
1321 unsigned idx = 0;
1322 bool misMatch = false;
1323 layout.foreachField(callback: [&idx, &misMatch, stt, valTp,
1324 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1325 Level lvl, LevelType lt) -> bool {
1326 if (fKind == SparseTensorFieldKind::StorageSpec)
1327 return true;
1328
1329 Type inputTp = nullptr;
1330 if (fKind == SparseTensorFieldKind::ValMemRef) {
1331 inputTp = valTp;
1332 } else {
1333 assert(fid == idx && stt.getLvlType(lvl) == lt);
1334 inputTp = lvlTps[idx++];
1335 }
1336 // The input element type and expected element type should match.
1337 Type inpElemTp = llvm::cast<TensorType>(Val&: inputTp).getElementType();
1338 Type expElemTp = getFieldElemType(stt, kind: fKind);
1339 if (inpElemTp != expElemTp) {
1340 misMatch = true;
1341 return false; // to terminate the iteration
1342 }
1343 return true;
1344 });
1345
1346 if (misMatch)
1347 return op->emitError(message: "input/output element-types don't match");
1348 return success();
1349}
1350
1351LogicalResult AssembleOp::verify() {
1352 RankedTensorType valuesTp = getValues().getType();
1353 const auto lvlsTp = getLevels().getTypes();
1354 const auto resTp = getSparseTensorType(getResult());
1355 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1356}
1357
1358LogicalResult DisassembleOp::verify() {
1359 if (getOutValues().getType() != getRetValues().getType())
1360 return emitError("output values and return value type mismatch");
1361
1362 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1363 if (ot.getType() != rt.getType())
1364 return emitError("output levels and return levels type mismatch");
1365
1366 RankedTensorType valuesTp = getRetValues().getType();
1367 const auto lvlsTp = getRetLevels().getTypes();
1368 const auto srcTp = getSparseTensorType(getTensor());
1369 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1370}
1371
1372LogicalResult ConvertOp::verify() {
1373 RankedTensorType tp1 = getSource().getType();
1374 RankedTensorType tp2 = getDest().getType();
1375 if (tp1.getRank() != tp2.getRank())
1376 return emitError("unexpected conversion mismatch in rank");
1377 auto dstEnc =
1378 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1379 if (dstEnc && dstEnc.isSlice())
1380 return emitError("cannot convert to a sparse tensor slice");
1381
1382 auto shape1 = tp1.getShape();
1383 auto shape2 = tp2.getShape();
1384 // Accept size matches between the source and the destination type
1385 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1386 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1387 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1388 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1389 return emitError("unexpected conversion mismatch in dimension ") << d;
1390 return success();
1391}
1392
1393OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1394 if (getType() == getSource().getType())
1395 return getSource();
1396 return {};
1397}
1398
1399bool ConvertOp::needsExtraSort() {
1400 SparseTensorType srcStt = getSparseTensorType(getSource());
1401 SparseTensorType dstStt = getSparseTensorType(getDest());
1402
1403 // We do not need an extra sort when returning unordered sparse tensors or
1404 // dense tensor since dense tensor support random access.
1405 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1406 return false;
1407
1408 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1409 srcStt.hasSameDimToLvl(dstStt)) {
1410 return false;
1411 }
1412
1413 // Source and dest tensors are ordered in different ways. We only do direct
1414 // dense to sparse conversion when the dense input is defined by a sparse
1415 // constant. Note that we can theoretically always directly convert from dense
1416 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1417 // performance.
1418 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1419 if (isa<SparseElementsAttr>(constOp.getValue()))
1420 return false;
1421
1422 return true;
1423}
1424
1425LogicalResult CrdTranslateOp::verify() {
1426 uint64_t inRank = getEncoder().getLvlRank();
1427 uint64_t outRank = getEncoder().getDimRank();
1428
1429 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1430 std::swap(inRank, outRank);
1431
1432 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1433 return emitError("Coordinate rank mismatch with encoding");
1434
1435 return success();
1436}
1437
1438LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1439 SmallVectorImpl<OpFoldResult> &results) {
1440 if (getEncoder().isIdentity()) {
1441 results.assign(getInCrds().begin(), getInCrds().end());
1442 return success();
1443 }
1444 if (getEncoder().isPermutation()) {
1445 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1446 ? getEncoder().getDimToLvl()
1447 : getEncoder().getLvlToDim();
1448 for (AffineExpr exp : perm.getResults())
1449 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1450 return success();
1451 }
1452
1453 // Fuse dim2lvl/lvl2dim pairs.
1454 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1455 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1456 return v.getDefiningOp() == def;
1457 });
1458 if (!sameDef)
1459 return failure();
1460
1461 bool oppositeDir = def.getDirection() != getDirection();
1462 bool sameOracle =
1463 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1464 bool sameCount = def.getNumResults() == getInCrds().size();
1465 if (!oppositeDir || !sameOracle || !sameCount)
1466 return failure();
1467
1468 // The definition produces the coordinates in the same order as the input
1469 // coordinates.
1470 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1471 [](auto valuePair) {
1472 auto [lhs, rhs] = valuePair;
1473 return lhs == rhs;
1474 });
1475
1476 if (!sameOrder)
1477 return failure();
1478 // l1 = dim2lvl (lvl2dim l0)
1479 // ==> l0
1480 results.append(def.getInCrds().begin(), def.getInCrds().end());
1481 return success();
1482}
1483
1484void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1485 int64_t index) {
1486 Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1487 return build(builder, state, source, val);
1488}
1489
1490LogicalResult LvlOp::verify() {
1491 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1492 auto stt = getSparseTensorType(getSource());
1493 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1494 return emitError(
1495 "Level index exceeds the rank of the input sparse tensor");
1496 }
1497 return success();
1498}
1499
1500std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1501 return getConstantIntValue(getIndex());
1502}
1503
1504Speculation::Speculatability LvlOp::getSpeculatability() {
1505 auto constantIndex = getConstantLvlIndex();
1506 if (!constantIndex)
1507 return Speculation::NotSpeculatable;
1508
1509 assert(constantIndex <
1510 cast<RankedTensorType>(getSource().getType()).getRank());
1511 return Speculation::Speculatable;
1512}
1513
1514OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1515 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1516 if (!lvlIndex)
1517 return {};
1518
1519 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1520 auto stt = getSparseTensorType(getSource());
1521 if (lvl >= stt.getLvlRank()) {
1522 // Follows the same convention used by tensor.dim operation. Out of bound
1523 // indices produce undefined behavior but are still valid IR. Don't choke on
1524 // them.
1525 return {};
1526 }
1527
1528 // Helper lambda to build an IndexAttr.
1529 auto getIndexAttr = [this](int64_t lvlSz) {
1530 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1531 };
1532
1533 SmallVector<Size> lvlShape = stt.getLvlShape();
1534 if (!ShapedType::isDynamic(lvlShape[lvl]))
1535 return getIndexAttr(lvlShape[lvl]);
1536
1537 return {};
1538}
1539
1540void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1541 SparseTensorEncodingAttr dstEnc, Value source) {
1542 auto srcStt = getSparseTensorType(source);
1543 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1544 SmallVector<int64_t> dstDimShape =
1545 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1546 auto dstTp =
1547 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1548 return build(odsBuilder, odsState, dstTp, source);
1549}
1550
1551LogicalResult ReinterpretMapOp::verify() {
1552 auto srcStt = getSparseTensorType(getSource());
1553 auto dstStt = getSparseTensorType(getDest());
1554 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1555 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1556
1557 if (srcLvlTps.size() != dstLvlTps.size())
1558 return emitError("Level rank mismatch between source/dest tensors");
1559
1560 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1561 if (srcLvlTp != dstLvlTp)
1562 return emitError("Level type mismatch between source/dest tensors");
1563
1564 if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1565 srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1566 return emitError("Crd/Pos width mismatch between source/dest tensors");
1567 }
1568
1569 if (srcStt.getElementType() != dstStt.getElementType())
1570 return emitError("Element type mismatch between source/dest tensors");
1571
1572 SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1573 SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1574 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1575 if (srcLvlSz != dstLvlSz) {
1576 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1577 // compatible to <3x4>? For now, we require all the level sizes to be
1578 // *exactly* matched for simplicity.
1579 return emitError("Level size mismatch between source/dest tensors");
1580 }
1581 }
1582
1583 return success();
1584}
1585
1586OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1587 if (getSource().getType() == getDest().getType())
1588 return getSource();
1589
1590 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1591 // A -> B, B -> A ==> A
1592 if (def.getSource().getType() == getDest().getType())
1593 return def.getSource();
1594 }
1595 return {};
1596}
1597
1598template <typename ToBufferOp>
1599static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1600 OpaqueProperties prop,
1601 RegionRange region,
1602 SmallVectorImpl<mlir::Type> &ret) {
1603 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1604 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1605 Type elemTp = nullptr;
1606 bool withStride = false;
1607 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1608 elemTp = stt.getPosType();
1609 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1610 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1611 elemTp = stt.getCrdType();
1612 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1613 withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1614 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1615 elemTp = stt.getElementType();
1616 }
1617
1618 assert(elemTp && "unhandled operation.");
1619 SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1620 bufShape.push_back(ShapedType::kDynamic);
1621
1622 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1623 stt.getContext(), ShapedType::kDynamic,
1624 {ShapedType::kDynamic})
1625 : StridedLayoutAttr();
1626 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1627 return success();
1628}
1629
1630LogicalResult ToPositionsOp::verify() {
1631 auto stt = getSparseTensorType(getTensor());
1632 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1633 return emitError("requested level is out of bounds");
1634 if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1635 return emitError("unexpected type for positions");
1636 return success();
1637}
1638
1639LogicalResult
1640ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1641 ValueRange ops, DictionaryAttr attr,
1642 OpaqueProperties prop, RegionRange region,
1643 SmallVectorImpl<mlir::Type> &ret) {
1644 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1645}
1646
1647LogicalResult ToCoordinatesOp::verify() {
1648 auto stt = getSparseTensorType(getTensor());
1649 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1650 return emitError("requested level is out of bounds");
1651 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1652 return emitError("unexpected type for coordinates");
1653 return success();
1654}
1655
1656LogicalResult
1657ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1658 ValueRange ops, DictionaryAttr attr,
1659 OpaqueProperties prop, RegionRange region,
1660 SmallVectorImpl<mlir::Type> &ret) {
1661 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1662}
1663
1664LogicalResult ToCoordinatesBufferOp::verify() {
1665 auto stt = getSparseTensorType(getTensor());
1666 if (stt.getAoSCOOStart() >= stt.getLvlRank())
1667 return emitError("expected sparse tensor with a COO region");
1668 return success();
1669}
1670
1671LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1672 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1673 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1674 SmallVectorImpl<mlir::Type> &ret) {
1675 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1676 ret);
1677}
1678
1679LogicalResult ToValuesOp::verify() {
1680 auto stt = getSparseTensorType(getTensor());
1681 auto mtp = getMemRefType(getResult());
1682 if (stt.getElementType() != mtp.getElementType())
1683 return emitError("unexpected mismatch in element types");
1684 return success();
1685}
1686
1687LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1688 std::optional<Location> loc,
1689 ValueRange ops, DictionaryAttr attr,
1690 OpaqueProperties prop,
1691 RegionRange region,
1692 SmallVectorImpl<mlir::Type> &ret) {
1693 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1694}
1695
1696LogicalResult ToSliceOffsetOp::verify() {
1697 auto rank = getSlice().getType().getRank();
1698 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1699 return emitError("requested dimension out of bound");
1700 return success();
1701}
1702
1703LogicalResult ToSliceStrideOp::verify() {
1704 auto rank = getSlice().getType().getRank();
1705 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1706 return emitError("requested dimension out of bound");
1707 return success();
1708}
1709
1710LogicalResult GetStorageSpecifierOp::verify() {
1711 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1712 getSpecifier(), getOperation());
1713}
1714
1715template <typename SpecifierOp>
1716static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1717 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1718}
1719
1720OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1721 const StorageSpecifierKind kind = getSpecifierKind();
1722 const auto lvl = getLevel();
1723 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1724 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1725 return op.getValue();
1726 return {};
1727}
1728
1729LogicalResult SetStorageSpecifierOp::verify() {
1730 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1731 getSpecifier(), getOperation());
1732}
1733
1734template <class T>
1735static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1736 const char *regionName,
1737 TypeRange inputTypes, Type outputType) {
1738 unsigned numArgs = region.getNumArguments();
1739 unsigned expectedNum = inputTypes.size();
1740 if (numArgs != expectedNum)
1741 return op->emitError() << regionName << " region must have exactly "
1742 << expectedNum << " arguments";
1743
1744 for (unsigned i = 0; i < numArgs; i++) {
1745 Type typ = region.getArgument(i).getType();
1746 if (typ != inputTypes[i])
1747 return op->emitError() << regionName << " region argument " << (i + 1)
1748 << " type mismatch";
1749 }
1750 Operation *term = region.front().getTerminator();
1751 YieldOp yield = dyn_cast<YieldOp>(term);
1752 if (!yield)
1753 return op->emitError() << regionName
1754 << " region must end with sparse_tensor.yield";
1755 if (!yield.hasSingleResult() ||
1756 yield.getSingleResult().getType() != outputType)
1757 return op->emitError() << regionName << " region yield type mismatch";
1758
1759 return success();
1760}
1761
1762LogicalResult BinaryOp::verify() {
1763 NamedAttrList attrs = (*this)->getAttrs();
1764 Type leftType = getX().getType();
1765 Type rightType = getY().getType();
1766 Type outputType = getOutput().getType();
1767 Region &overlap = getOverlapRegion();
1768 Region &left = getLeftRegion();
1769 Region &right = getRightRegion();
1770
1771 // Check correct number of block arguments and return type for each
1772 // non-empty region.
1773 if (!overlap.empty()) {
1774 if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1775 TypeRange{leftType, rightType}, outputType)))
1776 return failure();
1777 }
1778 if (!left.empty()) {
1779 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1780 outputType)))
1781 return failure();
1782 } else if (getLeftIdentity()) {
1783 if (leftType != outputType)
1784 return emitError("left=identity requires first argument to have the same "
1785 "type as the output");
1786 }
1787 if (!right.empty()) {
1788 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1789 outputType)))
1790 return failure();
1791 } else if (getRightIdentity()) {
1792 if (rightType != outputType)
1793 return emitError("right=identity requires second argument to have the "
1794 "same type as the output");
1795 }
1796 return success();
1797}
1798
1799LogicalResult UnaryOp::verify() {
1800 Type inputType = getX().getType();
1801 Type outputType = getOutput().getType();
1802
1803 // Check correct number of block arguments and return type for each
1804 // non-empty region.
1805 Region &present = getPresentRegion();
1806 if (!present.empty()) {
1807 if (failed(verifyNumBlockArgs(this, present, "present",
1808 TypeRange{inputType}, outputType)))
1809 return failure();
1810 }
1811 Region &absent = getAbsentRegion();
1812 if (!absent.empty()) {
1813 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1814 outputType)))
1815 return failure();
1816 // Absent branch can only yield invariant values.
1817 Block *absentBlock = &absent.front();
1818 Block *parent = getOperation()->getBlock();
1819 Value absentVal =
1820 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1821 if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1822 if (arg.getOwner() == parent)
1823 return emitError("absent region cannot yield linalg argument");
1824 } else if (Operation *def = absentVal.getDefiningOp()) {
1825 if (!isa<arith::ConstantOp>(def) &&
1826 (def->getBlock() == absentBlock || def->getBlock() == parent))
1827 return emitError("absent region cannot yield locally computed value");
1828 }
1829 }
1830 return success();
1831}
1832
1833bool ConcatenateOp::needsExtraSort() {
1834 SparseTensorType dstStt = getSparseTensorType(*this);
1835 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1836 return false;
1837
1838 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1839 return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1840 });
1841 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1842 // in all input/output buffers, and all input/output buffers have the same
1843 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1844 // CSC matrices along column).
1845 bool directLowerable =
1846 allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1847 return !directLowerable;
1848}
1849
1850LogicalResult ConcatenateOp::verify() {
1851 const auto dstTp = getSparseTensorType(*this);
1852 const Dimension concatDim = getDimension();
1853 const Dimension dimRank = dstTp.getDimRank();
1854
1855 if (getInputs().size() <= 1)
1856 return emitError("Need at least two tensors to concatenate.");
1857
1858 if (concatDim >= dimRank)
1859 return emitError(llvm::formatv(
1860 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1861 concatDim, dimRank));
1862
1863 for (const auto &it : llvm::enumerate(getInputs())) {
1864 const auto i = it.index();
1865 const auto srcTp = getSparseTensorType(it.value());
1866 if (srcTp.hasDynamicDimShape())
1867 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1868 const Dimension srcDimRank = srcTp.getDimRank();
1869 if (srcDimRank != dimRank)
1870 return emitError(
1871 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1872 "from the output tensor (rank={2}).",
1873 i, srcDimRank, dimRank));
1874 }
1875
1876 for (Dimension d = 0; d < dimRank; d++) {
1877 const Size dstSh = dstTp.getDimShape()[d];
1878 if (d == concatDim) {
1879 if (!ShapedType::isDynamic(dstSh)) {
1880 // If we reach here, then all inputs have static shapes. So we
1881 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1882 // to avoid redundant assertions in the loop.
1883 Size sumSz = 0;
1884 for (const auto src : getInputs())
1885 sumSz += getSparseTensorType(src).getDimShape()[d];
1886 // If all dimension are statically known, the sum of all the input
1887 // dimensions should be equal to the output dimension.
1888 if (sumSz != dstSh)
1889 return emitError(
1890 "The concatenation dimension of the output tensor should be the "
1891 "sum of all the concatenation dimensions of the input tensors.");
1892 }
1893 } else {
1894 Size prev = dstSh;
1895 for (const auto src : getInputs()) {
1896 const auto sh = getSparseTensorType(src).getDimShape()[d];
1897 if (!ShapedType::isDynamic(prev) && sh != prev)
1898 return emitError("All dimensions (expect for the concatenating one) "
1899 "should be equal.");
1900 prev = sh;
1901 }
1902 }
1903 }
1904
1905 return success();
1906}
1907
1908void PushBackOp::build(OpBuilder &builder, OperationState &result,
1909 Value curSize, Value inBuffer, Value value) {
1910 build(builder, result, curSize, inBuffer, value, Value());
1911}
1912
1913LogicalResult PushBackOp::verify() {
1914 if (Value n = getN()) {
1915 std::optional<int64_t> nValue = getConstantIntValue(n);
1916 if (nValue && nValue.value() < 1)
1917 return emitOpError("n must be not less than 1");
1918 }
1919 return success();
1920}
1921
1922LogicalResult CompressOp::verify() {
1923 const auto stt = getSparseTensorType(getTensor());
1924 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1925 return emitOpError("incorrect number of coordinates");
1926 return success();
1927}
1928
1929void ForeachOp::build(
1930 OpBuilder &builder, OperationState &result, Value tensor,
1931 ValueRange initArgs, AffineMapAttr order,
1932 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1933 bodyBuilder) {
1934 build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1935 // Builds foreach body.
1936 if (!bodyBuilder)
1937 return;
1938 const auto stt = getSparseTensorType(tensor);
1939 const Dimension dimRank = stt.getDimRank();
1940
1941 // Starts with `dimRank`-many coordinates.
1942 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1943 // Followed by one value.
1944 blockArgTypes.push_back(stt.getElementType());
1945 // Followed by the reduction variables.
1946 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1947
1948 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1949
1950 OpBuilder::InsertionGuard guard(builder);
1951 auto &region = *result.regions.front();
1952 Block *bodyBlock =
1953 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1954 bodyBuilder(builder, result.location,
1955 bodyBlock->getArguments().slice(0, dimRank),
1956 bodyBlock->getArguments()[dimRank],
1957 bodyBlock->getArguments().drop_front(dimRank + 1));
1958}
1959
1960LogicalResult ForeachOp::verify() {
1961 const auto t = getSparseTensorType(getTensor());
1962 const Dimension dimRank = t.getDimRank();
1963 const auto args = getBody()->getArguments();
1964
1965 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1966 return emitError("Level traverse order does not match tensor's level rank");
1967
1968 if (dimRank + 1 + getInitArgs().size() != args.size())
1969 return emitError("Unmatched number of arguments in the block");
1970
1971 if (getNumResults() != getInitArgs().size())
1972 return emitError("Mismatch in number of init arguments and results");
1973
1974 if (getResultTypes() != getInitArgs().getTypes())
1975 return emitError("Mismatch in types of init arguments and results");
1976
1977 // Cannot mark this const, because the getters aren't.
1978 auto yield = cast<YieldOp>(getBody()->getTerminator());
1979 if (yield.getNumOperands() != getNumResults() ||
1980 yield.getOperands().getTypes() != getResultTypes())
1981 return emitError("Mismatch in types of yield values and results");
1982
1983 const auto iTp = IndexType::get(getContext());
1984 for (Dimension d = 0; d < dimRank; d++)
1985 if (args[d].getType() != iTp)
1986 return emitError(
1987 llvm::formatv("Expecting Index type for argument at index {0}", d));
1988
1989 const auto elemTp = t.getElementType();
1990 const auto valueTp = args[dimRank].getType();
1991 if (elemTp != valueTp)
1992 return emitError(
1993 llvm::formatv("Unmatched element type between input tensor and "
1994 "block argument, expected:{0}, got: {1}",
1995 elemTp, valueTp));
1996 return success();
1997}
1998
1999OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
2000 if (getSparseTensorEncoding(getInputCoo().getType()) ==
2001 getSparseTensorEncoding(getResultCoo().getType()))
2002 return getInputCoo();
2003
2004 return {};
2005}
2006
2007LogicalResult ReorderCOOOp::verify() {
2008 SparseTensorType srcStt = getSparseTensorType(getInputCoo());
2009 SparseTensorType dstStt = getSparseTensorType(getResultCoo());
2010
2011 if (!srcStt.isCOOType() || !dstStt.isCOOType())
2012 return emitError("Expected COO sparse tensors only");
2013
2014 if (!srcStt.hasSameDimToLvl(dstStt))
2015 return emitError("Unmatched dim2lvl map between input and result COO");
2016
2017 if (srcStt.getPosType() != dstStt.getPosType() ||
2018 srcStt.getCrdType() != dstStt.getCrdType() ||
2019 srcStt.getElementType() != dstStt.getElementType())
2020 return emitError("Unmatched storage format between input and result COO");
2021
2022 return success();
2023}
2024
2025LogicalResult ReduceOp::verify() {
2026 Type inputType = getX().getType();
2027 Region &formula = getRegion();
2028 return verifyNumBlockArgs(this, formula, "reduce",
2029 TypeRange{inputType, inputType}, inputType);
2030}
2031
2032LogicalResult SelectOp::verify() {
2033 Builder b(getContext());
2034 Type inputType = getX().getType();
2035 Type boolType = b.getI1Type();
2036 Region &formula = getRegion();
2037 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
2038 boolType);
2039}
2040
2041LogicalResult SortOp::verify() {
2042 AffineMap xPerm = getPermMap();
2043 uint64_t nx = xPerm.getNumDims();
2044 if (nx < 1)
2045 return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2046
2047 if (!xPerm.isPermutation())
2048 return emitError(
2049 llvm::formatv("Expected a permutation map, got {0}", xPerm));
2050
2051 // We can't check the size of the buffers when n or buffer dimensions aren't
2052 // compile-time constants.
2053 std::optional<int64_t> cn = getConstantIntValue(getN());
2054 if (!cn)
2055 return success();
2056
2057 // Verify dimensions.
2058 const auto checkDim = [&](Value v, Size minSize,
2059 const char *message) -> LogicalResult {
2060 const Size sh = getMemRefType(v).getShape()[0];
2061 if (!ShapedType::isDynamic(sh) && sh < minSize)
2062 return emitError(
2063 llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2064 return success();
2065 };
2066 uint64_t n = cn.value();
2067 uint64_t ny = 0;
2068 if (auto nyAttr = getNyAttr())
2069 ny = nyAttr.getInt();
2070 if (failed(checkDim(getXy(), n * (nx + ny),
2071 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2072 return failure();
2073 for (Value opnd : getYs())
2074 if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2075 return failure();
2076
2077 return success();
2078}
2079
2080//===----------------------------------------------------------------------===//
2081// Sparse Tensor Iteration Operations.
2082//===----------------------------------------------------------------------===//
2083
2084IterSpaceType IteratorType::getIterSpaceType() const {
2085 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
2086 getHiLvl());
2087}
2088
2089IteratorType IterSpaceType::getIteratorType() const {
2090 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
2091}
2092
2093/// Parses a level range in the form "$lo `to` $hi"
2094/// or simply "$lo" if $hi - $lo = 1
2095static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2096 Level &lvlHi) {
2097 if (parser.parseInteger(result&: lvlLo))
2098 return failure();
2099
2100 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "to"))) {
2101 if (parser.parseInteger(result&: lvlHi))
2102 return failure();
2103 } else {
2104 lvlHi = lvlLo + 1;
2105 }
2106
2107 if (lvlHi <= lvlLo)
2108 return parser.emitError(loc: parser.getNameLoc(),
2109 message: "expect larger level upper bound than lower bound");
2110
2111 return success();
2112}
2113
2114/// Parses a level range in the form "$lo `to` $hi"
2115/// or simply "$lo" if $hi - $lo = 1
2116static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2117 IntegerAttr &lvlHiAttr) {
2118 Level lvlLo, lvlHi;
2119 if (parseLevelRange(parser, lvlLo, lvlHi))
2120 return failure();
2121
2122 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2123 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2124 return success();
2125}
2126
2127/// Prints a level range in the form "$lo `to` $hi"
2128/// or simply "$lo" if $hi - $lo = 1
2129static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2130
2131 if (lo + 1 == hi)
2132 p << lo;
2133 else
2134 p << lo << " to " << hi;
2135}
2136
2137/// Prints a level range in the form "$lo `to` $hi"
2138/// or simply "$lo" if $hi - $lo = 1
2139static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2140 IntegerAttr lvlHi) {
2141 unsigned lo = lvlLo.getValue().getZExtValue();
2142 unsigned hi = lvlHi.getValue().getZExtValue();
2143 printLevelRange(p, lo, hi);
2144}
2145
2146/// Parses a list of `optional` defined list in the form of
2147/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2148/// corresponding value is not defined (e.g., to represent an undefined
2149/// coordinate in the sparse iteration space).
2150static ParseResult parseOptionalDefinedList(
2151 OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2152 SmallVectorImpl<OpAsmParser::Argument> &definedArgs,
2153 unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2154 OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) {
2155 unsigned cnt = 0;
2156 ParseResult crdList =
2157 parser.parseCommaSeparatedList(delimiter, parseElementFn: [&]() -> ParseResult {
2158 if (parser.parseOptionalKeyword(keyword: "_")) {
2159 if (parser.parseArgument(result&: definedArgs.emplace_back()))
2160 return failure();
2161 definedSet.set(cnt);
2162 }
2163 cnt += 1;
2164 return success();
2165 });
2166
2167 if (cnt > maxCnt)
2168 return parser.emitError(loc: parser.getNameLoc(),
2169 message: "parsed more value than expected.");
2170
2171 if (failed(Result: crdList)) {
2172 return parser.emitError(
2173 loc: parser.getNameLoc(),
2174 message: "expecting SSA value or \"_\" for level coordinates");
2175 }
2176 assert(definedArgs.size() == definedSet.count());
2177 return success();
2178}
2179
2180static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2181 Block::BlockArgListType blocksArgs,
2182 I64BitSet definedSet) {
2183 if (definedSet.empty())
2184 return;
2185
2186 for (unsigned i = 0; i < size; i++) {
2187 if (definedSet[i]) {
2188 p << blocksArgs.front();
2189 blocksArgs = blocksArgs.drop_front();
2190 } else {
2191 p << "_";
2192 }
2193 if (i != size - 1)
2194 p << ", ";
2195 }
2196 assert(blocksArgs.empty());
2197}
2198
2199static ParseResult
2200parseUsedCoordList(OpAsmParser &parser, OperationState &state,
2201 SmallVectorImpl<OpAsmParser::Argument> &coords) {
2202 // Parse "at(%crd0, _, ...)"
2203 I64BitSet crdUsedLvlSet;
2204 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "at")) &&
2205 failed(Result: parseOptionalDefinedList(parser, state, definedSet&: crdUsedLvlSet, definedArgs&: coords)))
2206 return failure();
2207
2208 // Always use IndexType for the coordinate.
2209 for (auto &coord : coords)
2210 coord.type = parser.getBuilder().getIndexType();
2211
2212 // Set the CrdUsedLvl bitset.
2213 state.addAttribute("crdUsedLvls",
2214 parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
2215 return success();
2216}
2217
2218static ParseResult
2219parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
2220 SmallVectorImpl<OpAsmParser::Argument> &iterators,
2221 SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
2222 SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2223 SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2224
2225 // Parse "%iters, ... in %spaces, ..."
2226 if (parser.parseArgumentList(result&: iterators) || parser.parseKeyword(keyword: "in") ||
2227 parser.parseOperandList(result&: spaces))
2228 return failure();
2229
2230 if (iterators.size() != spaces.size())
2231 return parser.emitError(
2232 loc: parser.getNameLoc(),
2233 message: "mismatch in number of sparse iterators and sparse spaces");
2234
2235 SmallVector<OpAsmParser::Argument> coords;
2236 if (failed(Result: parseUsedCoordList(parser, state, coords)))
2237 return failure();
2238 size_t numCrds = coords.size();
2239
2240 // Parse "iter_args(%arg = %init, ...)"
2241 bool hasIterArgs = succeeded(Result: parser.parseOptionalKeyword(keyword: "iter_args"));
2242 if (hasIterArgs)
2243 if (parser.parseAssignmentList(lhs&: blockArgs, rhs&: initArgs))
2244 return failure();
2245
2246 blockArgs.append(RHS: coords);
2247
2248 SmallVector<Type> iterSpaceTps;
2249 // parse ": sparse_tensor.iter_space -> ret"
2250 if (parser.parseColon() || parser.parseTypeList(result&: iterSpaceTps))
2251 return failure();
2252 if (iterSpaceTps.size() != spaces.size())
2253 return parser.emitError(loc: parser.getNameLoc(),
2254 message: "mismatch in number of iteration space operands "
2255 "and iteration space types");
2256
2257 for (auto [it, tp] : llvm::zip_equal(t&: iterators, u&: iterSpaceTps)) {
2258 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
2259 if (!spaceTp)
2260 return parser.emitError(loc: parser.getNameLoc(),
2261 message: "expected sparse_tensor.iter_space type for "
2262 "iteration space operands");
2263 it.type = spaceTp.getIteratorType();
2264 }
2265
2266 if (hasIterArgs)
2267 if (parser.parseArrowTypeList(result&: state.types))
2268 return failure();
2269
2270 // Resolves input operands.
2271 if (parser.resolveOperands(operands&: spaces, types&: iterSpaceTps, loc: parser.getNameLoc(),
2272 result&: state.operands))
2273 return failure();
2274
2275 if (hasIterArgs) {
2276 // Strip off leading args that used for coordinates.
2277 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(N: numCrds);
2278 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2279 return parser.emitError(
2280 loc: parser.getNameLoc(),
2281 message: "mismatch in number of iteration arguments and return values");
2282 }
2283
2284 for (auto [it, init, tp] : llvm::zip_equal(t&: args, u&: initArgs, args&: state.types)) {
2285 it.type = tp;
2286 if (parser.resolveOperand(operand: init, type: tp, result&: state.operands))
2287 return failure();
2288 }
2289 }
2290 return success();
2291}
2292
2293static ParseResult
2294parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
2295 SmallVectorImpl<Value> &spacesVals,
2296 SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
2297
2298 // Parse "(%spaces, ...)"
2299 SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2300 if (parser.parseOperandList(result&: spaces, delimiter: OpAsmParser::Delimiter::Paren))
2301 return failure();
2302
2303 SmallVector<OpAsmParser::Argument> coords;
2304 if (failed(Result: parseUsedCoordList(parser, state, coords)))
2305 return failure();
2306 size_t numCrds = coords.size();
2307
2308 // Parse "iter_args(%arg = %init, ...)"
2309 SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2310 bool hasIterArgs = succeeded(Result: parser.parseOptionalKeyword(keyword: "iter_args"));
2311 if (hasIterArgs)
2312 if (parser.parseAssignmentList(lhs&: blockArgs, rhs&: initArgs))
2313 return failure();
2314 blockArgs.append(RHS: coords);
2315
2316 SmallVector<Type> iterSpaceTps;
2317 // parse ": (sparse_tensor.iter_space, ...) -> ret"
2318 if (parser.parseColon() || parser.parseLParen() ||
2319 parser.parseTypeList(result&: iterSpaceTps) || parser.parseRParen())
2320 return failure();
2321
2322 if (iterSpaceTps.size() != spaces.size())
2323 return parser.emitError(loc: parser.getNameLoc(),
2324 message: "mismatch in number of iteration space operands "
2325 "and iteration space types");
2326
2327 if (hasIterArgs)
2328 if (parser.parseArrowTypeList(result&: state.types))
2329 return failure();
2330
2331 // Resolves input sparse iteration spaces.
2332 if (parser.resolveOperands(operands&: spaces, types&: iterSpaceTps, loc: parser.getNameLoc(),
2333 result&: spacesVals))
2334 return failure();
2335 state.operands.append(RHS: spacesVals);
2336
2337 if (hasIterArgs) {
2338 // Strip off trailing args that used for coordinates.
2339 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(N: numCrds);
2340 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2341 return parser.emitError(
2342 loc: parser.getNameLoc(),
2343 message: "mismatch in number of iteration arguments and return values");
2344 }
2345
2346 for (auto [it, init, tp] : llvm::zip_equal(t&: args, u&: initArgs, args&: state.types)) {
2347 it.type = tp;
2348 if (parser.resolveOperand(operand: init, type: tp, result&: state.operands))
2349 return failure();
2350 }
2351 }
2352 return success();
2353}
2354
2355LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2356 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2357 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2358 SmallVectorImpl<mlir::Type> &ret) {
2359
2360 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2361 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2362 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2363 adaptor.getHiLvl()));
2364 return success();
2365}
2366
2367LogicalResult ExtractIterSpaceOp::verify() {
2368 if (getLoLvl() >= getHiLvl())
2369 return emitOpError("expected smaller level low than level high");
2370
2371 TypedValue<IteratorType> pIter = getParentIter();
2372 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2373 return emitOpError(
2374 "parent iterator should be specified iff level lower bound equals 0");
2375 }
2376
2377 if (pIter) {
2378 IterSpaceType spaceTp = getExtractedSpace().getType();
2379 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2380 return emitOpError(
2381 "mismatch in parent iterator encoding and iteration space encoding.");
2382
2383 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2384 return emitOpError("parent iterator should be used to extract an "
2385 "iteration space from a consecutive level.");
2386 }
2387
2388 return success();
2389}
2390
2391LogicalResult ExtractValOp::verify() {
2392 auto stt = getSparseTensorType(getTensor());
2393 auto itTp = getIterator().getType();
2394
2395 if (stt.getEncoding() != itTp.getEncoding())
2396 return emitOpError("mismatch in tensor encoding and iterator encoding.");
2397
2398 if (stt.getLvlRank() != itTp.getHiLvl())
2399 return emitOpError("must use last-level iterator to extract values. ");
2400
2401 return success();
2402}
2403
2404struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2405 using OpRewritePattern::OpRewritePattern;
2406
2407 LogicalResult matchAndRewrite(IterateOp iterateOp,
2408 PatternRewriter &rewriter) const override {
2409 I64BitSet newUsedLvls(0);
2410 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2411 for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2412 if (auto crd = iterateOp.getLvlCrd(i)) {
2413 if (crd->getUsers().empty())
2414 toRemove.set(crd->getArgNumber());
2415 else
2416 newUsedLvls.set(i);
2417 }
2418 }
2419
2420 // All coordinates are used.
2421 if (toRemove.none())
2422 return failure();
2423
2424 rewriter.startOpModification(op: iterateOp);
2425 iterateOp.setCrdUsedLvls(newUsedLvls);
2426 iterateOp.getBody()->eraseArguments(toRemove);
2427 rewriter.finalizeOpModification(op: iterateOp);
2428 return success();
2429 }
2430};
2431
2432void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2433 mlir::MLIRContext *context) {
2434 results.add<RemoveUnusedLvlCrds>(context);
2435}
2436
2437void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2438 Value iterSpace, ValueRange initArgs) {
2439 unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim();
2440 // All ones.
2441 I64BitSet set((1 << rank) - 1);
2442 return build(builder, odsState, iterSpace, initArgs, set);
2443}
2444
2445void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2446 Value iterSpace, ValueRange initArgs,
2447 I64BitSet crdUsedLvls) {
2448 OpBuilder::InsertionGuard guard(builder);
2449
2450 odsState.addOperands(iterSpace);
2451 odsState.addOperands(initArgs);
2452 odsState.getOrAddProperties<Properties>().crdUsedLvls =
2453 builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls);
2454 Region *bodyRegion = odsState.addRegion();
2455 odsState.addTypes(initArgs.getTypes());
2456 Block *bodyBlock = builder.createBlock(bodyRegion);
2457
2458 // Starts with a list of user-provided loop arguments.
2459 for (Value v : initArgs)
2460 bodyBlock->addArgument(v.getType(), v.getLoc());
2461
2462 // Follows by a list of used coordinates.
2463 for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2464 bodyBlock->addArgument(builder.getIndexType(), odsState.location);
2465
2466 // Ends with sparse iterator
2467 bodyBlock->addArgument(
2468 llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2469 odsState.location);
2470}
2471
2472ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2473 OpAsmParser::Argument iterator;
2474 OpAsmParser::UnresolvedOperand iterSpace;
2475
2476 SmallVector<OpAsmParser::Argument> iters, iterArgs;
2477 if (parseSparseIterateLoop(parser, result, iters, iterArgs))
2478 return failure();
2479 if (iters.size() != 1)
2480 return parser.emitError(parser.getNameLoc(),
2481 "expected only one iterator/iteration space");
2482
2483 iterArgs.append(iters);
2484 Region *body = result.addRegion();
2485 if (parser.parseRegion(*body, iterArgs))
2486 return failure();
2487
2488 IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2489
2490 // Parse the optional attribute list.
2491 if (parser.parseOptionalAttrDict(result.attributes))
2492 return failure();
2493
2494 return success();
2495}
2496
2497/// Prints the initialization list in the form of
2498/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2499/// where 'inner' values are assumed to be region arguments and 'outer' values
2500/// are regular SSA values.
2501static void printInitializationList(OpAsmPrinter &p,
2502 Block::BlockArgListType blocksArgs,
2503 ValueRange initializers,
2504 StringRef prefix = "") {
2505 assert(blocksArgs.size() == initializers.size() &&
2506 "expected same length of arguments and initializers");
2507 if (initializers.empty())
2508 return;
2509
2510 p << prefix << '(';
2511 llvm::interleaveComma(c: llvm::zip(t&: blocksArgs, u&: initializers), os&: p, each_fn: [&](auto it) {
2512 p << std::get<0>(it) << " = " << std::get<1>(it);
2513 });
2514 p << ")";
2515}
2516
2517template <typename SparseLoopOp>
2518static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2519 if (op.getInitArgs().size() != op.getNumResults()) {
2520 return op.emitOpError(
2521 "mismatch in number of loop-carried values and defined values");
2522 }
2523 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2524 return op.emitOpError("required out-of-bound coordinates");
2525
2526 return success();
2527}
2528
2529LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2530LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2531
2532void IterateOp::print(OpAsmPrinter &p) {
2533 p << " " << getIterator() << " in " << getIterSpace();
2534 if (!getCrdUsedLvls().empty()) {
2535 p << " at(";
2536 printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
2537 p << ")";
2538 }
2539 printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
2540
2541 p << " : " << getIterSpace().getType() << " ";
2542 if (!getInitArgs().empty())
2543 p.printArrowTypeList(getInitArgs().getTypes());
2544
2545 p << " ";
2546 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2547 /*printBlockTerminators=*/!getInitArgs().empty());
2548}
2549
2550LogicalResult IterateOp::verifyRegions() {
2551 if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2552 return emitOpError("mismatch in iterator and iteration space type");
2553 if (getNumRegionIterArgs() != getNumResults())
2554 return emitOpError(
2555 "mismatch in number of basic block args and defined values");
2556
2557 auto initArgs = getInitArgs();
2558 auto iterArgs = getRegionIterArgs();
2559 auto yieldVals = getYieldedValues();
2560 auto opResults = getResults();
2561 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2562 opResults.size()})) {
2563 return emitOpError() << "number mismatch between iter args and results.";
2564 }
2565
2566 for (auto [i, init, iter, yield, ret] :
2567 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2568 if (init.getType() != ret.getType())
2569 return emitOpError() << "types mismatch between " << i
2570 << "th iter operand and defined value";
2571 if (iter.getType() != ret.getType())
2572 return emitOpError() << "types mismatch between " << i
2573 << "th iter region arg and defined value";
2574 if (yield.getType() != ret.getType())
2575 return emitOpError() << "types mismatch between " << i
2576 << "th yield value and defined value";
2577 }
2578
2579 return success();
2580}
2581
2582/// OpInterfaces' methods implemented by IterateOp.
2583SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2584
2585MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2586 return getInitArgsMutable();
2587}
2588
2589Block::BlockArgListType IterateOp::getRegionIterArgs() {
2590 return getRegion().getArguments().take_front(getNumRegionIterArgs());
2591}
2592
2593std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2594 return cast<sparse_tensor::YieldOp>(
2595 getRegion().getBlocks().front().getTerminator())
2596 .getResultsMutable();
2597}
2598
2599std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2600
2601OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2602 return getInitArgs();
2603}
2604
2605void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2606 SmallVectorImpl<RegionSuccessor> &regions) {
2607 // Both the operation itself and the region may be branching into the body
2608 // or back into the operation itself.
2609 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2610 // It is possible for loop not to enter the body.
2611 regions.push_back(RegionSuccessor(getResults()));
2612}
2613
2614void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2615 ValueRange iterSpaces, ValueRange initArgs,
2616 unsigned numCases) {
2617 unsigned rank =
2618 cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2619 // All ones.
2620 I64BitSet set((1 << rank) - 1);
2621 // Generates all-zero case bits (they only serve as placeholders), which are
2622 // supposed to be overriden later. We need to preallocate all the regions as
2623 // mlir::Region cannot be dynamically added later after the operation is
2624 // created.
2625 SmallVector<int64_t> caseBits(numCases, 0);
2626 ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2627 return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2628 initArgs, set, cases,
2629 /*caseRegionsCount=*/numCases);
2630}
2631
2632ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2633
2634 SmallVector<Value> spaces;
2635 // The block argument list of each regions, it is arranged in the order of
2636 // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2637 SmallVector<OpAsmParser::Argument> blockArgs;
2638 if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs))
2639 return failure();
2640
2641 result.addAttribute("operandSegmentSizes",
2642 parser.getBuilder().getDenseI32ArrayAttr(
2643 {static_cast<int32_t>(spaces.size()),
2644 static_cast<int32_t>(result.types.size())}));
2645
2646 SmallVector<Attribute> cases;
2647 while (succeeded(parser.parseOptionalKeyword("case"))) {
2648 // Parse one region per case.
2649 I64BitSet definedItSet;
2650 SmallVector<OpAsmParser::Argument> definedIts;
2651 if (parseOptionalDefinedList(parser, result, definedItSet, definedIts,
2652 spaces.size(), OpAsmParser::Delimiter::None))
2653 return failure();
2654
2655 cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet));
2656
2657 for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) {
2658 // Resolve the iterator type based on the iteration space type.
2659 auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType());
2660 definedIts[i].type = spaceTp.getIteratorType();
2661 }
2662 definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end());
2663 Region *body = result.addRegion();
2664 if (parser.parseRegion(*body, definedIts))
2665 return failure();
2666
2667 CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
2668 }
2669
2670 result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases));
2671
2672 // Parse the optional attribute list.
2673 if (parser.parseOptionalAttrDict(result.attributes))
2674 return failure();
2675
2676 return success();
2677}
2678
2679void CoIterateOp::print(OpAsmPrinter &p) {
2680 p << " (";
2681 llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; });
2682 p << ")";
2683
2684 if (!getCrdUsedLvls().empty()) {
2685 p << " at(";
2686 printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls());
2687 p << ")";
2688 }
2689
2690 printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args");
2691
2692 p << " : (" << getIterSpaces().getTypes() << ")";
2693 if (!getInitArgs().empty())
2694 p.printArrowTypeList(getInitArgs().getTypes());
2695
2696 for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2697 p.printNewline();
2698 p << "case ";
2699 printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx),
2700 getRegionDefinedSpace(idx));
2701 p << " ";
2702 p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false,
2703 /*printBlockTerminators=*/!getInitArgs().empty());
2704 }
2705}
2706
2707ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2708 return cast<sparse_tensor::YieldOp>(
2709 getRegion(regionIdx).getBlocks().front().getTerminator())
2710 .getResults();
2711}
2712
2713LogicalResult CoIterateOp::verifyRegions() {
2714 for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2715 if (getNumRegionIterArgs() != getNumResults())
2716 return emitOpError(
2717 "mismatch in number of basic block args and defined values");
2718
2719 auto initArgs = getInitArgs();
2720 auto iterArgs = getRegionIterArgs(r);
2721 auto yieldVals = getYieldedValues(r);
2722 auto opResults = getResults();
2723 if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2724 opResults.size()})) {
2725 return emitOpError()
2726 << "number mismatch between iter args and results on " << r
2727 << "th region";
2728 }
2729
2730 for (auto [i, init, iter, yield, ret] :
2731 llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2732 if (init.getType() != ret.getType())
2733 return emitOpError()
2734 << "types mismatch between " << i
2735 << "th iter operand and defined value on " << r << "th region";
2736 if (iter.getType() != ret.getType())
2737 return emitOpError() << "types mismatch between " << i
2738 << "th iter region arg and defined value on " << r
2739 << "th region";
2740 if (yield.getType() != ret.getType())
2741 return emitOpError()
2742 << "types mismatch between " << i
2743 << "th yield value and defined value on " << r << "th region";
2744 }
2745 }
2746
2747 auto cases = getRegionDefinedSpaces();
2748 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2749 if (set.size() != getNumRegions())
2750 return emitOpError("contains duplicated cases.");
2751
2752 return success();
2753}
2754
2755SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2756 SmallVector<Region *> ret;
2757 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2758 for (Region &r : getCaseRegions())
2759 if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2760 ret.push_back(&r);
2761
2762 return ret;
2763}
2764
2765//===----------------------------------------------------------------------===//
2766// Sparse Tensor Dialect Setups.
2767//===----------------------------------------------------------------------===//
2768
2769/// Materialize a single constant operation from a given attribute value with
2770/// the desired resultant type.
2771Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2772 Attribute value, Type type,
2773 Location loc) {
2774 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2775 return op;
2776 return nullptr;
2777}
2778
2779void SparseTensorDialect::initialize() {
2780 addAttributes<
2781#define GET_ATTRDEF_LIST
2782#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2783 >();
2784 addTypes<
2785#define GET_TYPEDEF_LIST
2786#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2787 >();
2788 addOperations<
2789#define GET_OP_LIST
2790#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2791 >();
2792 declarePromisedInterfaces<
2793 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2794 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2795 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2796}
2797
2798#define GET_OP_CLASSES
2799#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2800
2801#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
2802

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp