1//===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "Detail/DimLvlMapParser.h"
12
13#include "mlir/Dialect/SparseTensor/IR/Enums.h"
14#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20#include "mlir/Dialect/Complex/IR/Complex.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/DialectImplementation.h"
24#include "mlir/IR/OpImplementation.h"
25#include "mlir/IR/PatternMatch.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28
29#define GET_ATTRDEF_CLASSES
30#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32
33// Forward declarations, following custom print/parsing methods are referenced
34// by the generated code for SparseTensorTypes.td.
35static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
36 mlir::sparse_tensor::Level &,
37 mlir::sparse_tensor::Level &);
38static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
39 mlir::sparse_tensor::Level);
40
41#define GET_TYPEDEF_CLASSES
42#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
43
44using namespace mlir;
45using namespace mlir::sparse_tensor;
46
47// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
48// well.
49namespace mlir::sparse_tensor {
50llvm::hash_code hash_value(LevelType lt) {
51 return llvm::hash_value(value: static_cast<uint64_t>(lt));
52}
53} // namespace mlir::sparse_tensor
54
55//===----------------------------------------------------------------------===//
56// Local Convenience Methods.
57//===----------------------------------------------------------------------===//
58
59static constexpr bool acceptBitWidth(unsigned bitWidth) {
60 switch (bitWidth) {
61 case 0:
62 case 8:
63 case 16:
64 case 32:
65 case 64:
66 return true;
67 default:
68 return false;
69 }
70}
71
72static SmallVector<Size>
73getSparseFieldShape(const SparseTensorEncodingAttr enc,
74 std::optional<ArrayRef<int64_t>> dimShape) {
75 assert(enc);
76 // With only encoding, we can not determine the static shape for leading
77 // batch levels, we therefore return a dynamic shape memref instead.
78 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
79 if (dimShape.has_value()) {
80 // If the actual tensor shape is provided, we can then refine the leading
81 // batch dimension.
82 SmallVector<int64_t> lvlShape =
83 enc.translateShape(srcShape: *dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(in_start: lvlShape.begin(),
85 in_end: lvlShape.begin() + enc.getBatchLvlRank());
86 }
87 // Another dynamic dimension to store the sparse level.
88 memrefShape.push_back(Elt: ShapedType::kDynamic);
89 return memrefShape;
90}
91
92//===----------------------------------------------------------------------===//
93// SparseTensorDialect StorageLayout.
94//===----------------------------------------------------------------------===//
95
96static constexpr Level kInvalidLevel = -1u;
97static constexpr Level kInvalidFieldIndex = -1u;
98static constexpr FieldIndex kDataFieldStartingIdx = 0;
99
100void StorageLayout::foreachField(
101 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
102 LevelType)>
103 callback) const {
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
106 SmallVector<COOSegment> cooSegs = enc.getCOOSegments();
107 FieldIndex fieldIdx = kDataFieldStartingIdx;
108
109 ArrayRef cooSegsRef = cooSegs;
110 // Per-level storage.
111 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
112 const auto lt = lvlTypes[l];
113 if (isWithPosLT(lt)) {
114 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
115 return;
116 }
117 if (isWithCrdLT(lt)) {
118 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
119 return;
120 }
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
123 // AoS COO, all singletons are fused into one memrefs. Skips the entire
124 // COO segement.
125 l = cooSegsRef.front().lvlRange.second;
126 } else {
127 // SoA COO, each singleton level has one memref.
128 l++;
129 }
130 // Expire handled COO segment.
131 cooSegsRef = cooSegsRef.drop_front();
132 } else {
133 // Non COO levels.
134 l++;
135 }
136 }
137 // The values array.
138 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
139 LevelFormat::Undef)))
140 return;
141 // Put metadata at the end.
142 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
143 LevelFormat::Undef)))
144 return;
145}
146
147void sparse_tensor::foreachFieldAndTypeInSparseTensor(
148 SparseTensorType stt,
149 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
150 LevelType)>
151 callback) {
152 assert(stt.hasEncoding());
153
154 SmallVector<int64_t> memrefShape =
155 getSparseFieldShape(enc: stt.getEncoding(), dimShape: stt.getDimShape());
156
157 const Type specType = StorageSpecifierType::get(encoding: stt.getEncoding());
158 // memref<[batch] x ? x pos> positions
159 const Type posMemType = MemRefType::get(shape: memrefShape, elementType: stt.getPosType());
160 // memref<[batch] x ? x crd> coordinates
161 const Type crdMemType = MemRefType::get(shape: memrefShape, elementType: stt.getCrdType());
162 // memref<[batch] x ? x eltType> values
163 const Type valMemType = MemRefType::get(shape: memrefShape, elementType: stt.getElementType());
164
165 StorageLayout(stt).foreachField(callback: [specType, posMemType, crdMemType, valMemType,
166 callback](FieldIndex fieldIdx,
167 SparseTensorFieldKind fieldKind,
168 Level lvl, LevelType lt) -> bool {
169 switch (fieldKind) {
170 case SparseTensorFieldKind::StorageSpec:
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
172 case SparseTensorFieldKind::PosMemRef:
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
174 case SparseTensorFieldKind::CrdMemRef:
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
176 case SparseTensorFieldKind::ValMemRef:
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
178 };
179 llvm_unreachable("unrecognized field kind");
180 });
181}
182
183unsigned StorageLayout::getNumFields() const {
184 unsigned numFields = 0;
185 foreachField(callback: [&numFields](FieldIndex, SparseTensorFieldKind, Level,
186 LevelType) -> bool {
187 numFields++;
188 return true;
189 });
190 return numFields;
191}
192
193unsigned StorageLayout::getNumDataFields() const {
194 unsigned numFields = 0; // one value memref
195 foreachField(callback: [&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
196 LevelType) -> bool {
197 if (fidx >= kDataFieldStartingIdx)
198 numFields++;
199 return true;
200 });
201 numFields -= 1; // the last field is StorageSpecifier
202 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
203 return numFields;
204}
205
206std::pair<FieldIndex, unsigned>
207StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
208 std::optional<Level> lvl) const {
209 FieldIndex fieldIdx = kInvalidFieldIndex;
210 unsigned stride = 1;
211 if (kind == SparseTensorFieldKind::CrdMemRef) {
212 assert(lvl.has_value());
213 const Level cooStart = enc.getAoSCOOStart();
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
216 lvl = cooStart;
217 stride = lvlRank - cooStart;
218 }
219 }
220 foreachField(callback: [lvl, kind, &fieldIdx](FieldIndex fIdx,
221 SparseTensorFieldKind fKind, Level fLvl,
222 LevelType lt) -> bool {
223 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
224 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
225 fieldIdx = fIdx;
226 // Returns false to break the iteration.
227 return false;
228 }
229 return true;
230 });
231 assert(fieldIdx != kInvalidFieldIndex);
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
233}
234
235//===----------------------------------------------------------------------===//
236// SparseTensorDialect Attribute Methods.
237//===----------------------------------------------------------------------===//
238
239std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(t: static_cast<uint64_t>(v));
242}
243
244std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
245 return getStatic(v: getOffset());
246}
247
248std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
249 return getStatic(v: getStride());
250}
251
252std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
253 return getStatic(v: getSize());
254}
255
256bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
257 return isDynamic(v: getOffset()) && isDynamic(v: getStride()) &&
258 isDynamic(v: getSize());
259}
260
261std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ? "?" : std::to_string(val: v);
263}
264
265void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
266 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
267 os << '(';
268 os << getStaticString(v: getOffset());
269 os << ", ";
270 os << getStaticString(v: getSize());
271 os << ", ";
272 os << getStaticString(v: getStride());
273 os << ')';
274}
275
276void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
277 print(os&: printer.getStream());
278}
279
280static ParseResult parseOptionalStaticSlice(int64_t &result,
281 AsmParser &parser) {
282 auto parseResult = parser.parseOptionalInteger(result);
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() && result < 0) {
285 parser.emitError(
286 loc: parser.getCurrentLocation(),
287 message: "expect positive value or ? for slice offset/size/stride");
288 return failure();
289 }
290 return parseResult.value();
291 }
292
293 // Else, and '?' which represented dynamic slice
294 result = SparseTensorDimSliceAttr::kDynamic;
295 return parser.parseQuestion();
296}
297
298Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
300
301 if (failed(Result: parser.parseLParen()) ||
302 failed(Result: parseOptionalStaticSlice(result&: offset, parser)) ||
303 failed(Result: parser.parseComma()) ||
304 failed(Result: parseOptionalStaticSlice(result&: size, parser)) ||
305 failed(Result: parser.parseComma()) ||
306 failed(Result: parseOptionalStaticSlice(result&: stride, parser)) ||
307 failed(Result: parser.parseRParen()))
308 return {};
309
310 return parser.getChecked<SparseTensorDimSliceAttr>(params: parser.getContext(),
311 params&: offset, params&: size, params&: stride);
312}
313
314LogicalResult
315SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(v: offset) && offset < 0)
318 return emitError() << "expect non-negative value or ? for slice offset";
319 if (!isDynamic(v: size) && size <= 0)
320 return emitError() << "expect positive value or ? for slice size";
321 if (!isDynamic(v: stride) && stride <= 0)
322 return emitError() << "expect positive value or ? for slice stride";
323 return success();
324}
325
326SparseTensorEncodingAttr
327SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
328 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
329 return SparseTensorEncodingAttr::get(
330 context: getContext(), lvlTypes: getLvlTypes(), dimToLvl, lvlToDim: AffineMap(), posWidth: getPosWidth(),
331 crdWidth: getCrdWidth(), explicitVal: getExplicitVal(), implicitVal: getImplicitVal());
332}
333
334SparseTensorEncodingAttr
335SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
336 return withDimToLvl(dimToLvl: enc ? enc.getDimToLvl() : AffineMap());
337}
338
339SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
340 return withDimToLvl(dimToLvl: AffineMap());
341}
342
343SparseTensorEncodingAttr
344SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
345 unsigned crdWidth) const {
346 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
347 return SparseTensorEncodingAttr::get(
348 context: getContext(), lvlTypes: getLvlTypes(), dimToLvl: getDimToLvl(), lvlToDim: getLvlToDim(), posWidth,
349 crdWidth, explicitVal: getExplicitVal(), implicitVal: getImplicitVal());
350}
351
352SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
353 return withBitWidths(posWidth: 0, crdWidth: 0);
354}
355
356SparseTensorEncodingAttr
357SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const {
358 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
359 return SparseTensorEncodingAttr::get(
360 context: getContext(), lvlTypes: getLvlTypes(), dimToLvl: getDimToLvl(), lvlToDim: getLvlToDim(), posWidth: getPosWidth(),
361 crdWidth: getCrdWidth(), explicitVal, implicitVal: getImplicitVal());
362}
363
364SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const {
365 return withExplicitVal(explicitVal: Attribute());
366}
367
368SparseTensorEncodingAttr
369SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const {
370 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
371 return SparseTensorEncodingAttr::get(
372 context: getContext(), lvlTypes: getLvlTypes(), dimToLvl: getDimToLvl(), lvlToDim: getLvlToDim(), posWidth: getPosWidth(),
373 crdWidth: getCrdWidth(), explicitVal: getExplicitVal(), implicitVal);
374}
375
376SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const {
377 return withImplicitVal(implicitVal: Attribute());
378}
379
380SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
381 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
382 return SparseTensorEncodingAttr::get(
383 context: getContext(), lvlTypes: getLvlTypes(), dimToLvl: getDimToLvl(), lvlToDim: getLvlToDim(), posWidth: getPosWidth(),
384 crdWidth: getCrdWidth(), explicitVal: getExplicitVal(), implicitVal: getImplicitVal(), dimSlices);
385}
386
387SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
388 return withDimSlices(dimSlices: ArrayRef<SparseTensorDimSliceAttr>{});
389}
390
391uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
392 ArrayRef<LevelType> lvlTypes = getLvlTypes();
393 auto lastBatch = std::find_if(first: lvlTypes.rbegin(), last: lvlTypes.rend(), pred: isBatchLT);
394 return std::distance(first: lastBatch, last: lvlTypes.rend());
395}
396
397bool SparseTensorEncodingAttr::isAllDense() const {
398 return !getImpl() || llvm::all_of(Range: getLvlTypes(), P: isDenseLT);
399}
400
401bool SparseTensorEncodingAttr::isAllOrdered() const {
402 return !getImpl() || llvm::all_of(Range: getLvlTypes(), P: isOrderedLT);
403}
404
405Type SparseTensorEncodingAttr::getCrdElemType() const {
406 if (!getImpl())
407 return nullptr;
408 if (getCrdWidth())
409 return IntegerType::get(context: getContext(), width: getCrdWidth());
410 return IndexType::get(context: getContext());
411}
412
413Type SparseTensorEncodingAttr::getPosElemType() const {
414 if (!getImpl())
415 return nullptr;
416 if (getPosWidth())
417 return IntegerType::get(context: getContext(), width: getPosWidth());
418 return IndexType::get(context: getContext());
419}
420
421MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
422 std::optional<ArrayRef<int64_t>> dimShape) const {
423 SmallVector<Size> shape = getSparseFieldShape(enc: *this, dimShape);
424 return MemRefType::get(shape, elementType: getCrdElemType());
425}
426
427MemRefType SparseTensorEncodingAttr::getPosMemRefType(
428 std::optional<ArrayRef<int64_t>> dimShape) const {
429 SmallVector<Size> shape = getSparseFieldShape(enc: *this, dimShape);
430 return MemRefType::get(shape, elementType: getPosElemType());
431}
432
433bool SparseTensorEncodingAttr::isIdentity() const {
434 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
435}
436
437bool SparseTensorEncodingAttr::isPermutation() const {
438 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
439}
440
441Dimension SparseTensorEncodingAttr::getDimRank() const {
442 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
443 const auto dimToLvl = getDimToLvl();
444 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
445}
446
447Level SparseTensorEncodingAttr::getLvlRank() const {
448 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
449 return getLvlTypes().size();
450}
451
452LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
453 if (!getImpl())
454 return LevelFormat::Batch;
455 assert(l < getLvlRank() && "Level is out of bounds");
456 return getLvlTypes()[l];
457}
458
459bool SparseTensorEncodingAttr::isSlice() const {
460 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
461 return !getDimSlices().empty();
462}
463
464SparseTensorDimSliceAttr
465SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
466 assert(isSlice() && "Is not a slice");
467 const auto dimSlices = getDimSlices();
468 assert(dim < dimSlices.size() && "Dimension is out of bounds");
469 return dimSlices[dim];
470}
471
472std::optional<uint64_t>
473SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
474 return getDimSlice(dim).getStaticOffset();
475}
476
477std::optional<uint64_t>
478SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
479 return getDimSlice(dim).getStaticStride();
480}
481
482std::optional<uint64_t>
483SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
484 return getStaticDimSliceOffset(dim: toDim(enc: *this, l: lvl));
485}
486
487std::optional<uint64_t>
488SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
489 return getStaticDimSliceStride(dim: toDim(enc: *this, l: lvl));
490}
491
492SmallVector<int64_t>
493SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
494 CrdTransDirectionKind dir) const {
495 if (isIdentity())
496 return SmallVector<int64_t>(srcShape);
497
498 SmallVector<int64_t> ret;
499 unsigned rank =
500 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
501 ret.reserve(N: rank);
502
503 if (isPermutation()) {
504 for (unsigned r = 0; r < rank; r++) {
505 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(enc: *this, l: r)
506 : toLvl(enc: *this, d: r);
507 ret.push_back(Elt: srcShape[trans]);
508 }
509 return ret;
510 }
511
512 // Handle non-permutation maps.
513 AffineMap transMap =
514 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
515
516 SmallVector<AffineExpr> dimRep;
517 dimRep.reserve(N: srcShape.size());
518 for (int64_t sz : srcShape) {
519 if (ShapedType::isStatic(dValue: sz)) {
520 // Push back the max coordinate for the given dimension/level size.
521 dimRep.push_back(Elt: getAffineConstantExpr(constant: sz - 1, context: getContext()));
522 } else {
523 // A dynamic size, use a AffineDimExpr to symbolize the value.
524 dimRep.push_back(Elt: getAffineDimExpr(position: dimRep.size(), context: getContext()));
525 }
526 };
527
528 for (AffineExpr exp : transMap.getResults()) {
529 // Do constant propagation on the affine map.
530 AffineExpr evalExp =
531 simplifyAffineExpr(expr: exp.replaceDims(dimReplacements: dimRep), numDims: srcShape.size(), numSymbols: 0);
532 // use llvm namespace here to avoid ambiguity
533 if (auto c = llvm::dyn_cast<AffineConstantExpr>(Val&: evalExp)) {
534 ret.push_back(Elt: c.getValue() + 1);
535 } else {
536 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(Val&: evalExp);
537 mod && mod.getKind() == AffineExprKind::Mod) {
538 // We can still infer a static bound for expressions in form
539 // "d % constant" since d % constant \in [0, constant).
540 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(Val: mod.getRHS())) {
541 ret.push_back(Elt: bound.getValue());
542 continue;
543 }
544 }
545 ret.push_back(Elt: ShapedType::kDynamic);
546 }
547 }
548 assert(ret.size() == rank);
549 return ret;
550}
551
552ValueRange
553SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
554 ValueRange crds,
555 CrdTransDirectionKind dir) const {
556 if (!getImpl())
557 return crds;
558
559 SmallVector<Type> retType(
560 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
561 builder.getIndexType());
562 auto transOp = builder.create<CrdTranslateOp>(location: loc, args&: retType, args&: crds, args&: dir, args: *this);
563 return transOp.getOutCrds();
564}
565
566Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
567 // Open "<{" part.
568 if (failed(Result: parser.parseLess()))
569 return {};
570 if (failed(Result: parser.parseLBrace()))
571 return {};
572
573 // Process the data from the parsed dictionary value into struct-like data.
574 SmallVector<LevelType> lvlTypes;
575 SmallVector<SparseTensorDimSliceAttr> dimSlices;
576 AffineMap dimToLvl = {};
577 AffineMap lvlToDim = {};
578 unsigned posWidth = 0;
579 unsigned crdWidth = 0;
580 Attribute explicitVal;
581 Attribute implicitVal;
582 StringRef attrName;
583 SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth",
584 "explicitVal", "implicitVal"};
585 while (succeeded(Result: parser.parseOptionalKeyword(keyword: &attrName))) {
586 // Detect admissible keyword.
587 auto *it = find(Range&: keys, Val: attrName);
588 if (it == keys.end()) {
589 parser.emitError(loc: parser.getNameLoc(), message: "unexpected key: ") << attrName;
590 return {};
591 }
592 unsigned keyWordIndex = it - keys.begin();
593 // Consume the `=` after keys
594 if (failed(Result: parser.parseEqual()))
595 return {};
596 // Dispatch on keyword.
597 switch (keyWordIndex) {
598 case 0: { // map
599 ir_detail::DimLvlMapParser cParser(parser);
600 auto res = cParser.parseDimLvlMap();
601 if (failed(Result: res))
602 return {};
603 const auto &dlm = *res;
604
605 const Level lvlRank = dlm.getLvlRank();
606 for (Level lvl = 0; lvl < lvlRank; lvl++)
607 lvlTypes.push_back(Elt: dlm.getLvlType(lvl));
608
609 const Dimension dimRank = dlm.getDimRank();
610 for (Dimension dim = 0; dim < dimRank; dim++)
611 dimSlices.push_back(Elt: dlm.getDimSlice(dim));
612 // NOTE: the old syntax requires an all-or-nothing approach to
613 // `dimSlices`; therefore, if any slice actually exists then we need
614 // to convert null-DSA into default/nop DSA.
615 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
616 return static_cast<bool>(slice.getImpl());
617 };
618 if (llvm::any_of(Range&: dimSlices, P: isDefined)) {
619 const auto defaultSlice =
620 SparseTensorDimSliceAttr::get(context: parser.getContext());
621 for (Dimension dim = 0; dim < dimRank; dim++)
622 if (!isDefined(dimSlices[dim]))
623 dimSlices[dim] = defaultSlice;
624 } else {
625 dimSlices.clear();
626 }
627
628 dimToLvl = dlm.getDimToLvlMap(context: parser.getContext());
629 lvlToDim = dlm.getLvlToDimMap(context: parser.getContext());
630 break;
631 }
632 case 1: { // posWidth
633 Attribute attr;
634 if (failed(Result: parser.parseAttribute(result&: attr)))
635 return {};
636 auto intAttr = llvm::dyn_cast<IntegerAttr>(Val&: attr);
637 if (!intAttr) {
638 parser.emitError(loc: parser.getNameLoc(),
639 message: "expected an integral position bitwidth");
640 return {};
641 }
642 posWidth = intAttr.getInt();
643 break;
644 }
645 case 2: { // crdWidth
646 Attribute attr;
647 if (failed(Result: parser.parseAttribute(result&: attr)))
648 return {};
649 auto intAttr = llvm::dyn_cast<IntegerAttr>(Val&: attr);
650 if (!intAttr) {
651 parser.emitError(loc: parser.getNameLoc(),
652 message: "expected an integral index bitwidth");
653 return {};
654 }
655 crdWidth = intAttr.getInt();
656 break;
657 }
658 case 3: { // explicitVal
659 Attribute attr;
660 if (failed(Result: parser.parseAttribute(result&: attr)))
661 return {};
662 if (auto result = llvm::dyn_cast<FloatAttr>(Val&: attr)) {
663 explicitVal = result;
664 } else if (auto result = llvm::dyn_cast<IntegerAttr>(Val&: attr)) {
665 explicitVal = result;
666 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(Val&: attr)) {
667 explicitVal = result;
668 } else {
669 parser.emitError(loc: parser.getNameLoc(),
670 message: "expected a numeric value for explicitVal");
671 return {};
672 }
673 break;
674 }
675 case 4: { // implicitVal
676 Attribute attr;
677 if (failed(Result: parser.parseAttribute(result&: attr)))
678 return {};
679 if (auto result = llvm::dyn_cast<FloatAttr>(Val&: attr)) {
680 implicitVal = result;
681 } else if (auto result = llvm::dyn_cast<IntegerAttr>(Val&: attr)) {
682 implicitVal = result;
683 } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(Val&: attr)) {
684 implicitVal = result;
685 } else {
686 parser.emitError(loc: parser.getNameLoc(),
687 message: "expected a numeric value for implicitVal");
688 return {};
689 }
690 break;
691 }
692 } // switch
693 // Only last item can omit the comma.
694 if (parser.parseOptionalComma().failed())
695 break;
696 }
697
698 // Close "}>" part.
699 if (failed(Result: parser.parseRBrace()))
700 return {};
701 if (failed(Result: parser.parseGreater()))
702 return {};
703
704 // Construct struct-like storage for attribute.
705 if (!lvlToDim || lvlToDim.isEmpty()) {
706 lvlToDim = inferLvlToDim(dimToLvl, context: parser.getContext());
707 }
708 return parser.getChecked<SparseTensorEncodingAttr>(
709 params: parser.getContext(), params&: lvlTypes, params&: dimToLvl, params&: lvlToDim, params&: posWidth, params&: crdWidth,
710 params&: explicitVal, params&: implicitVal, params&: dimSlices);
711}
712
713void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
714 auto map = static_cast<AffineMap>(getDimToLvl());
715 // Empty affine map indicates identity map
716 if (!map)
717 map = AffineMap::getMultiDimIdentityMap(numDims: getLvlTypes().size(), context: getContext());
718 printer << "<{ map = ";
719 printSymbols(map, printer);
720 printer << '(';
721 printDimensions(map, printer, dimSlices: getDimSlices());
722 printer << ") -> (";
723 printLevels(map, printer, lvlTypes: getLvlTypes());
724 printer << ')';
725 // Print remaining members only for non-default values.
726 if (getPosWidth())
727 printer << ", posWidth = " << getPosWidth();
728 if (getCrdWidth())
729 printer << ", crdWidth = " << getCrdWidth();
730 if (getExplicitVal()) {
731 printer << ", explicitVal = " << getExplicitVal();
732 }
733 if (getImplicitVal())
734 printer << ", implicitVal = " << getImplicitVal();
735 printer << " }>";
736}
737
738void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
739 AsmPrinter &printer) const {
740 if (map.getNumSymbols() == 0)
741 return;
742 printer << '[';
743 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
744 printer << 's' << i << ", ";
745 if (map.getNumSymbols() >= 1)
746 printer << 's' << map.getNumSymbols() - 1;
747 printer << ']';
748}
749
750void SparseTensorEncodingAttr::printDimensions(
751 AffineMap &map, AsmPrinter &printer,
752 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
753 if (!dimSlices.empty()) {
754 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
755 printer << 'd' << i << " : " << dimSlices[i] << ", ";
756 if (map.getNumDims() >= 1) {
757 printer << 'd' << map.getNumDims() - 1 << " : "
758 << dimSlices[map.getNumDims() - 1];
759 }
760 } else {
761 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
762 printer << 'd' << i << ", ";
763 if (map.getNumDims() >= 1)
764 printer << 'd' << map.getNumDims() - 1;
765 }
766}
767
768void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
769 ArrayRef<LevelType> lvlTypes) const {
770 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
771 map.getResult(idx: i).print(os&: printer.getStream());
772 printer << " : " << toMLIRString(lt: lvlTypes[i]) << ", ";
773 }
774 if (map.getNumResults() >= 1) {
775 auto lastIndex = map.getNumResults() - 1;
776 map.getResult(idx: lastIndex).print(os&: printer.getStream());
777 printer << " : " << toMLIRString(lt: lvlTypes[lastIndex]);
778 }
779}
780
781LogicalResult SparseTensorEncodingAttr::verify(
782 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
783 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
784 unsigned crdWidth, Attribute explicitVal, Attribute implicitVal,
785 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
786 if (!acceptBitWidth(bitWidth: posWidth))
787 return emitError() << "unexpected position bitwidth: " << posWidth;
788 if (!acceptBitWidth(bitWidth: crdWidth))
789 return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
790
791 // Verify every COO segment.
792 auto *it = llvm::find_if(Range&: lvlTypes, P: isSingletonLT);
793 while (it != lvlTypes.end()) {
794 if (it == lvlTypes.begin() ||
795 !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>())
796 return emitError() << "expected compressed or loose_compressed level "
797 "before singleton level";
798
799 auto *curCOOEnd = std::find_if_not(first: it, last: lvlTypes.end(), pred: isSingletonLT);
800 if (!std::all_of(first: it, last: curCOOEnd, pred: isSingletonLT))
801 return emitError() << "expected all singleton lvlTypes "
802 "following a singleton level";
803 // We can potentially support mixed SoA/AoS singleton levels.
804 if (!std::all_of(first: it, last: curCOOEnd, pred: [it](LevelType i) {
805 return it->isa<LevelPropNonDefault::SoA>() ==
806 i.isa<LevelPropNonDefault::SoA>();
807 })) {
808 return emitError() << "expected all singleton lvlTypes stored in the "
809 "same memory layout (SoA vs AoS).";
810 }
811 it = std::find_if(first: curCOOEnd, last: lvlTypes.end(), pred: isSingletonLT);
812 }
813
814 auto lastBatch = std::find_if(first: lvlTypes.rbegin(), last: lvlTypes.rend(), pred: isBatchLT);
815 if (!std::all_of(first: lastBatch, last: lvlTypes.rend(), pred: isBatchLT))
816 return emitError() << "Batch lvlType can only be leading levels.";
817
818 // SoA property can only be applied on singleton level.
819 auto soaLvls = llvm::make_filter_range(Range&: lvlTypes, Pred: [](LevelType lt) {
820 return lt.isa<LevelPropNonDefault::SoA>();
821 });
822 if (llvm::any_of(Range&: soaLvls, P: [](LevelType lt) {
823 return !lt.isa<LevelFormat::Singleton>();
824 })) {
825 return emitError() << "SoA is only applicable to singleton lvlTypes.";
826 }
827
828 // TODO: audit formats that actually are supported by backend.
829 if (auto it = llvm::find_if(Range&: lvlTypes, P: isNOutOfMLT);
830 it != std::end(cont&: lvlTypes)) {
831 if (it != lvlTypes.end() - 1)
832 return emitError() << "expected n_out_of_m to be the last level type";
833 if (!std::all_of(first: lvlTypes.begin(), last: it, pred: isDenseLT))
834 return emitError() << "expected all dense lvlTypes "
835 "before a n_out_of_m level";
836 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
837 if (!isBlockSparsity(dimToLvl)) {
838 return emitError()
839 << "expected 1xm block structure for n_out_of_m level";
840 }
841 auto sizes = getBlockSize(dimToLvl);
842 unsigned coefficient = 0;
843 for (const auto &elem : sizes) {
844 if (elem != 0) {
845 if (elem != coefficient && coefficient != 0) {
846 return emitError() << "expected only one blocked level "
847 "with the same coefficients";
848 }
849 coefficient = elem;
850 }
851 }
852 if (coefficient != getM(lt: *it)) {
853 return emitError() << "expected coeffiencts of Affine expressions "
854 "to be equal to m of n_out_of_m level";
855 }
856 }
857 }
858 // Before we can check that the level-rank is consistent/coherent
859 // across all fields, we need to define it. The source-of-truth for
860 // the `getLvlRank` method is the length of the level-types array,
861 // since it must always be provided and have full rank; therefore we
862 // use that same source-of-truth here.
863 const Level lvlRank = lvlTypes.size();
864 if (lvlRank == 0)
865 return emitError() << "expected a non-empty array for lvlTypes";
866 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
867 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
868 if (dimToLvl) {
869 if (dimToLvl.getNumResults() != lvlRank)
870 return emitError()
871 << "level-rank mismatch between dimToLvl and lvlTypes: "
872 << dimToLvl.getNumResults() << " != " << lvlRank;
873 auto inferRes = inferLvlToDim(dimToLvl, context: dimToLvl.getContext());
874 // Symbols can't be inferred but are acceptable.
875 if (!inferRes && dimToLvl.getNumSymbols() == 0)
876 return emitError() << "failed to infer lvlToDim from dimToLvl";
877 if (lvlToDim && (inferRes != lvlToDim))
878 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
879 if (dimRank > lvlRank)
880 return emitError() << "unexpected dimToLvl mapping from " << dimRank
881 << " to " << lvlRank;
882 }
883 if (!dimSlices.empty()) {
884 if (dimSlices.size() != dimRank)
885 return emitError()
886 << "dimension-rank mismatch between dimSlices and dimToLvl: "
887 << dimSlices.size() << " != " << dimRank;
888 // Compiler support for `dimSlices` currently requires that the two
889 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
890 if (dimRank != lvlRank)
891 return emitError()
892 << "dimSlices expected dimension-rank to match level-rank: "
893 << dimRank << " != " << lvlRank;
894 }
895 return success();
896}
897
898LogicalResult SparseTensorEncodingAttr::verifyEncoding(
899 ArrayRef<Size> dimShape, Type elementType,
900 function_ref<InFlightDiagnostic()> emitError) const {
901 // Check structural integrity. In particular, this ensures that the
902 // level-rank is coherent across all the fields.
903 if (failed(Result: verify(emitError, lvlTypes: getLvlTypes(), dimToLvl: getDimToLvl(), lvlToDim: getLvlToDim(),
904 posWidth: getPosWidth(), crdWidth: getCrdWidth(), explicitVal: getExplicitVal(),
905 implicitVal: getImplicitVal(), dimSlices: getDimSlices())))
906 return failure();
907 // Check integrity with tensor type specifics. In particular, we
908 // need only check that the dimension-rank of the tensor agrees with
909 // the dimension-rank of the encoding.
910 const Dimension dimRank = dimShape.size();
911 if (dimRank == 0)
912 return emitError() << "expected non-scalar sparse tensor";
913 if (getDimRank() != dimRank)
914 return emitError()
915 << "dimension-rank mismatch between encoding and tensor shape: "
916 << getDimRank() << " != " << dimRank;
917 if (auto expVal = getExplicitVal()) {
918 Type attrType = llvm::dyn_cast<TypedAttr>(Val&: expVal).getType();
919 if (attrType != elementType) {
920 return emitError() << "explicit value type mismatch between encoding and "
921 << "tensor element type: " << attrType
922 << " != " << elementType;
923 }
924 }
925 if (auto impVal = getImplicitVal()) {
926 Type attrType = llvm::dyn_cast<TypedAttr>(Val&: impVal).getType();
927 if (attrType != elementType) {
928 return emitError() << "implicit value type mismatch between encoding and "
929 << "tensor element type: " << attrType
930 << " != " << elementType;
931 }
932 // Currently, we only support zero as the implicit value.
933 auto impFVal = llvm::dyn_cast<FloatAttr>(Val&: impVal);
934 auto impIntVal = llvm::dyn_cast<IntegerAttr>(Val&: impVal);
935 auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(Val&: impVal);
936 if ((impFVal && impFVal.getValue().isNonZero()) ||
937 (impIntVal && !impIntVal.getValue().isZero()) ||
938 (impComplexVal && (impComplexVal.getImag().isNonZero() ||
939 impComplexVal.getReal().isNonZero()))) {
940 return emitError() << "implicit value must be zero";
941 }
942 }
943 return success();
944}
945
946Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const {
947 SmallVector<COOSegment> coo = getCOOSegments();
948 assert(coo.size() == 1 || coo.empty());
949 if (!coo.empty() && coo.front().isAoS()) {
950 return coo.front().lvlRange.first;
951 }
952 return getLvlRank();
953}
954
955SmallVector<COOSegment>
956mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const {
957 SmallVector<COOSegment> ret;
958 if (getLvlRank() <= 1)
959 return ret;
960
961 ArrayRef<LevelType> lts = getLvlTypes();
962 Level l = 0;
963 while (l < getLvlRank()) {
964 auto lt = lts[l];
965 if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
966 auto cur = lts.begin() + l;
967 auto end = std::find_if(first: cur + 1, last: lts.end(), pred: [](LevelType lt) {
968 return !lt.isa<LevelFormat::Singleton>();
969 });
970 unsigned cooLen = std::distance(first: cur, last: end);
971 if (cooLen > 1) {
972 // To support mixed SoA/AoS COO, we should break the segment when the
973 // storage scheme changes, for now we faithfully assume that all
974 // consecutive singleton levels have the same storage format as verified
975 // STEA.
976 ret.push_back(Elt: COOSegment{.lvlRange: std::make_pair(x&: l, y: l + cooLen),
977 .isSoA: lts[l + 1].isa<LevelPropNonDefault::SoA>()});
978 }
979 l += cooLen;
980 } else {
981 l++;
982 }
983 }
984 return ret;
985}
986
987//===----------------------------------------------------------------------===//
988// SparseTensorType Methods.
989//===----------------------------------------------------------------------===//
990
991bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
992 bool isUnique) const {
993 if (!hasEncoding())
994 return false;
995 if (!isCompressedLvl(l: startLvl) && !isLooseCompressedLvl(l: startLvl))
996 return false;
997 for (Level l = startLvl + 1; l < lvlRank; ++l)
998 if (!isSingletonLvl(l))
999 return false;
1000 // If isUnique is true, then make sure that the last level is unique,
1001 // that is, when lvlRank == 1, the only compressed level is unique,
1002 // and when lvlRank > 1, the last singleton is unique.
1003 return !isUnique || isUniqueLvl(l: lvlRank - 1);
1004}
1005
1006RankedTensorType
1007mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
1008 SmallVector<LevelType> lvlTypes;
1009 lvlTypes.reserve(N: lvlRank);
1010 // A non-unique compressed level at beginning (unless this is
1011 // also the last level, then it is unique).
1012 lvlTypes.push_back(
1013 Elt: *buildLevelType(lf: LevelFormat::Compressed, ordered, unique: lvlRank == 1));
1014 if (lvlRank > 1) {
1015 // Followed by n-2 non-unique singleton levels.
1016 std::fill_n(first: std::back_inserter(x&: lvlTypes), n: lvlRank - 2,
1017 value: *buildLevelType(lf: LevelFormat::Singleton, ordered, unique: false));
1018 // Ends by a unique singleton level.
1019 lvlTypes.push_back(Elt: *buildLevelType(lf: LevelFormat::Singleton, ordered, unique: true));
1020 }
1021 auto enc = SparseTensorEncodingAttr::get(
1022 context: getContext(), lvlTypes, dimToLvl: getDimToLvl(), lvlToDim: getLvlToDim(), posWidth: getPosWidth(),
1023 crdWidth: getCrdWidth(), explicitVal: getExplicitVal(), implicitVal: getImplicitVal());
1024 return RankedTensorType::get(shape: getDimShape(), elementType: getElementType(), encoding: enc);
1025}
1026
1027//===----------------------------------------------------------------------===//
1028// Convenience Methods.
1029//===----------------------------------------------------------------------===//
1030
1031SparseTensorEncodingAttr
1032mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
1033 if (auto ttp = llvm::dyn_cast<RankedTensorType>(Val&: type))
1034 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(Val: ttp.getEncoding());
1035 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(Val&: type))
1036 return mdtp.getEncoding();
1037 return nullptr;
1038}
1039
1040AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
1041 MLIRContext *context) {
1042 auto map = static_cast<AffineMap>(dimToLvl);
1043 AffineMap lvlToDim;
1044 // Return an empty lvlToDim when inference is not successful.
1045 if (!map || map.getNumSymbols() != 0) {
1046 lvlToDim = AffineMap();
1047 } else if (map.isPermutation()) {
1048 lvlToDim = inversePermutation(map);
1049 } else if (isBlockSparsity(dimToLvl: map)) {
1050 lvlToDim = inverseBlockSparsity(dimToLvl: map, context);
1051 }
1052 return lvlToDim;
1053}
1054
1055AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
1056 MLIRContext *context) {
1057 SmallVector<AffineExpr> lvlExprs;
1058 auto numLvls = dimToLvl.getNumResults();
1059 lvlExprs.reserve(N: numLvls);
1060 // lvlExprComponents stores information of the floordiv and mod operations
1061 // applied to the same dimension, so as to build the lvlToDim map.
1062 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
1063 for (unsigned i = 0, n = numLvls; i < n; i++) {
1064 auto result = dimToLvl.getResult(idx: i);
1065 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) {
1066 if (result.getKind() == AffineExprKind::FloorDiv) {
1067 // Position of the dimension in dimToLvl.
1068 auto pos = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()).getPosition();
1069 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
1070 "expected only one floordiv for each dimension");
1071 SmallVector<AffineExpr, 3> components;
1072 // Level variable for floordiv.
1073 components.push_back(Elt: getAffineDimExpr(position: i, context));
1074 // Multiplier.
1075 components.push_back(Elt: binOp.getRHS());
1076 // Map key is the position of the dimension.
1077 lvlExprComponents[pos] = components;
1078 } else if (result.getKind() == AffineExprKind::Mod) {
1079 auto pos = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()).getPosition();
1080 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
1081 "expected floordiv before mod");
1082 // Add level variable for mod to the same vector
1083 // of the corresponding floordiv.
1084 lvlExprComponents[pos].push_back(Elt: getAffineDimExpr(position: i, context));
1085 } else {
1086 assert(false && "expected floordiv or mod");
1087 }
1088 } else {
1089 lvlExprs.push_back(Elt: getAffineDimExpr(position: i, context));
1090 }
1091 }
1092 // Build lvlExprs from lvlExprComponents.
1093 // For example, for il = i floordiv 2 and ii = i mod 2, the components
1094 // would be [il, 2, ii]. It could be used to build the AffineExpr
1095 // i = il * 2 + ii in lvlToDim.
1096 for (auto &components : lvlExprComponents) {
1097 assert(components.second.size() == 3 &&
1098 "expected 3 components to build lvlExprs");
1099 auto mulOp = getAffineBinaryOpExpr(
1100 kind: AffineExprKind::Mul, lhs: components.second[0], rhs: components.second[1]);
1101 auto addOp =
1102 getAffineBinaryOpExpr(kind: AffineExprKind::Add, lhs: mulOp, rhs: components.second[2]);
1103 lvlExprs.push_back(Elt: addOp);
1104 }
1105 return dimToLvl.get(dimCount: dimToLvl.getNumResults(), symbolCount: 0, results: lvlExprs, context);
1106}
1107
1108SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
1109 assert(isBlockSparsity(dimToLvl) &&
1110 "expected dimToLvl to be block sparsity for calling getBlockSize");
1111 SmallVector<unsigned> blockSize;
1112 for (auto result : dimToLvl.getResults()) {
1113 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) {
1114 if (result.getKind() == AffineExprKind::Mod) {
1115 blockSize.push_back(
1116 Elt: dyn_cast<AffineConstantExpr>(Val: binOp.getRHS()).getValue());
1117 }
1118 } else {
1119 blockSize.push_back(Elt: 0);
1120 }
1121 }
1122 return blockSize;
1123}
1124
1125bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
1126 if (!dimToLvl)
1127 return false;
1128 std::map<unsigned, int64_t> coeffientMap;
1129 bool hasBlock = false;
1130 for (auto result : dimToLvl.getResults()) {
1131 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) {
1132 // Check for "dim op const".
1133 auto dimOp = dyn_cast<AffineDimExpr>(Val: binOp.getLHS());
1134 auto conOp = dyn_cast<AffineConstantExpr>(Val: binOp.getRHS());
1135 if (!dimOp || !conOp || conOp.getValue() <= 0)
1136 return false;
1137 // Inspect "dim / const" or "dim % const".
1138 auto pos = dimOp.getPosition();
1139 if (binOp.getKind() == AffineExprKind::FloorDiv) {
1140 // Expect only one floordiv for each dimension.
1141 auto [it, inserted] = coeffientMap.try_emplace(k: pos);
1142 if (!inserted)
1143 return false;
1144 // Record coefficient of the floordiv.
1145 it->second = conOp.getValue();
1146 } else if (binOp.getKind() == AffineExprKind::Mod) {
1147 // Expect floordiv before mod.
1148 auto it = coeffientMap.find(x: pos);
1149 if (it == coeffientMap.end())
1150 return false;
1151 // Expect mod to have the same coefficient as floordiv.
1152 if (conOp.getValue() != it->second)
1153 return false;
1154 hasBlock = true;
1155 } else {
1156 return false;
1157 }
1158 } else if (auto dimOp = dyn_cast<AffineDimExpr>(Val&: result)) {
1159 auto pos = dimOp.getPosition();
1160 // Expect dim to be unset.
1161 if (!coeffientMap.try_emplace(k: pos, args: 0).second)
1162 return false;
1163 } else {
1164 return false;
1165 }
1166 }
1167 return hasBlock;
1168}
1169
1170bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
1171 auto hasNonIdentityMap = [](Value v) {
1172 auto stt = tryGetSparseTensorType(val: v);
1173 return stt && !stt->isIdentity();
1174 };
1175
1176 return llvm::any_of(Range: op->getOperands(), P: hasNonIdentityMap) ||
1177 llvm::any_of(Range: op->getResults(), P: hasNonIdentityMap);
1178}
1179
1180Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1181 if (enc) {
1182 assert(enc.isPermutation() && "Non permutation map not supported");
1183 if (const auto dimToLvl = enc.getDimToLvl())
1184 return dimToLvl.getDimPosition(idx: l);
1185 }
1186 return l;
1187}
1188
1189Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1190 if (enc) {
1191 assert(enc.isPermutation() && "Non permutation map not supported");
1192 if (const auto lvlToDim = enc.getLvlToDim())
1193 return lvlToDim.getDimPosition(idx: d);
1194 }
1195 return d;
1196}
1197
1198/// We normalized sparse tensor encoding attribute by always using
1199/// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1200/// as other variants) lead to the same storage specifier type, and stripping
1201/// irrelevant fields that do not alter the sparse tensor memory layout.
1202static SparseTensorEncodingAttr
1203getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1204 SmallVector<LevelType> lts;
1205 for (auto lt : enc.getLvlTypes())
1206 lts.push_back(Elt: lt.stripStorageIrrelevantProperties());
1207
1208 return SparseTensorEncodingAttr::get(
1209 context: enc.getContext(), lvlTypes: lts,
1210 dimToLvl: AffineMap(), // dimToLvl (irrelevant to storage specifier)
1211 lvlToDim: AffineMap(), // lvlToDim (irrelevant to storage specifier)
1212 // Always use `index` for memSize and lvlSize instead of reusing
1213 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1214 // value for different bitwidth, it also avoids casting between index and
1215 // integer (returned by DimOp)
1216 posWidth: 0, crdWidth: 0,
1217 explicitVal: Attribute(), // explicitVal (irrelevant to storage specifier)
1218 implicitVal: Attribute(), // implicitVal (irrelevant to storage specifier)
1219 dimSlices: enc.getDimSlices());
1220}
1221
1222StorageSpecifierType
1223StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1224 return Base::get(ctx, args: getNormalizedEncodingForSpecifier(enc: encoding));
1225}
1226
1227StorageSpecifierType
1228StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1229 MLIRContext *ctx,
1230 SparseTensorEncodingAttr encoding) {
1231 return Base::getChecked(emitErrorFn: emitError, ctx,
1232 args: getNormalizedEncodingForSpecifier(enc: encoding));
1233}
1234
1235//===----------------------------------------------------------------------===//
1236// SparseTensorDialect Operations.
1237//===----------------------------------------------------------------------===//
1238
1239static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1240 return success(IsSuccess: lvl < getSparseTensorType(val: tensor).getLvlRank());
1241}
1242
1243static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1244 const Type etp = getMemRefType(t&: mem).getElementType();
1245 return success(IsSuccess: width == 0 ? etp.isIndex() : etp.isInteger(width));
1246}
1247
1248static LogicalResult verifySparsifierGetterSetter(
1249 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1250 TypedValue<StorageSpecifierType> md, Operation *op) {
1251 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1252 return op->emitError(
1253 message: "redundant level argument for querying value memory size");
1254 }
1255
1256 const auto enc = md.getType().getEncoding();
1257 const Level lvlRank = enc.getLvlRank();
1258
1259 if (mdKind == StorageSpecifierKind::DimOffset ||
1260 mdKind == StorageSpecifierKind::DimStride)
1261 if (!enc.isSlice())
1262 return op->emitError(message: "requested slice data on non-slice tensor");
1263
1264 if (mdKind != StorageSpecifierKind::ValMemSize) {
1265 if (!lvl)
1266 return op->emitError(message: "missing level argument");
1267
1268 const Level l = lvl.value();
1269 if (l >= lvlRank)
1270 return op->emitError(message: "requested level is out of bounds");
1271
1272 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1273 return op->emitError(
1274 message: "requested position memory size on a singleton level");
1275 }
1276 return success();
1277}
1278
1279static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1280 switch (kind) {
1281 case SparseTensorFieldKind::CrdMemRef:
1282 return stt.getCrdType();
1283 case SparseTensorFieldKind::PosMemRef:
1284 return stt.getPosType();
1285 case SparseTensorFieldKind::ValMemRef:
1286 return stt.getElementType();
1287 case SparseTensorFieldKind::StorageSpec:
1288 return nullptr;
1289 }
1290 llvm_unreachable("Unrecognizable FieldKind");
1291}
1292
1293static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1294 SparseTensorType stt,
1295 RankedTensorType valTp,
1296 TypeRange lvlTps) {
1297 if (requiresStaticShape && !stt.hasStaticDimShape())
1298 return op->emitError(message: "the sparse-tensor must have static shape");
1299 if (!stt.hasEncoding())
1300 return op->emitError(message: "the sparse-tensor must have an encoding attribute");
1301
1302 // Verifies the trailing COO.
1303 Level cooStartLvl = stt.getAoSCOOStart();
1304 if (cooStartLvl < stt.getLvlRank()) {
1305 // We only supports trailing COO for now, must be the last input.
1306 auto cooTp = llvm::cast<ShapedType>(Val: lvlTps.back());
1307 // The coordinates should be in shape of <? x rank>
1308 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1309 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1310 return op->emitError(message: "input/output trailing COO level-ranks don't match");
1311 }
1312 }
1313
1314 // Verifies that all types match.
1315 StorageLayout layout(stt.getEncoding());
1316 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1317 return op->emitError(message: "inconsistent number of fields between input/output");
1318
1319 unsigned idx = 0;
1320 bool misMatch = false;
1321 layout.foreachField(callback: [&idx, &misMatch, stt, valTp,
1322 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1323 Level lvl, LevelType lt) -> bool {
1324 if (fKind == SparseTensorFieldKind::StorageSpec)
1325 return true;
1326
1327 Type inputTp = nullptr;
1328 if (fKind == SparseTensorFieldKind::ValMemRef) {
1329 inputTp = valTp;
1330 } else {
1331 assert(fid == idx && stt.getLvlType(lvl) == lt);
1332 inputTp = lvlTps[idx++];
1333 }
1334 // The input element type and expected element type should match.
1335 Type inpElemTp = llvm::cast<TensorType>(Val&: inputTp).getElementType();
1336 Type expElemTp = getFieldElemType(stt, kind: fKind);
1337 if (inpElemTp != expElemTp) {
1338 misMatch = true;
1339 return false; // to terminate the iteration
1340 }
1341 return true;
1342 });
1343
1344 if (misMatch)
1345 return op->emitError(message: "input/output element-types don't match");
1346 return success();
1347}
1348
1349LogicalResult AssembleOp::verify() {
1350 RankedTensorType valuesTp = getValues().getType();
1351 const auto lvlsTp = getLevels().getTypes();
1352 const auto resTp = getSparseTensorType(val: getResult());
1353 return verifyPackUnPack(op: *this, requiresStaticShape: true, stt: resTp, valTp: valuesTp, lvlTps: lvlsTp);
1354}
1355
1356LogicalResult DisassembleOp::verify() {
1357 if (getOutValues().getType() != getRetValues().getType())
1358 return emitError(message: "output values and return value type mismatch");
1359
1360 for (auto [ot, rt] : llvm::zip_equal(t: getOutLevels(), u: getRetLevels()))
1361 if (ot.getType() != rt.getType())
1362 return emitError(message: "output levels and return levels type mismatch");
1363
1364 RankedTensorType valuesTp = getRetValues().getType();
1365 const auto lvlsTp = getRetLevels().getTypes();
1366 const auto srcTp = getSparseTensorType(val: getTensor());
1367 return verifyPackUnPack(op: *this, requiresStaticShape: false, stt: srcTp, valTp: valuesTp, lvlTps: lvlsTp);
1368}
1369
1370LogicalResult ConvertOp::verify() {
1371 RankedTensorType tp1 = getSource().getType();
1372 RankedTensorType tp2 = getDest().getType();
1373 if (tp1.getRank() != tp2.getRank())
1374 return emitError(message: "unexpected conversion mismatch in rank");
1375 auto dstEnc =
1376 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(Val: tp2.getEncoding());
1377 if (dstEnc && dstEnc.isSlice())
1378 return emitError(message: "cannot convert to a sparse tensor slice");
1379
1380 auto shape1 = tp1.getShape();
1381 auto shape2 = tp2.getShape();
1382 // Accept size matches between the source and the destination type
1383 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1384 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1385 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1386 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1387 return emitError(message: "unexpected conversion mismatch in dimension ") << d;
1388 return success();
1389}
1390
1391OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1392 if (getType() == getSource().getType())
1393 return getSource();
1394 return {};
1395}
1396
1397bool ConvertOp::needsExtraSort() {
1398 SparseTensorType srcStt = getSparseTensorType(val: getSource());
1399 SparseTensorType dstStt = getSparseTensorType(val: getDest());
1400
1401 // We do not need an extra sort when returning unordered sparse tensors or
1402 // dense tensor since dense tensor support random access.
1403 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1404 return false;
1405
1406 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1407 srcStt.hasSameDimToLvl(other: dstStt)) {
1408 return false;
1409 }
1410
1411 // Source and dest tensors are ordered in different ways. We only do direct
1412 // dense to sparse conversion when the dense input is defined by a sparse
1413 // constant. Note that we can theoretically always directly convert from dense
1414 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1415 // performance.
1416 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1417 if (isa<SparseElementsAttr>(Val: constOp.getValue()))
1418 return false;
1419
1420 return true;
1421}
1422
1423LogicalResult CrdTranslateOp::verify() {
1424 uint64_t inRank = getEncoder().getLvlRank();
1425 uint64_t outRank = getEncoder().getDimRank();
1426
1427 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1428 std::swap(a&: inRank, b&: outRank);
1429
1430 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1431 return emitError(message: "Coordinate rank mismatch with encoding");
1432
1433 return success();
1434}
1435
1436LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1437 SmallVectorImpl<OpFoldResult> &results) {
1438 if (getEncoder().isIdentity()) {
1439 results.assign(in_start: getInCrds().begin(), in_end: getInCrds().end());
1440 return success();
1441 }
1442 if (getEncoder().isPermutation()) {
1443 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1444 ? getEncoder().getDimToLvl()
1445 : getEncoder().getLvlToDim();
1446 for (AffineExpr exp : perm.getResults())
1447 results.push_back(Elt: getInCrds()[cast<AffineDimExpr>(Val&: exp).getPosition()]);
1448 return success();
1449 }
1450
1451 // Fuse dim2lvl/lvl2dim pairs.
1452 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1453 bool sameDef = def && llvm::all_of(Range: getInCrds(), P: [def](Value v) {
1454 return v.getDefiningOp() == def;
1455 });
1456 if (!sameDef)
1457 return failure();
1458
1459 bool oppositeDir = def.getDirection() != getDirection();
1460 bool sameOracle =
1461 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1462 bool sameCount = def.getNumResults() == getInCrds().size();
1463 if (!oppositeDir || !sameOracle || !sameCount)
1464 return failure();
1465
1466 // The definition produces the coordinates in the same order as the input
1467 // coordinates.
1468 bool sameOrder = llvm::all_of(Range: llvm::zip_equal(t: def.getOutCrds(), u: getInCrds()),
1469 P: [](auto valuePair) {
1470 auto [lhs, rhs] = valuePair;
1471 return lhs == rhs;
1472 });
1473
1474 if (!sameOrder)
1475 return failure();
1476 // l1 = dim2lvl (lvl2dim l0)
1477 // ==> l0
1478 results.append(in_start: def.getInCrds().begin(), in_end: def.getInCrds().end());
1479 return success();
1480}
1481
1482void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1483 int64_t index) {
1484 Value val = builder.create<arith::ConstantIndexOp>(location: state.location, args&: index);
1485 return build(odsBuilder&: builder, odsState&: state, source, index: val);
1486}
1487
1488LogicalResult LvlOp::verify() {
1489 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1490 auto stt = getSparseTensorType(val: getSource());
1491 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1492 return emitError(
1493 message: "Level index exceeds the rank of the input sparse tensor");
1494 }
1495 return success();
1496}
1497
1498std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1499 return getConstantIntValue(ofr: getIndex());
1500}
1501
1502Speculation::Speculatability LvlOp::getSpeculatability() {
1503 auto constantIndex = getConstantLvlIndex();
1504 if (!constantIndex)
1505 return Speculation::NotSpeculatable;
1506
1507 assert(constantIndex <
1508 cast<RankedTensorType>(getSource().getType()).getRank());
1509 return Speculation::Speculatable;
1510}
1511
1512OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1513 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(Val: adaptor.getIndex());
1514 if (!lvlIndex)
1515 return {};
1516
1517 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1518 auto stt = getSparseTensorType(val: getSource());
1519 if (lvl >= stt.getLvlRank()) {
1520 // Follows the same convention used by tensor.dim operation. Out of bound
1521 // indices produce undefined behavior but are still valid IR. Don't choke on
1522 // them.
1523 return {};
1524 }
1525
1526 // Helper lambda to build an IndexAttr.
1527 auto getIndexAttr = [this](int64_t lvlSz) {
1528 return IntegerAttr::get(type: IndexType::get(context: getContext()), value: APInt(64, lvlSz));
1529 };
1530
1531 SmallVector<Size> lvlShape = stt.getLvlShape();
1532 if (ShapedType::isStatic(dValue: lvlShape[lvl]))
1533 return getIndexAttr(lvlShape[lvl]);
1534
1535 return {};
1536}
1537
1538void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1539 SparseTensorEncodingAttr dstEnc, Value source) {
1540 auto srcStt = getSparseTensorType(val: source);
1541 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1542 SmallVector<int64_t> dstDimShape =
1543 dstEnc.translateShape(srcShape: srcLvlShape, dir: CrdTransDirectionKind::lvl2dim);
1544 auto dstTp =
1545 RankedTensorType::get(shape: dstDimShape, elementType: srcStt.getElementType(), encoding: dstEnc);
1546 return build(odsBuilder, odsState, dest: dstTp, source);
1547}
1548
1549LogicalResult ReinterpretMapOp::verify() {
1550 auto srcStt = getSparseTensorType(val: getSource());
1551 auto dstStt = getSparseTensorType(val: getDest());
1552 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1553 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1554
1555 if (srcLvlTps.size() != dstLvlTps.size())
1556 return emitError(message: "Level rank mismatch between source/dest tensors");
1557
1558 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(t&: srcLvlTps, u&: dstLvlTps))
1559 if (srcLvlTp != dstLvlTp)
1560 return emitError(message: "Level type mismatch between source/dest tensors");
1561
1562 if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1563 srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1564 return emitError(message: "Crd/Pos width mismatch between source/dest tensors");
1565 }
1566
1567 if (srcStt.getElementType() != dstStt.getElementType())
1568 return emitError(message: "Element type mismatch between source/dest tensors");
1569
1570 SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1571 SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1572 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(t&: srcLvlShape, u&: dstLvlShape)) {
1573 if (srcLvlSz != dstLvlSz) {
1574 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1575 // compatible to <3x4>? For now, we require all the level sizes to be
1576 // *exactly* matched for simplicity.
1577 return emitError(message: "Level size mismatch between source/dest tensors");
1578 }
1579 }
1580
1581 return success();
1582}
1583
1584OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1585 if (getSource().getType() == getDest().getType())
1586 return getSource();
1587
1588 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1589 // A -> B, B -> A ==> A
1590 if (def.getSource().getType() == getDest().getType())
1591 return def.getSource();
1592 }
1593 return {};
1594}
1595
1596template <typename ToBufferOp>
1597static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1598 OpaqueProperties prop,
1599 RegionRange region,
1600 SmallVectorImpl<mlir::Type> &ret) {
1601 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1602 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1603 Type elemTp = nullptr;
1604 bool withStride = false;
1605 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1606 elemTp = stt.getPosType();
1607 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1608 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1609 elemTp = stt.getCrdType();
1610 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1611 withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1612 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1613 elemTp = stt.getElementType();
1614 }
1615
1616 assert(elemTp && "unhandled operation.");
1617 SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1618 bufShape.push_back(Elt: ShapedType::kDynamic);
1619
1620 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1621 context: stt.getContext(), offset: ShapedType::kDynamic,
1622 strides: {ShapedType::kDynamic})
1623 : StridedLayoutAttr();
1624 ret.emplace_back(Args: MemRefType::get(shape: bufShape, elementType: elemTp, layout));
1625 return success();
1626}
1627
1628LogicalResult ToPositionsOp::verify() {
1629 auto stt = getSparseTensorType(val: getTensor());
1630 if (failed(Result: lvlIsInBounds(lvl: getLevel(), tensor: getTensor())))
1631 return emitError(message: "requested level is out of bounds");
1632 if (failed(Result: isMatchingWidth(mem: getResult(), width: stt.getPosWidth())))
1633 return emitError(message: "unexpected type for positions");
1634 return success();
1635}
1636
1637LogicalResult
1638ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1639 ValueRange ops, DictionaryAttr attr,
1640 OpaqueProperties prop, RegionRange region,
1641 SmallVectorImpl<mlir::Type> &ret) {
1642 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1643}
1644
1645LogicalResult ToCoordinatesOp::verify() {
1646 auto stt = getSparseTensorType(val: getTensor());
1647 if (failed(Result: lvlIsInBounds(lvl: getLevel(), tensor: getTensor())))
1648 return emitError(message: "requested level is out of bounds");
1649 if (failed(Result: isMatchingWidth(mem: getResult(), width: stt.getCrdWidth())))
1650 return emitError(message: "unexpected type for coordinates");
1651 return success();
1652}
1653
1654LogicalResult
1655ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1656 ValueRange ops, DictionaryAttr attr,
1657 OpaqueProperties prop, RegionRange region,
1658 SmallVectorImpl<mlir::Type> &ret) {
1659 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1660}
1661
1662LogicalResult ToCoordinatesBufferOp::verify() {
1663 auto stt = getSparseTensorType(val: getTensor());
1664 if (stt.getAoSCOOStart() >= stt.getLvlRank())
1665 return emitError(message: "expected sparse tensor with a COO region");
1666 return success();
1667}
1668
1669LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1670 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1671 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1672 SmallVectorImpl<mlir::Type> &ret) {
1673 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1674 ret);
1675}
1676
1677LogicalResult ToValuesOp::verify() {
1678 auto stt = getSparseTensorType(val: getTensor());
1679 auto mtp = getMemRefType(t: getResult());
1680 if (stt.getElementType() != mtp.getElementType())
1681 return emitError(message: "unexpected mismatch in element types");
1682 return success();
1683}
1684
1685LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1686 std::optional<Location> loc,
1687 ValueRange ops, DictionaryAttr attr,
1688 OpaqueProperties prop,
1689 RegionRange region,
1690 SmallVectorImpl<mlir::Type> &ret) {
1691 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1692}
1693
1694LogicalResult ToSliceOffsetOp::verify() {
1695 auto rank = getSlice().getType().getRank();
1696 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1697 return emitError(message: "requested dimension out of bound");
1698 return success();
1699}
1700
1701LogicalResult ToSliceStrideOp::verify() {
1702 auto rank = getSlice().getType().getRank();
1703 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1704 return emitError(message: "requested dimension out of bound");
1705 return success();
1706}
1707
1708LogicalResult GetStorageSpecifierOp::verify() {
1709 return verifySparsifierGetterSetter(mdKind: getSpecifierKind(), lvl: getLevel(),
1710 md: getSpecifier(), op: getOperation());
1711}
1712
1713template <typename SpecifierOp>
1714static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1715 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1716}
1717
1718OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1719 const StorageSpecifierKind kind = getSpecifierKind();
1720 const auto lvl = getLevel();
1721 for (auto op = getSpecifierSetDef(op: *this); op; op = getSpecifierSetDef(op))
1722 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1723 return op.getValue();
1724 return {};
1725}
1726
1727LogicalResult SetStorageSpecifierOp::verify() {
1728 return verifySparsifierGetterSetter(mdKind: getSpecifierKind(), lvl: getLevel(),
1729 md: getSpecifier(), op: getOperation());
1730}
1731
1732template <class T>
1733static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1734 const char *regionName,
1735 TypeRange inputTypes, Type outputType) {
1736 unsigned numArgs = region.getNumArguments();
1737 unsigned expectedNum = inputTypes.size();
1738 if (numArgs != expectedNum)
1739 return op->emitError() << regionName << " region must have exactly "
1740 << expectedNum << " arguments";
1741
1742 for (unsigned i = 0; i < numArgs; i++) {
1743 Type typ = region.getArgument(i).getType();
1744 if (typ != inputTypes[i])
1745 return op->emitError() << regionName << " region argument " << (i + 1)
1746 << " type mismatch";
1747 }
1748 Operation *term = region.front().getTerminator();
1749 YieldOp yield = dyn_cast<YieldOp>(Val: term);
1750 if (!yield)
1751 return op->emitError() << regionName
1752 << " region must end with sparse_tensor.yield";
1753 if (!yield.hasSingleResult() ||
1754 yield.getSingleResult().getType() != outputType)
1755 return op->emitError() << regionName << " region yield type mismatch";
1756
1757 return success();
1758}
1759
1760LogicalResult BinaryOp::verify() {
1761 NamedAttrList attrs = (*this)->getAttrs();
1762 Type leftType = getX().getType();
1763 Type rightType = getY().getType();
1764 Type outputType = getOutput().getType();
1765 Region &overlap = getOverlapRegion();
1766 Region &left = getLeftRegion();
1767 Region &right = getRightRegion();
1768
1769 // Check correct number of block arguments and return type for each
1770 // non-empty region.
1771 if (!overlap.empty()) {
1772 if (failed(Result: verifyNumBlockArgs(op: this, region&: overlap, regionName: "overlap",
1773 inputTypes: TypeRange{leftType, rightType}, outputType)))
1774 return failure();
1775 }
1776 if (!left.empty()) {
1777 if (failed(Result: verifyNumBlockArgs(op: this, region&: left, regionName: "left", inputTypes: TypeRange{leftType},
1778 outputType)))
1779 return failure();
1780 } else if (getLeftIdentity()) {
1781 if (leftType != outputType)
1782 return emitError(message: "left=identity requires first argument to have the same "
1783 "type as the output");
1784 }
1785 if (!right.empty()) {
1786 if (failed(Result: verifyNumBlockArgs(op: this, region&: right, regionName: "right", inputTypes: TypeRange{rightType},
1787 outputType)))
1788 return failure();
1789 } else if (getRightIdentity()) {
1790 if (rightType != outputType)
1791 return emitError(message: "right=identity requires second argument to have the "
1792 "same type as the output");
1793 }
1794 return success();
1795}
1796
1797LogicalResult UnaryOp::verify() {
1798 Type inputType = getX().getType();
1799 Type outputType = getOutput().getType();
1800
1801 // Check correct number of block arguments and return type for each
1802 // non-empty region.
1803 Region &present = getPresentRegion();
1804 if (!present.empty()) {
1805 if (failed(Result: verifyNumBlockArgs(op: this, region&: present, regionName: "present",
1806 inputTypes: TypeRange{inputType}, outputType)))
1807 return failure();
1808 }
1809 Region &absent = getAbsentRegion();
1810 if (!absent.empty()) {
1811 if (failed(Result: verifyNumBlockArgs(op: this, region&: absent, regionName: "absent", inputTypes: TypeRange{},
1812 outputType)))
1813 return failure();
1814 // Absent branch can only yield invariant values.
1815 Block *absentBlock = &absent.front();
1816 Block *parent = getOperation()->getBlock();
1817 Value absentVal =
1818 cast<YieldOp>(Val: absentBlock->getTerminator()).getSingleResult();
1819 if (auto arg = dyn_cast<BlockArgument>(Val&: absentVal)) {
1820 if (arg.getOwner() == parent)
1821 return emitError(message: "absent region cannot yield linalg argument");
1822 } else if (Operation *def = absentVal.getDefiningOp()) {
1823 if (!isa<arith::ConstantOp>(Val: def) &&
1824 (def->getBlock() == absentBlock || def->getBlock() == parent))
1825 return emitError(message: "absent region cannot yield locally computed value");
1826 }
1827 }
1828 return success();
1829}
1830
1831bool ConcatenateOp::needsExtraSort() {
1832 SparseTensorType dstStt = getSparseTensorType(val: *this);
1833 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1834 return false;
1835
1836 bool allSameOrdered = llvm::all_of(Range: getInputs(), P: [dstStt](Value op) {
1837 return getSparseTensorType(val: op).hasSameDimToLvl(other: dstStt);
1838 });
1839 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1840 // in all input/output buffers, and all input/output buffers have the same
1841 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1842 // CSC matrices along column).
1843 bool directLowerable =
1844 allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1845 return !directLowerable;
1846}
1847
1848LogicalResult ConcatenateOp::verify() {
1849 const auto dstTp = getSparseTensorType(val: *this);
1850 const Dimension concatDim = getDimension();
1851 const Dimension dimRank = dstTp.getDimRank();
1852
1853 if (getInputs().size() <= 1)
1854 return emitError(message: "Need at least two tensors to concatenate.");
1855
1856 if (concatDim >= dimRank)
1857 return emitError(message: llvm::formatv(
1858 Fmt: "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1859 Vals: concatDim, Vals: dimRank));
1860
1861 for (const auto &it : llvm::enumerate(First: getInputs())) {
1862 const auto i = it.index();
1863 const auto srcTp = getSparseTensorType(val: it.value());
1864 if (srcTp.hasDynamicDimShape())
1865 return emitError(message: llvm::formatv(Fmt: "Input tensor ${0} has dynamic shape", Vals: i));
1866 const Dimension srcDimRank = srcTp.getDimRank();
1867 if (srcDimRank != dimRank)
1868 return emitError(
1869 message: llvm::formatv(Fmt: "Input tensor ${0} has a different rank (rank={1}) "
1870 "from the output tensor (rank={2}).",
1871 Vals: i, Vals: srcDimRank, Vals: dimRank));
1872 }
1873
1874 for (Dimension d = 0; d < dimRank; d++) {
1875 const Size dstSh = dstTp.getDimShape()[d];
1876 if (d == concatDim) {
1877 if (ShapedType::isStatic(dValue: dstSh)) {
1878 // If we reach here, then all inputs have static shapes. So we
1879 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1880 // to avoid redundant assertions in the loop.
1881 Size sumSz = 0;
1882 for (const auto src : getInputs())
1883 sumSz += getSparseTensorType(val: src).getDimShape()[d];
1884 // If all dimension are statically known, the sum of all the input
1885 // dimensions should be equal to the output dimension.
1886 if (sumSz != dstSh)
1887 return emitError(
1888 message: "The concatenation dimension of the output tensor should be the "
1889 "sum of all the concatenation dimensions of the input tensors.");
1890 }
1891 } else {
1892 Size prev = dstSh;
1893 for (const auto src : getInputs()) {
1894 const auto sh = getSparseTensorType(val: src).getDimShape()[d];
1895 if (ShapedType::isStatic(dValue: prev) && sh != prev)
1896 return emitError(message: "All dimensions (expect for the concatenating one) "
1897 "should be equal.");
1898 prev = sh;
1899 }
1900 }
1901 }
1902
1903 return success();
1904}
1905
1906void PushBackOp::build(OpBuilder &builder, OperationState &result,
1907 Value curSize, Value inBuffer, Value value) {
1908 build(odsBuilder&: builder, odsState&: result, curSize, inBuffer, value, n: Value());
1909}
1910
1911LogicalResult PushBackOp::verify() {
1912 if (Value n = getN()) {
1913 std::optional<int64_t> nValue = getConstantIntValue(ofr: n);
1914 if (nValue && nValue.value() < 1)
1915 return emitOpError(message: "n must be not less than 1");
1916 }
1917 return success();
1918}
1919
1920LogicalResult CompressOp::verify() {
1921 const auto stt = getSparseTensorType(val: getTensor());
1922 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1923 return emitOpError(message: "incorrect number of coordinates");
1924 return success();
1925}
1926
1927void ForeachOp::build(
1928 OpBuilder &builder, OperationState &result, Value tensor,
1929 ValueRange initArgs, AffineMapAttr order,
1930 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1931 bodyBuilder) {
1932 build(odsBuilder&: builder, odsState&: result, results: initArgs.getTypes(), tensor, initArgs, order);
1933 // Builds foreach body.
1934 if (!bodyBuilder)
1935 return;
1936 const auto stt = getSparseTensorType(val: tensor);
1937 const Dimension dimRank = stt.getDimRank();
1938
1939 // Starts with `dimRank`-many coordinates.
1940 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1941 // Followed by one value.
1942 blockArgTypes.push_back(Elt: stt.getElementType());
1943 // Followed by the reduction variables.
1944 blockArgTypes.append(in_start: initArgs.getTypes().begin(), in_end: initArgs.getTypes().end());
1945
1946 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1947
1948 OpBuilder::InsertionGuard guard(builder);
1949 auto &region = *result.regions.front();
1950 Block *bodyBlock =
1951 builder.createBlock(parent: &region, insertPt: region.end(), argTypes: blockArgTypes, locs: blockArgLocs);
1952 bodyBuilder(builder, result.location,
1953 bodyBlock->getArguments().slice(N: 0, M: dimRank),
1954 bodyBlock->getArguments()[dimRank],
1955 bodyBlock->getArguments().drop_front(N: dimRank + 1));
1956}
1957
1958LogicalResult ForeachOp::verify() {
1959 const auto t = getSparseTensorType(val: getTensor());
1960 const Dimension dimRank = t.getDimRank();
1961 const auto args = getBody()->getArguments();
1962
1963 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1964 return emitError(message: "Level traverse order does not match tensor's level rank");
1965
1966 if (dimRank + 1 + getInitArgs().size() != args.size())
1967 return emitError(message: "Unmatched number of arguments in the block");
1968
1969 if (getNumResults() != getInitArgs().size())
1970 return emitError(message: "Mismatch in number of init arguments and results");
1971
1972 if (getResultTypes() != getInitArgs().getTypes())
1973 return emitError(message: "Mismatch in types of init arguments and results");
1974
1975 // Cannot mark this const, because the getters aren't.
1976 auto yield = cast<YieldOp>(Val: getBody()->getTerminator());
1977 if (yield.getNumOperands() != getNumResults() ||
1978 yield.getOperands().getTypes() != getResultTypes())
1979 return emitError(message: "Mismatch in types of yield values and results");
1980
1981 const auto iTp = IndexType::get(context: getContext());
1982 for (Dimension d = 0; d < dimRank; d++)
1983 if (args[d].getType() != iTp)
1984 return emitError(
1985 message: llvm::formatv(Fmt: "Expecting Index type for argument at index {0}", Vals&: d));
1986
1987 const auto elemTp = t.getElementType();
1988 const auto valueTp = args[dimRank].getType();
1989 if (elemTp != valueTp)
1990 return emitError(
1991 message: llvm::formatv(Fmt: "Unmatched element type between input tensor and "
1992 "block argument, expected:{0}, got: {1}",
1993 Vals: elemTp, Vals: valueTp));
1994 return success();
1995}
1996
1997OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1998 if (getSparseTensorEncoding(type: getInputCoo().getType()) ==
1999 getSparseTensorEncoding(type: getResultCoo().getType()))
2000 return getInputCoo();
2001
2002 return {};
2003}
2004
2005LogicalResult ReorderCOOOp::verify() {
2006 SparseTensorType srcStt = getSparseTensorType(val: getInputCoo());
2007 SparseTensorType dstStt = getSparseTensorType(val: getResultCoo());
2008
2009 if (!srcStt.isCOOType() || !dstStt.isCOOType())
2010 return emitError(message: "Expected COO sparse tensors only");
2011
2012 if (!srcStt.hasSameDimToLvl(other: dstStt))
2013 return emitError(message: "Unmatched dim2lvl map between input and result COO");
2014
2015 if (srcStt.getPosType() != dstStt.getPosType() ||
2016 srcStt.getCrdType() != dstStt.getCrdType() ||
2017 srcStt.getElementType() != dstStt.getElementType())
2018 return emitError(message: "Unmatched storage format between input and result COO");
2019
2020 return success();
2021}
2022
2023LogicalResult ReduceOp::verify() {
2024 Type inputType = getX().getType();
2025 Region &formula = getRegion();
2026 return verifyNumBlockArgs(op: this, region&: formula, regionName: "reduce",
2027 inputTypes: TypeRange{inputType, inputType}, outputType: inputType);
2028}
2029
2030LogicalResult SelectOp::verify() {
2031 Builder b(getContext());
2032 Type inputType = getX().getType();
2033 Type boolType = b.getI1Type();
2034 Region &formula = getRegion();
2035 return verifyNumBlockArgs(op: this, region&: formula, regionName: "select", inputTypes: TypeRange{inputType},
2036 outputType: boolType);
2037}
2038
2039LogicalResult SortOp::verify() {
2040 AffineMap xPerm = getPermMap();
2041 uint64_t nx = xPerm.getNumDims();
2042 if (nx < 1)
2043 return emitError(message: llvm::formatv(Fmt: "Expected rank(perm_map) > 1, got {0}", Vals&: nx));
2044
2045 if (!xPerm.isPermutation())
2046 return emitError(
2047 message: llvm::formatv(Fmt: "Expected a permutation map, got {0}", Vals&: xPerm));
2048
2049 // We can't check the size of the buffers when n or buffer dimensions aren't
2050 // compile-time constants.
2051 std::optional<int64_t> cn = getConstantIntValue(ofr: getN());
2052 if (!cn)
2053 return success();
2054
2055 // Verify dimensions.
2056 const auto checkDim = [&](Value v, Size minSize,
2057 const char *message) -> LogicalResult {
2058 const Size sh = getMemRefType(t&: v).getShape()[0];
2059 if (ShapedType::isStatic(dValue: sh) && sh < minSize)
2060 return emitError(
2061 message: llvm::formatv(Fmt: "{0} got {1} < {2}", Vals&: message, Vals: sh, Vals&: minSize));
2062 return success();
2063 };
2064 uint64_t n = cn.value();
2065 uint64_t ny = 0;
2066 if (auto nyAttr = getNyAttr())
2067 ny = nyAttr.getInt();
2068 if (failed(Result: checkDim(getXy(), n * (nx + ny),
2069 "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2070 return failure();
2071 for (Value opnd : getYs())
2072 if (failed(Result: checkDim(opnd, n, "Expected dimension(y) >= n")))
2073 return failure();
2074
2075 return success();
2076}
2077
2078//===----------------------------------------------------------------------===//
2079// Sparse Tensor Iteration Operations.
2080//===----------------------------------------------------------------------===//
2081
2082IterSpaceType IteratorType::getIterSpaceType() const {
2083 return IterSpaceType::get(context: getContext(), encoding: getEncoding(), loLvl: getLoLvl(),
2084 hiLvl: getHiLvl());
2085}
2086
2087IteratorType IterSpaceType::getIteratorType() const {
2088 return IteratorType::get(context: getContext(), encoding: getEncoding(), loLvl: getLoLvl(), hiLvl: getHiLvl());
2089}
2090
2091/// Parses a level range in the form "$lo `to` $hi"
2092/// or simply "$lo" if $hi - $lo = 1
2093static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2094 Level &lvlHi) {
2095 if (parser.parseInteger(result&: lvlLo))
2096 return failure();
2097
2098 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "to"))) {
2099 if (parser.parseInteger(result&: lvlHi))
2100 return failure();
2101 } else {
2102 lvlHi = lvlLo + 1;
2103 }
2104
2105 if (lvlHi <= lvlLo)
2106 return parser.emitError(loc: parser.getNameLoc(),
2107 message: "expect larger level upper bound than lower bound");
2108
2109 return success();
2110}
2111
2112/// Parses a level range in the form "$lo `to` $hi"
2113/// or simply "$lo" if $hi - $lo = 1
2114static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2115 IntegerAttr &lvlHiAttr) {
2116 Level lvlLo, lvlHi;
2117 if (parseLevelRange(parser, lvlLo, lvlHi))
2118 return failure();
2119
2120 lvlLoAttr = IntegerAttr::get(type: parser.getBuilder().getIndexType(), value: lvlLo);
2121 lvlHiAttr = IntegerAttr::get(type: parser.getBuilder().getIndexType(), value: lvlHi);
2122 return success();
2123}
2124
2125/// Prints a level range in the form "$lo `to` $hi"
2126/// or simply "$lo" if $hi - $lo = 1
2127static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2128
2129 if (lo + 1 == hi)
2130 p << lo;
2131 else
2132 p << lo << " to " << hi;
2133}
2134
2135/// Prints a level range in the form "$lo `to` $hi"
2136/// or simply "$lo" if $hi - $lo = 1
2137static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2138 IntegerAttr lvlHi) {
2139 unsigned lo = lvlLo.getValue().getZExtValue();
2140 unsigned hi = lvlHi.getValue().getZExtValue();
2141 printLevelRange(p, lo, hi);
2142}
2143
2144/// Parses a list of `optional` defined list in the form of
2145/// "(%val0, _, %val1, ...)", where `_` is used to annotate that the
2146/// corresponding value is not defined (e.g., to represent an undefined
2147/// coordinate in the sparse iteration space).
2148static ParseResult parseOptionalDefinedList(
2149 OpAsmParser &parser, OperationState &state, I64BitSet &definedSet,
2150 SmallVectorImpl<OpAsmParser::Argument> &definedArgs,
2151 unsigned maxCnt = std::numeric_limits<unsigned>::max(),
2152 OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) {
2153 unsigned cnt = 0;
2154 ParseResult crdList =
2155 parser.parseCommaSeparatedList(delimiter, parseElementFn: [&]() -> ParseResult {
2156 if (parser.parseOptionalKeyword(keyword: "_")) {
2157 if (parser.parseArgument(result&: definedArgs.emplace_back()))
2158 return failure();
2159 definedSet.set(cnt);
2160 }
2161 cnt += 1;
2162 return success();
2163 });
2164
2165 if (cnt > maxCnt)
2166 return parser.emitError(loc: parser.getNameLoc(),
2167 message: "parsed more value than expected.");
2168
2169 if (failed(Result: crdList)) {
2170 return parser.emitError(
2171 loc: parser.getNameLoc(),
2172 message: "expecting SSA value or \"_\" for level coordinates");
2173 }
2174 assert(definedArgs.size() == definedSet.count());
2175 return success();
2176}
2177
2178static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size,
2179 Block::BlockArgListType blocksArgs,
2180 I64BitSet definedSet) {
2181 if (definedSet.empty())
2182 return;
2183
2184 for (unsigned i = 0; i < size; i++) {
2185 if (definedSet[i]) {
2186 p << blocksArgs.front();
2187 blocksArgs = blocksArgs.drop_front();
2188 } else {
2189 p << "_";
2190 }
2191 if (i != size - 1)
2192 p << ", ";
2193 }
2194 assert(blocksArgs.empty());
2195}
2196
2197static ParseResult
2198parseUsedCoordList(OpAsmParser &parser, OperationState &state,
2199 SmallVectorImpl<OpAsmParser::Argument> &coords) {
2200 // Parse "at(%crd0, _, ...)"
2201 I64BitSet crdUsedLvlSet;
2202 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "at")) &&
2203 failed(Result: parseOptionalDefinedList(parser, state, definedSet&: crdUsedLvlSet, definedArgs&: coords)))
2204 return failure();
2205
2206 // Always use IndexType for the coordinate.
2207 for (auto &coord : coords)
2208 coord.type = parser.getBuilder().getIndexType();
2209
2210 // Set the CrdUsedLvl bitset.
2211 state.addAttribute(name: "crdUsedLvls",
2212 attr: parser.getBuilder().getI64IntegerAttr(value: crdUsedLvlSet));
2213 return success();
2214}
2215
2216static ParseResult
2217parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
2218 SmallVectorImpl<OpAsmParser::Argument> &iterators,
2219 SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
2220 SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2221 SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2222
2223 // Parse "%iters, ... in %spaces, ..."
2224 if (parser.parseArgumentList(result&: iterators) || parser.parseKeyword(keyword: "in") ||
2225 parser.parseOperandList(result&: spaces))
2226 return failure();
2227
2228 if (iterators.size() != spaces.size())
2229 return parser.emitError(
2230 loc: parser.getNameLoc(),
2231 message: "mismatch in number of sparse iterators and sparse spaces");
2232
2233 SmallVector<OpAsmParser::Argument> coords;
2234 if (failed(Result: parseUsedCoordList(parser, state, coords)))
2235 return failure();
2236 size_t numCrds = coords.size();
2237
2238 // Parse "iter_args(%arg = %init, ...)"
2239 bool hasIterArgs = succeeded(Result: parser.parseOptionalKeyword(keyword: "iter_args"));
2240 if (hasIterArgs)
2241 if (parser.parseAssignmentList(lhs&: blockArgs, rhs&: initArgs))
2242 return failure();
2243
2244 blockArgs.append(RHS: coords);
2245
2246 SmallVector<Type> iterSpaceTps;
2247 // parse ": sparse_tensor.iter_space -> ret"
2248 if (parser.parseColon() || parser.parseTypeList(result&: iterSpaceTps))
2249 return failure();
2250 if (iterSpaceTps.size() != spaces.size())
2251 return parser.emitError(loc: parser.getNameLoc(),
2252 message: "mismatch in number of iteration space operands "
2253 "and iteration space types");
2254
2255 for (auto [it, tp] : llvm::zip_equal(t&: iterators, u&: iterSpaceTps)) {
2256 IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(Val&: tp);
2257 if (!spaceTp)
2258 return parser.emitError(loc: parser.getNameLoc(),
2259 message: "expected sparse_tensor.iter_space type for "
2260 "iteration space operands");
2261 it.type = spaceTp.getIteratorType();
2262 }
2263
2264 if (hasIterArgs)
2265 if (parser.parseArrowTypeList(result&: state.types))
2266 return failure();
2267
2268 // Resolves input operands.
2269 if (parser.resolveOperands(operands&: spaces, types&: iterSpaceTps, loc: parser.getNameLoc(),
2270 result&: state.operands))
2271 return failure();
2272
2273 if (hasIterArgs) {
2274 // Strip off leading args that used for coordinates.
2275 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(N: numCrds);
2276 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2277 return parser.emitError(
2278 loc: parser.getNameLoc(),
2279 message: "mismatch in number of iteration arguments and return values");
2280 }
2281
2282 for (auto [it, init, tp] : llvm::zip_equal(t&: args, u&: initArgs, args&: state.types)) {
2283 it.type = tp;
2284 if (parser.resolveOperand(operand: init, type: tp, result&: state.operands))
2285 return failure();
2286 }
2287 }
2288 return success();
2289}
2290
2291static ParseResult
2292parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
2293 SmallVectorImpl<Value> &spacesVals,
2294 SmallVectorImpl<OpAsmParser::Argument> &blockArgs) {
2295
2296 // Parse "(%spaces, ...)"
2297 SmallVector<OpAsmParser::UnresolvedOperand> spaces;
2298 if (parser.parseOperandList(result&: spaces, delimiter: OpAsmParser::Delimiter::Paren))
2299 return failure();
2300
2301 SmallVector<OpAsmParser::Argument> coords;
2302 if (failed(Result: parseUsedCoordList(parser, state, coords)))
2303 return failure();
2304 size_t numCrds = coords.size();
2305
2306 // Parse "iter_args(%arg = %init, ...)"
2307 SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
2308 bool hasIterArgs = succeeded(Result: parser.parseOptionalKeyword(keyword: "iter_args"));
2309 if (hasIterArgs)
2310 if (parser.parseAssignmentList(lhs&: blockArgs, rhs&: initArgs))
2311 return failure();
2312 blockArgs.append(RHS: coords);
2313
2314 SmallVector<Type> iterSpaceTps;
2315 // parse ": (sparse_tensor.iter_space, ...) -> ret"
2316 if (parser.parseColon() || parser.parseLParen() ||
2317 parser.parseTypeList(result&: iterSpaceTps) || parser.parseRParen())
2318 return failure();
2319
2320 if (iterSpaceTps.size() != spaces.size())
2321 return parser.emitError(loc: parser.getNameLoc(),
2322 message: "mismatch in number of iteration space operands "
2323 "and iteration space types");
2324
2325 if (hasIterArgs)
2326 if (parser.parseArrowTypeList(result&: state.types))
2327 return failure();
2328
2329 // Resolves input sparse iteration spaces.
2330 if (parser.resolveOperands(operands&: spaces, types&: iterSpaceTps, loc: parser.getNameLoc(),
2331 result&: spacesVals))
2332 return failure();
2333 state.operands.append(RHS: spacesVals);
2334
2335 if (hasIterArgs) {
2336 // Strip off trailing args that used for coordinates.
2337 MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(N: numCrds);
2338 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
2339 return parser.emitError(
2340 loc: parser.getNameLoc(),
2341 message: "mismatch in number of iteration arguments and return values");
2342 }
2343
2344 for (auto [it, init, tp] : llvm::zip_equal(t&: args, u&: initArgs, args&: state.types)) {
2345 it.type = tp;
2346 if (parser.resolveOperand(operand: init, type: tp, result&: state.operands))
2347 return failure();
2348 }
2349 }
2350 return success();
2351}
2352
2353LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2354 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2355 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2356 SmallVectorImpl<mlir::Type> &ret) {
2357
2358 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2359 SparseTensorType stt = getSparseTensorType(val: adaptor.getTensor());
2360 ret.push_back(Elt: IterSpaceType::get(context: ctx, encoding: stt.getEncoding(), loLvl: adaptor.getLoLvl(),
2361 hiLvl: adaptor.getHiLvl()));
2362 return success();
2363}
2364
2365LogicalResult ExtractIterSpaceOp::verify() {
2366 if (getLoLvl() >= getHiLvl())
2367 return emitOpError(message: "expected smaller level low than level high");
2368
2369 TypedValue<IteratorType> pIter = getParentIter();
2370 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2371 return emitOpError(
2372 message: "parent iterator should be specified iff level lower bound equals 0");
2373 }
2374
2375 if (pIter) {
2376 IterSpaceType spaceTp = getExtractedSpace().getType();
2377 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2378 return emitOpError(
2379 message: "mismatch in parent iterator encoding and iteration space encoding.");
2380
2381 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2382 return emitOpError(message: "parent iterator should be used to extract an "
2383 "iteration space from a consecutive level.");
2384 }
2385
2386 return success();
2387}
2388
2389LogicalResult ExtractValOp::verify() {
2390 auto stt = getSparseTensorType(val: getTensor());
2391 auto itTp = getIterator().getType();
2392
2393 if (stt.getEncoding() != itTp.getEncoding())
2394 return emitOpError(message: "mismatch in tensor encoding and iterator encoding.");
2395
2396 if (stt.getLvlRank() != itTp.getHiLvl())
2397 return emitOpError(message: "must use last-level iterator to extract values. ");
2398
2399 return success();
2400}
2401
2402struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2403 using OpRewritePattern::OpRewritePattern;
2404
2405 LogicalResult matchAndRewrite(IterateOp iterateOp,
2406 PatternRewriter &rewriter) const override {
2407 I64BitSet newUsedLvls(0);
2408 llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2409 for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2410 if (auto crd = iterateOp.getLvlCrd(lvl: i)) {
2411 if (crd->getUsers().empty())
2412 toRemove.set(crd->getArgNumber());
2413 else
2414 newUsedLvls.set(i);
2415 }
2416 }
2417
2418 // All coordinates are used.
2419 if (toRemove.none())
2420 return failure();
2421
2422 rewriter.startOpModification(op: iterateOp);
2423 iterateOp.setCrdUsedLvls(newUsedLvls);
2424 iterateOp.getBody()->eraseArguments(eraseIndices: toRemove);
2425 rewriter.finalizeOpModification(op: iterateOp);
2426 return success();
2427 }
2428};
2429
2430void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2431 mlir::MLIRContext *context) {
2432 results.add<RemoveUnusedLvlCrds>(arg&: context);
2433}
2434
2435void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2436 Value iterSpace, ValueRange initArgs) {
2437 unsigned rank = llvm::cast<IterSpaceType>(Val: iterSpace.getType()).getSpaceDim();
2438 // All ones.
2439 I64BitSet set((1 << rank) - 1);
2440 return build(odsBuilder&: builder, odsState, iterSpace, initArgs, crdUsedLvls: set);
2441}
2442
2443void IterateOp::build(OpBuilder &builder, OperationState &odsState,
2444 Value iterSpace, ValueRange initArgs,
2445 I64BitSet crdUsedLvls) {
2446 OpBuilder::InsertionGuard guard(builder);
2447
2448 odsState.addOperands(newOperands: iterSpace);
2449 odsState.addOperands(newOperands: initArgs);
2450 odsState.getOrAddProperties<Properties>().crdUsedLvls =
2451 builder.getIntegerAttr(type: builder.getIntegerType(width: 64), value: crdUsedLvls);
2452 Region *bodyRegion = odsState.addRegion();
2453 odsState.addTypes(newTypes: initArgs.getTypes());
2454 Block *bodyBlock = builder.createBlock(parent: bodyRegion);
2455
2456 // Starts with a list of user-provided loop arguments.
2457 for (Value v : initArgs)
2458 bodyBlock->addArgument(type: v.getType(), loc: v.getLoc());
2459
2460 // Follows by a list of used coordinates.
2461 for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
2462 bodyBlock->addArgument(type: builder.getIndexType(), loc: odsState.location);
2463
2464 // Ends with sparse iterator
2465 bodyBlock->addArgument(
2466 type: llvm::cast<IterSpaceType>(Val: iterSpace.getType()).getIteratorType(),
2467 loc: odsState.location);
2468}
2469
2470ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
2471 OpAsmParser::Argument iterator;
2472 OpAsmParser::UnresolvedOperand iterSpace;
2473
2474 SmallVector<OpAsmParser::Argument> iters, iterArgs;
2475 if (parseSparseIterateLoop(parser, state&: result, iterators&: iters, blockArgs&: iterArgs))
2476 return failure();
2477 if (iters.size() != 1)
2478 return parser.emitError(loc: parser.getNameLoc(),
2479 message: "expected only one iterator/iteration space");
2480
2481 iterArgs.append(RHS: iters);
2482 Region *body = result.addRegion();
2483 if (parser.parseRegion(region&: *body, arguments: iterArgs))
2484 return failure();
2485
2486 IterateOp::ensureTerminator(region&: *body, builder&: parser.getBuilder(), loc: result.location);
2487
2488 // Parse the optional attribute list.
2489 if (parser.parseOptionalAttrDict(result&: result.attributes))
2490 return failure();
2491
2492 return success();
2493}
2494
2495/// Prints the initialization list in the form of
2496/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
2497/// where 'inner' values are assumed to be region arguments and 'outer' values
2498/// are regular SSA values.
2499static void printInitializationList(OpAsmPrinter &p,
2500 Block::BlockArgListType blocksArgs,
2501 ValueRange initializers,
2502 StringRef prefix = "") {
2503 assert(blocksArgs.size() == initializers.size() &&
2504 "expected same length of arguments and initializers");
2505 if (initializers.empty())
2506 return;
2507
2508 p << prefix << '(';
2509 llvm::interleaveComma(c: llvm::zip(t&: blocksArgs, u&: initializers), os&: p, each_fn: [&](auto it) {
2510 p << std::get<0>(it) << " = " << std::get<1>(it);
2511 });
2512 p << ")";
2513}
2514
2515template <typename SparseLoopOp>
2516static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2517 if (op.getInitArgs().size() != op.getNumResults()) {
2518 return op.emitOpError(
2519 "mismatch in number of loop-carried values and defined values");
2520 }
2521 if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2522 return op.emitOpError("required out-of-bound coordinates");
2523
2524 return success();
2525}
2526
2527LogicalResult IterateOp::verify() { return verifySparseLoopOp(op: *this); }
2528LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(op: *this); }
2529
2530void IterateOp::print(OpAsmPrinter &p) {
2531 p << " " << getIterator() << " in " << getIterSpace();
2532 if (!getCrdUsedLvls().empty()) {
2533 p << " at(";
2534 printOptionalDefinedList(p, size: getSpaceDim(), blocksArgs: getCrds(), definedSet: getCrdUsedLvls());
2535 p << ")";
2536 }
2537 printInitializationList(p, blocksArgs: getRegionIterArgs(), initializers: getInitArgs(), prefix: " iter_args");
2538
2539 p << " : " << getIterSpace().getType() << " ";
2540 if (!getInitArgs().empty())
2541 p.printArrowTypeList(types: getInitArgs().getTypes());
2542
2543 p << " ";
2544 p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false,
2545 /*printBlockTerminators=*/!getInitArgs().empty());
2546}
2547
2548LogicalResult IterateOp::verifyRegions() {
2549 if (getIterator().getType() != getIterSpace().getType().getIteratorType())
2550 return emitOpError(message: "mismatch in iterator and iteration space type");
2551 if (getNumRegionIterArgs() != getNumResults())
2552 return emitOpError(
2553 message: "mismatch in number of basic block args and defined values");
2554
2555 auto initArgs = getInitArgs();
2556 auto iterArgs = getRegionIterArgs();
2557 auto yieldVals = getYieldedValues();
2558 auto opResults = getResults();
2559 if (!llvm::all_equal(Values: {initArgs.size(), iterArgs.size(), yieldVals.size(),
2560 opResults.size()})) {
2561 return emitOpError() << "number mismatch between iter args and results.";
2562 }
2563
2564 for (auto [i, init, iter, yield, ret] :
2565 llvm::enumerate(First&: initArgs, Rest&: iterArgs, Rest&: yieldVals, Rest&: opResults)) {
2566 if (init.getType() != ret.getType())
2567 return emitOpError() << "types mismatch between " << i
2568 << "th iter operand and defined value";
2569 if (iter.getType() != ret.getType())
2570 return emitOpError() << "types mismatch between " << i
2571 << "th iter region arg and defined value";
2572 if (yield.getType() != ret.getType())
2573 return emitOpError() << "types mismatch between " << i
2574 << "th yield value and defined value";
2575 }
2576
2577 return success();
2578}
2579
2580/// OpInterfaces' methods implemented by IterateOp.
2581SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
2582
2583MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
2584 return getInitArgsMutable();
2585}
2586
2587Block::BlockArgListType IterateOp::getRegionIterArgs() {
2588 return getRegion().getArguments().take_front(N: getNumRegionIterArgs());
2589}
2590
2591std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
2592 return cast<sparse_tensor::YieldOp>(
2593 Val: getRegion().getBlocks().front().getTerminator())
2594 .getResultsMutable();
2595}
2596
2597std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
2598
2599OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2600 return getInitArgs();
2601}
2602
2603void IterateOp::getSuccessorRegions(RegionBranchPoint point,
2604 SmallVectorImpl<RegionSuccessor> &regions) {
2605 // Both the operation itself and the region may be branching into the body
2606 // or back into the operation itself.
2607 regions.push_back(Elt: RegionSuccessor(&getRegion(), getRegionIterArgs()));
2608 // It is possible for loop not to enter the body.
2609 regions.push_back(Elt: RegionSuccessor(getResults()));
2610}
2611
2612void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2613 ValueRange iterSpaces, ValueRange initArgs,
2614 unsigned numCases) {
2615 unsigned rank =
2616 cast<IterSpaceType>(Val: iterSpaces.front().getType()).getSpaceDim();
2617 // All ones.
2618 I64BitSet set((1 << rank) - 1);
2619 // Generates all-zero case bits (they only serve as placeholders), which are
2620 // supposed to be overriden later. We need to preallocate all the regions as
2621 // mlir::Region cannot be dynamically added later after the operation is
2622 // created.
2623 SmallVector<int64_t> caseBits(numCases, 0);
2624 ArrayAttr cases = builder.getI64ArrayAttr(values: caseBits);
2625 return CoIterateOp::build(odsBuilder&: builder, odsState, results: initArgs.getTypes(), iterSpaces,
2626 initArgs, crdUsedLvls: set, cases,
2627 /*caseRegionsCount=*/numCases);
2628}
2629
2630ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
2631
2632 SmallVector<Value> spaces;
2633 // The block argument list of each regions, it is arranged in the order of
2634 // ([used coordinate list], [loop iterations args], [sparse iterator list]).
2635 SmallVector<OpAsmParser::Argument> blockArgs;
2636 if (parseSparseCoIterateLoop(parser, state&: result, spacesVals&: spaces, blockArgs))
2637 return failure();
2638
2639 result.addAttribute(name: "operandSegmentSizes",
2640 attr: parser.getBuilder().getDenseI32ArrayAttr(
2641 values: {static_cast<int32_t>(spaces.size()),
2642 static_cast<int32_t>(result.types.size())}));
2643
2644 SmallVector<Attribute> cases;
2645 while (succeeded(Result: parser.parseOptionalKeyword(keyword: "case"))) {
2646 // Parse one region per case.
2647 I64BitSet definedItSet;
2648 SmallVector<OpAsmParser::Argument> definedIts;
2649 if (parseOptionalDefinedList(parser, state&: result, definedSet&: definedItSet, definedArgs&: definedIts,
2650 maxCnt: spaces.size(), delimiter: OpAsmParser::Delimiter::None))
2651 return failure();
2652
2653 cases.push_back(Elt: parser.getBuilder().getI64IntegerAttr(value: definedItSet));
2654
2655 for (auto [i, definedIdx] : llvm::enumerate(First: definedItSet.bits())) {
2656 // Resolve the iterator type based on the iteration space type.
2657 auto spaceTp = llvm::cast<IterSpaceType>(Val: spaces[definedIdx].getType());
2658 definedIts[i].type = spaceTp.getIteratorType();
2659 }
2660 definedIts.insert(I: definedIts.begin(), From: blockArgs.begin(), To: blockArgs.end());
2661 Region *body = result.addRegion();
2662 if (parser.parseRegion(region&: *body, arguments: definedIts))
2663 return failure();
2664
2665 CoIterateOp::ensureTerminator(region&: *body, builder&: parser.getBuilder(), loc: result.location);
2666 }
2667
2668 result.addAttribute(name: "cases", attr: ArrayAttr::get(context: parser.getContext(), value: cases));
2669
2670 // Parse the optional attribute list.
2671 if (parser.parseOptionalAttrDict(result&: result.attributes))
2672 return failure();
2673
2674 return success();
2675}
2676
2677void CoIterateOp::print(OpAsmPrinter &p) {
2678 p << " (";
2679 llvm::interleaveComma(c: getIterSpaces(), os&: p, each_fn: [&](auto s) { p << s; });
2680 p << ")";
2681
2682 if (!getCrdUsedLvls().empty()) {
2683 p << " at(";
2684 printOptionalDefinedList(p, size: getSpaceDim(), blocksArgs: getCrds(regionIdx: 0), definedSet: getCrdUsedLvls());
2685 p << ")";
2686 }
2687
2688 printInitializationList(p, blocksArgs: getRegionIterArgs(regionIdx: 0), initializers: getInitArgs(), prefix: " iter_args");
2689
2690 p << " : (" << getIterSpaces().getTypes() << ")";
2691 if (!getInitArgs().empty())
2692 p.printArrowTypeList(types: getInitArgs().getTypes());
2693
2694 for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) {
2695 p.printNewline();
2696 p << "case ";
2697 printOptionalDefinedList(p, size: getIterSpaces().size(), blocksArgs: getRegionIterators(regionIdx: idx),
2698 definedSet: getRegionDefinedSpace(regionIdx: idx));
2699 p << " ";
2700 p.printRegion(blocks&: getRegion(i: idx), /*printEntryBlockArgs=*/false,
2701 /*printBlockTerminators=*/!getInitArgs().empty());
2702 }
2703}
2704
2705ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2706 return cast<sparse_tensor::YieldOp>(
2707 Val: getRegion(i: regionIdx).getBlocks().front().getTerminator())
2708 .getResults();
2709}
2710
2711LogicalResult CoIterateOp::verifyRegions() {
2712 for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2713 if (getNumRegionIterArgs() != getNumResults())
2714 return emitOpError(
2715 message: "mismatch in number of basic block args and defined values");
2716
2717 auto initArgs = getInitArgs();
2718 auto iterArgs = getRegionIterArgs(regionIdx: r);
2719 auto yieldVals = getYieldedValues(regionIdx: r);
2720 auto opResults = getResults();
2721 if (!llvm::all_equal(Values: {initArgs.size(), iterArgs.size(), yieldVals.size(),
2722 opResults.size()})) {
2723 return emitOpError()
2724 << "number mismatch between iter args and results on " << r
2725 << "th region";
2726 }
2727
2728 for (auto [i, init, iter, yield, ret] :
2729 llvm::enumerate(First&: initArgs, Rest&: iterArgs, Rest&: yieldVals, Rest&: opResults)) {
2730 if (init.getType() != ret.getType())
2731 return emitOpError()
2732 << "types mismatch between " << i
2733 << "th iter operand and defined value on " << r << "th region";
2734 if (iter.getType() != ret.getType())
2735 return emitOpError() << "types mismatch between " << i
2736 << "th iter region arg and defined value on " << r
2737 << "th region";
2738 if (yield.getType() != ret.getType())
2739 return emitOpError()
2740 << "types mismatch between " << i
2741 << "th yield value and defined value on " << r << "th region";
2742 }
2743 }
2744
2745 auto cases = getRegionDefinedSpaces();
2746 llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2747 if (set.size() != getNumRegions())
2748 return emitOpError(message: "contains duplicated cases.");
2749
2750 return success();
2751}
2752
2753SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2754 SmallVector<Region *> ret;
2755 I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2756 for (Region &r : getCaseRegions())
2757 if (getRegionDefinedSpace(regionIdx: r.getRegionNumber()).isSubSetOf(p: caseBit))
2758 ret.push_back(Elt: &r);
2759
2760 return ret;
2761}
2762
2763//===----------------------------------------------------------------------===//
2764// Sparse Tensor Dialect Setups.
2765//===----------------------------------------------------------------------===//
2766
2767/// Materialize a single constant operation from a given attribute value with
2768/// the desired resultant type.
2769Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2770 Attribute value, Type type,
2771 Location loc) {
2772 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2773 return op;
2774 return nullptr;
2775}
2776
2777void SparseTensorDialect::initialize() {
2778 addAttributes<
2779#define GET_ATTRDEF_LIST
2780#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2781 >();
2782 addTypes<
2783#define GET_TYPEDEF_LIST
2784#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2785 >();
2786 addOperations<
2787#define GET_OP_LIST
2788#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2789 >();
2790 declarePromisedInterfaces<
2791 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2792 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2793 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2794}
2795
2796#define GET_OP_CLASSES
2797#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2798
2799#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
2800

source code of mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp