1//===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "Detail/DimLvlMapParser.h"
12
13#include "mlir/Dialect/SparseTensor/IR/Enums.h"
14#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
17
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/IR/Builders.h"
22#include "mlir/IR/DialectImplementation.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/IR/OpImplementation.h"
25#include "mlir/IR/PatternMatch.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/FormatVariadic.h"
28
29#define GET_ATTRDEF_CLASSES
30#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
31#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32
33// Forward declarations, following custom print/parsing methods are referenced
34// by the generated code for SparseTensorTypes.td.
35static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
36 mlir::sparse_tensor::Level &,
37 mlir::sparse_tensor::Level &);
38static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
39 mlir::sparse_tensor::Level);
40
41#define GET_TYPEDEF_CLASSES
42#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
43
44using namespace mlir;
45using namespace mlir::sparse_tensor;
46
47// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
48// well.
49namespace mlir::sparse_tensor {
50llvm::hash_code hash_value(LevelType lt) {
51 return llvm::hash_value(static_cast<uint64_t>(lt));
52}
53} // namespace mlir::sparse_tensor
54
55//===----------------------------------------------------------------------===//
56// Local Convenience Methods.
57//===----------------------------------------------------------------------===//
58
59static constexpr bool acceptBitWidth(unsigned bitWidth) {
60 switch (bitWidth) {
61 case 0:
62 case 8:
63 case 16:
64 case 32:
65 case 64:
66 return true;
67 default:
68 return false;
69 }
70}
71
72static SmallVector<Size>
73getSparseFieldShape(const SparseTensorEncodingAttr enc,
74 std::optional<ArrayRef<int64_t>> dimShape) {
75 assert(enc);
76 // With only encoding, we can not determine the static shape for leading
77 // batch levels, we therefore return a dynamic shape memref instead.
78 SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
79 if (dimShape.has_value()) {
80 // If the actual tensor shape is provided, we can then refine the leading
81 // batch dimension.
82 SmallVector<int64_t> lvlShape =
83 enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
84 memrefShape.assign(lvlShape.begin(),
85 lvlShape.begin() + enc.getBatchLvlRank());
86 }
87 // Another dynamic dimension to store the sparse level.
88 memrefShape.push_back(ShapedType::kDynamic);
89 return memrefShape;
90}
91
92//===----------------------------------------------------------------------===//
93// SparseTensorDialect StorageLayout.
94//===----------------------------------------------------------------------===//
95
96static constexpr Level kInvalidLevel = -1u;
97static constexpr Level kInvalidFieldIndex = -1u;
98static constexpr FieldIndex kDataFieldStartingIdx = 0;
99
100void StorageLayout::foreachField(
101 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
102 LevelType)>
103 callback) const {
104 const auto lvlTypes = enc.getLvlTypes();
105 const Level lvlRank = enc.getLvlRank();
106 SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
107 FieldIndex fieldIdx = kDataFieldStartingIdx;
108
109 ArrayRef cooSegsRef = cooSegs;
110 // Per-level storage.
111 for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
112 const auto lt = lvlTypes[l];
113 if (isWithPosLT(lt)) {
114 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
115 return;
116 }
117 if (isWithCrdLT(lt)) {
118 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
119 return;
120 }
121 if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
122 if (!cooSegsRef.front().isSoA) {
123 // AoS COO, all singletons are fused into one memrefs. Skips the entire
124 // COO segement.
125 l = cooSegsRef.front().lvlRange.second;
126 } else {
127 // SoA COO, each singleton level has one memref.
128 l++;
129 }
130 // Expire handled COO segment.
131 cooSegsRef = cooSegsRef.drop_front();
132 } else {
133 // Non COO levels.
134 l++;
135 }
136 }
137 // The values array.
138 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
139 LevelFormat::Undef)))
140 return;
141 // Put metadata at the end.
142 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
143 LevelFormat::Undef)))
144 return;
145}
146
147void sparse_tensor::foreachFieldAndTypeInSparseTensor(
148 SparseTensorType stt,
149 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
150 LevelType)>
151 callback) {
152 assert(stt.hasEncoding());
153
154 SmallVector<int64_t> memrefShape =
155 getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
156
157 const Type specType = StorageSpecifierType::get(stt.getEncoding());
158 // memref<[batch] x ? x pos> positions
159 const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
160 // memref<[batch] x ? x crd> coordinates
161 const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
162 // memref<[batch] x ? x eltType> values
163 const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
164
165 StorageLayout(stt).foreachField(callback: [specType, posMemType, crdMemType, valMemType,
166 callback](FieldIndex fieldIdx,
167 SparseTensorFieldKind fieldKind,
168 Level lvl, LevelType lt) -> bool {
169 switch (fieldKind) {
170 case SparseTensorFieldKind::StorageSpec:
171 return callback(specType, fieldIdx, fieldKind, lvl, lt);
172 case SparseTensorFieldKind::PosMemRef:
173 return callback(posMemType, fieldIdx, fieldKind, lvl, lt);
174 case SparseTensorFieldKind::CrdMemRef:
175 return callback(crdMemType, fieldIdx, fieldKind, lvl, lt);
176 case SparseTensorFieldKind::ValMemRef:
177 return callback(valMemType, fieldIdx, fieldKind, lvl, lt);
178 };
179 llvm_unreachable("unrecognized field kind");
180 });
181}
182
183unsigned StorageLayout::getNumFields() const {
184 unsigned numFields = 0;
185 foreachField(callback: [&numFields](FieldIndex, SparseTensorFieldKind, Level,
186 LevelType) -> bool {
187 numFields++;
188 return true;
189 });
190 return numFields;
191}
192
193unsigned StorageLayout::getNumDataFields() const {
194 unsigned numFields = 0; // one value memref
195 foreachField(callback: [&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
196 LevelType) -> bool {
197 if (fidx >= kDataFieldStartingIdx)
198 numFields++;
199 return true;
200 });
201 numFields -= 1; // the last field is StorageSpecifier
202 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
203 return numFields;
204}
205
206std::pair<FieldIndex, unsigned>
207StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
208 std::optional<Level> lvl) const {
209 FieldIndex fieldIdx = kInvalidFieldIndex;
210 unsigned stride = 1;
211 if (kind == SparseTensorFieldKind::CrdMemRef) {
212 assert(lvl.has_value());
213 const Level cooStart = SparseTensorType(enc).getAoSCOOStart();
214 const Level lvlRank = enc.getLvlRank();
215 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
216 lvl = cooStart;
217 stride = lvlRank - cooStart;
218 }
219 }
220 foreachField(callback: [lvl, kind, &fieldIdx](FieldIndex fIdx,
221 SparseTensorFieldKind fKind, Level fLvl,
222 LevelType lt) -> bool {
223 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
224 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
225 fieldIdx = fIdx;
226 // Returns false to break the iteration.
227 return false;
228 }
229 return true;
230 });
231 assert(fieldIdx != kInvalidFieldIndex);
232 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
233}
234
235//===----------------------------------------------------------------------===//
236// SparseTensorDialect Attribute Methods.
237//===----------------------------------------------------------------------===//
238
239std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
240 return isDynamic(v) ? std::nullopt
241 : std::make_optional(static_cast<uint64_t>(v));
242}
243
244std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
245 return getStatic(getOffset());
246}
247
248std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
249 return getStatic(getStride());
250}
251
252std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
253 return getStatic(getSize());
254}
255
256bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
257 return isDynamic(getOffset()) && isDynamic(getStride()) &&
258 isDynamic(getSize());
259}
260
261std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
262 return isDynamic(v) ? "?" : std::to_string(v);
263}
264
265void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
266 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
267 os << '(';
268 os << getStaticString(getOffset());
269 os << ", ";
270 os << getStaticString(getSize());
271 os << ", ";
272 os << getStaticString(getStride());
273 os << ')';
274}
275
276void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
277 print(printer.getStream());
278}
279
280static ParseResult parseOptionalStaticSlice(int64_t &result,
281 AsmParser &parser) {
282 auto parseResult = parser.parseOptionalInteger(result);
283 if (parseResult.has_value()) {
284 if (parseResult.value().succeeded() && result < 0) {
285 parser.emitError(
286 loc: parser.getCurrentLocation(),
287 message: "expect positive value or ? for slice offset/size/stride");
288 return failure();
289 }
290 return parseResult.value();
291 }
292
293 // Else, and '?' which represented dynamic slice
294 result = SparseTensorDimSliceAttr::kDynamic;
295 return parser.parseQuestion();
296}
297
298Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
299 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
300
301 if (failed(parser.parseLParen()) ||
302 failed(parseOptionalStaticSlice(offset, parser)) ||
303 failed(parser.parseComma()) ||
304 failed(parseOptionalStaticSlice(size, parser)) ||
305 failed(parser.parseComma()) ||
306 failed(parseOptionalStaticSlice(stride, parser)) ||
307 failed(parser.parseRParen()))
308 return {};
309
310 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
311 offset, size, stride);
312}
313
314LogicalResult
315SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
316 int64_t offset, int64_t size, int64_t stride) {
317 if (!isDynamic(offset) && offset < 0)
318 return emitError() << "expect non-negative value or ? for slice offset";
319 if (!isDynamic(size) && size <= 0)
320 return emitError() << "expect positive value or ? for slice size";
321 if (!isDynamic(stride) && stride <= 0)
322 return emitError() << "expect positive value or ? for slice stride";
323 return success();
324}
325
326SparseTensorEncodingAttr
327SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
328 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
329 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
330 AffineMap(), getPosWidth(),
331 getCrdWidth());
332}
333
334SparseTensorEncodingAttr
335SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
336 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
337}
338
339SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
340 return withDimToLvl(AffineMap());
341}
342
343SparseTensorEncodingAttr
344SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
345 unsigned crdWidth) const {
346 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
347 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
348 getDimToLvl(), getLvlToDim(), posWidth,
349 crdWidth);
350}
351
352SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
353 return withBitWidths(0, 0);
354}
355
356SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
357 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
358 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
359 getDimToLvl(), getLvlToDim(),
360 getPosWidth(), getCrdWidth(), dimSlices);
361}
362
363SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
364 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
365}
366
367uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const {
368 ArrayRef<LevelType> lvlTypes = getLvlTypes();
369 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
370 return std::distance(lastBatch, lvlTypes.rend());
371}
372
373bool SparseTensorEncodingAttr::isAllDense() const {
374 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
375}
376
377bool SparseTensorEncodingAttr::isAllOrdered() const {
378 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
379}
380
381Type SparseTensorEncodingAttr::getCrdElemType() const {
382 if (!getImpl())
383 return nullptr;
384 if (getCrdWidth())
385 return IntegerType::get(getContext(), getCrdWidth());
386 return IndexType::get(getContext());
387}
388
389Type SparseTensorEncodingAttr::getPosElemType() const {
390 if (!getImpl())
391 return nullptr;
392 if (getPosWidth())
393 return IntegerType::get(getContext(), getPosWidth());
394 return IndexType::get(getContext());
395}
396
397MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
398 std::optional<ArrayRef<int64_t>> dimShape) const {
399 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
400 return MemRefType::get(shape, getCrdElemType());
401}
402
403MemRefType SparseTensorEncodingAttr::getPosMemRefType(
404 std::optional<ArrayRef<int64_t>> dimShape) const {
405 SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
406 return MemRefType::get(shape, getPosElemType());
407}
408
409bool SparseTensorEncodingAttr::isIdentity() const {
410 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
411}
412
413bool SparseTensorEncodingAttr::isPermutation() const {
414 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
415}
416
417Dimension SparseTensorEncodingAttr::getDimRank() const {
418 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
419 const auto dimToLvl = getDimToLvl();
420 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
421}
422
423Level SparseTensorEncodingAttr::getLvlRank() const {
424 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
425 return getLvlTypes().size();
426}
427
428LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
429 if (!getImpl())
430 return LevelFormat::Batch;
431 assert(l < getLvlRank() && "Level is out of bounds");
432 return getLvlTypes()[l];
433}
434
435bool SparseTensorEncodingAttr::isSlice() const {
436 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
437 return !getDimSlices().empty();
438}
439
440SparseTensorDimSliceAttr
441SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
442 assert(isSlice() && "Is not a slice");
443 const auto dimSlices = getDimSlices();
444 assert(dim < dimSlices.size() && "Dimension is out of bounds");
445 return dimSlices[dim];
446}
447
448std::optional<uint64_t>
449SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
450 return getDimSlice(dim).getStaticOffset();
451}
452
453std::optional<uint64_t>
454SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
455 return getDimSlice(dim).getStaticStride();
456}
457
458std::optional<uint64_t>
459SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
460 return getStaticDimSliceOffset(toDim(*this, lvl));
461}
462
463std::optional<uint64_t>
464SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
465 return getStaticDimSliceStride(toDim(*this, lvl));
466}
467
468SmallVector<int64_t>
469SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
470 CrdTransDirectionKind dir) const {
471 if (isIdentity())
472 return SmallVector<int64_t>(srcShape);
473
474 SmallVector<int64_t> ret;
475 unsigned rank =
476 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
477 ret.reserve(rank);
478
479 if (isPermutation()) {
480 for (unsigned r = 0; r < rank; r++) {
481 unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r)
482 : toLvl(*this, r);
483 ret.push_back(srcShape[trans]);
484 }
485 return ret;
486 }
487
488 // Handle non-permutation maps.
489 AffineMap transMap =
490 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
491
492 SmallVector<AffineExpr> dimRep;
493 dimRep.reserve(srcShape.size());
494 for (int64_t sz : srcShape) {
495 if (!ShapedType::isDynamic(sz)) {
496 // Push back the max coordinate for the given dimension/level size.
497 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
498 } else {
499 // A dynamic size, use a AffineDimExpr to symbolize the value.
500 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
501 }
502 };
503
504 for (AffineExpr exp : transMap.getResults()) {
505 // Do constant propagation on the affine map.
506 AffineExpr evalExp =
507 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
508 // use llvm namespace here to avoid ambiguity
509 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
510 ret.push_back(c.getValue() + 1);
511 } else {
512 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
513 mod && mod.getKind() == AffineExprKind::Mod) {
514 // We can still infer a static bound for expressions in form
515 // "d % constant" since d % constant \in [0, constant).
516 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
517 ret.push_back(bound.getValue());
518 continue;
519 }
520 }
521 ret.push_back(ShapedType::kDynamic);
522 }
523 }
524 assert(ret.size() == rank);
525 return ret;
526}
527
528ValueRange
529SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
530 ValueRange crds,
531 CrdTransDirectionKind dir) const {
532 if (!getImpl())
533 return crds;
534
535 SmallVector<Type> retType(
536 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
537 builder.getIndexType());
538 auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
539 return transOp.getOutCrds();
540}
541
542Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
543 // Open "<{" part.
544 if (failed(parser.parseLess()))
545 return {};
546 if (failed(parser.parseLBrace()))
547 return {};
548
549 // Process the data from the parsed dictionary value into struct-like data.
550 SmallVector<LevelType> lvlTypes;
551 SmallVector<SparseTensorDimSliceAttr> dimSlices;
552 AffineMap dimToLvl = {};
553 AffineMap lvlToDim = {};
554 unsigned posWidth = 0;
555 unsigned crdWidth = 0;
556 StringRef attrName;
557 SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"};
558 while (succeeded(parser.parseOptionalKeyword(&attrName))) {
559 // Detect admissible keyword.
560 auto *it = find(keys, attrName);
561 if (it == keys.end()) {
562 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
563 return {};
564 }
565 unsigned keyWordIndex = it - keys.begin();
566 // Consume the `=` after keys
567 if (failed(parser.parseEqual()))
568 return {};
569 // Dispatch on keyword.
570 switch (keyWordIndex) {
571 case 0: { // map
572 ir_detail::DimLvlMapParser cParser(parser);
573 auto res = cParser.parseDimLvlMap();
574 if (failed(res))
575 return {};
576 const auto &dlm = *res;
577
578 const Level lvlRank = dlm.getLvlRank();
579 for (Level lvl = 0; lvl < lvlRank; lvl++)
580 lvlTypes.push_back(dlm.getLvlType(lvl));
581
582 const Dimension dimRank = dlm.getDimRank();
583 for (Dimension dim = 0; dim < dimRank; dim++)
584 dimSlices.push_back(dlm.getDimSlice(dim));
585 // NOTE: the old syntax requires an all-or-nothing approach to
586 // `dimSlices`; therefore, if any slice actually exists then we need
587 // to convert null-DSA into default/nop DSA.
588 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
589 return static_cast<bool>(slice.getImpl());
590 };
591 if (llvm::any_of(dimSlices, isDefined)) {
592 const auto defaultSlice =
593 SparseTensorDimSliceAttr::get(parser.getContext());
594 for (Dimension dim = 0; dim < dimRank; dim++)
595 if (!isDefined(dimSlices[dim]))
596 dimSlices[dim] = defaultSlice;
597 } else {
598 dimSlices.clear();
599 }
600
601 dimToLvl = dlm.getDimToLvlMap(parser.getContext());
602 lvlToDim = dlm.getLvlToDimMap(parser.getContext());
603 break;
604 }
605 case 1: { // posWidth
606 Attribute attr;
607 if (failed(parser.parseAttribute(attr)))
608 return {};
609 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
610 if (!intAttr) {
611 parser.emitError(parser.getNameLoc(),
612 "expected an integral position bitwidth");
613 return {};
614 }
615 posWidth = intAttr.getInt();
616 break;
617 }
618 case 2: { // crdWidth
619 Attribute attr;
620 if (failed(parser.parseAttribute(attr)))
621 return {};
622 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
623 if (!intAttr) {
624 parser.emitError(parser.getNameLoc(),
625 "expected an integral index bitwidth");
626 return {};
627 }
628 crdWidth = intAttr.getInt();
629 break;
630 }
631 } // switch
632 // Only last item can omit the comma.
633 if (parser.parseOptionalComma().failed())
634 break;
635 }
636
637 // Close "}>" part.
638 if (failed(parser.parseRBrace()))
639 return {};
640 if (failed(parser.parseGreater()))
641 return {};
642
643 // Construct struct-like storage for attribute.
644 if (!lvlToDim || lvlToDim.isEmpty()) {
645 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
646 }
647 return parser.getChecked<SparseTensorEncodingAttr>(
648 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
649 dimSlices);
650}
651
652void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
653 auto map = static_cast<AffineMap>(getDimToLvl());
654 // Empty affine map indicates identity map
655 if (!map)
656 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
657 printer << "<{ map = ";
658 printSymbols(map, printer);
659 printer << '(';
660 printDimensions(map, printer, getDimSlices());
661 printer << ") -> (";
662 printLevels(map, printer, getLvlTypes());
663 printer << ')';
664 // Print remaining members only for non-default values.
665 if (getPosWidth())
666 printer << ", posWidth = " << getPosWidth();
667 if (getCrdWidth())
668 printer << ", crdWidth = " << getCrdWidth();
669 printer << " }>";
670}
671
672void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
673 AsmPrinter &printer) const {
674 if (map.getNumSymbols() == 0)
675 return;
676 printer << '[';
677 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
678 printer << 's' << i << ", ";
679 if (map.getNumSymbols() >= 1)
680 printer << 's' << map.getNumSymbols() - 1;
681 printer << ']';
682}
683
684void SparseTensorEncodingAttr::printDimensions(
685 AffineMap &map, AsmPrinter &printer,
686 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
687 if (!dimSlices.empty()) {
688 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
689 printer << 'd' << i << " : " << dimSlices[i] << ", ";
690 if (map.getNumDims() >= 1) {
691 printer << 'd' << map.getNumDims() - 1 << " : "
692 << dimSlices[map.getNumDims() - 1];
693 }
694 } else {
695 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
696 printer << 'd' << i << ", ";
697 if (map.getNumDims() >= 1)
698 printer << 'd' << map.getNumDims() - 1;
699 }
700}
701
702void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
703 ArrayRef<LevelType> lvlTypes) const {
704 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
705 map.getResult(i).print(printer.getStream());
706 printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
707 }
708 if (map.getNumResults() >= 1) {
709 auto lastIndex = map.getNumResults() - 1;
710 map.getResult(lastIndex).print(printer.getStream());
711 printer << " : " << toMLIRString(lvlTypes[lastIndex]);
712 }
713}
714
715LogicalResult SparseTensorEncodingAttr::verify(
716 function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes,
717 AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth,
718 unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
719 if (!acceptBitWidth(posWidth))
720 return emitError() << "unexpected position bitwidth: " << posWidth;
721 if (!acceptBitWidth(crdWidth))
722 return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
723 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
724 it != std::end(lvlTypes)) {
725 if (it == lvlTypes.begin() ||
726 (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1))))
727 return emitError() << "expected compressed or loose_compressed level "
728 "before singleton level";
729 if (!std::all_of(it, lvlTypes.end(),
730 [](LevelType i) { return isSingletonLT(i); }))
731 return emitError() << "expected all singleton lvlTypes "
732 "following a singleton level";
733 // We can potentially support mixed SoA/AoS singleton levels.
734 if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) {
735 return it->isa<LevelPropNonDefault::SoA>() ==
736 i.isa<LevelPropNonDefault::SoA>();
737 })) {
738 return emitError() << "expected all singleton lvlTypes stored in the "
739 "same memory layout (SoA vs AoS).";
740 }
741 }
742
743 auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
744 if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT))
745 return emitError() << "Batch lvlType can only be leading levels.";
746
747 // SoA property can only be applied on singleton level.
748 auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) {
749 return lt.isa<LevelPropNonDefault::SoA>();
750 });
751 if (llvm::any_of(soaLvls, [](LevelType lt) {
752 return !lt.isa<LevelFormat::Singleton>();
753 })) {
754 return emitError() << "SoA is only applicable to singleton lvlTypes.";
755 }
756
757 // TODO: audit formats that actually are supported by backend.
758 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT);
759 it != std::end(lvlTypes)) {
760 if (it != lvlTypes.end() - 1)
761 return emitError() << "expected n_out_of_m to be the last level type";
762 if (!std::all_of(lvlTypes.begin(), it,
763 [](LevelType i) { return isDenseLT(i); }))
764 return emitError() << "expected all dense lvlTypes "
765 "before a n_out_of_m level";
766 if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) {
767 if (!isBlockSparsity(dimToLvl)) {
768 return emitError()
769 << "expected 1xm block structure for n_out_of_m level";
770 }
771 auto sizes = getBlockSize(dimToLvl);
772 unsigned coefficient = 0;
773 for (const auto &elem : sizes) {
774 if (elem != 0) {
775 if (elem != coefficient && coefficient != 0) {
776 return emitError() << "expected only one blocked level "
777 "with the same coefficients";
778 }
779 coefficient = elem;
780 }
781 }
782 if (coefficient != getM(*it)) {
783 return emitError() << "expected coeffiencts of Affine expressions "
784 "to be equal to m of n_out_of_m level";
785 }
786 }
787 }
788 // Before we can check that the level-rank is consistent/coherent
789 // across all fields, we need to define it. The source-of-truth for
790 // the `getLvlRank` method is the length of the level-types array,
791 // since it must always be provided and have full rank; therefore we
792 // use that same source-of-truth here.
793 const Level lvlRank = lvlTypes.size();
794 if (lvlRank == 0)
795 return emitError() << "expected a non-empty array for lvlTypes";
796 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
797 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
798 if (dimToLvl) {
799 if (dimToLvl.getNumResults() != lvlRank)
800 return emitError()
801 << "level-rank mismatch between dimToLvl and lvlTypes: "
802 << dimToLvl.getNumResults() << " != " << lvlRank;
803 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
804 // Symbols can't be inferred but are acceptable.
805 if (!inferRes && dimToLvl.getNumSymbols() == 0)
806 return emitError() << "failed to infer lvlToDim from dimToLvl";
807 if (lvlToDim && (inferRes != lvlToDim))
808 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
809 if (dimRank > lvlRank)
810 return emitError() << "unexpected dimToLvl mapping from " << dimRank
811 << " to " << lvlRank;
812 }
813 if (!dimSlices.empty()) {
814 if (dimSlices.size() != dimRank)
815 return emitError()
816 << "dimension-rank mismatch between dimSlices and dimToLvl: "
817 << dimSlices.size() << " != " << dimRank;
818 // Compiler support for `dimSlices` currently requires that the two
819 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
820 if (dimRank != lvlRank)
821 return emitError()
822 << "dimSlices expected dimension-rank to match level-rank: "
823 << dimRank << " != " << lvlRank;
824 }
825 return success();
826}
827
828LogicalResult SparseTensorEncodingAttr::verifyEncoding(
829 ArrayRef<Size> dimShape, Type elementType,
830 function_ref<InFlightDiagnostic()> emitError) const {
831 // Check structural integrity. In particular, this ensures that the
832 // level-rank is coherent across all the fields.
833 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
834 getPosWidth(), getCrdWidth(), getDimSlices())))
835 return failure();
836 // Check integrity with tensor type specifics. In particular, we
837 // need only check that the dimension-rank of the tensor agrees with
838 // the dimension-rank of the encoding.
839 const Dimension dimRank = dimShape.size();
840 if (dimRank == 0)
841 return emitError() << "expected non-scalar sparse tensor";
842 if (getDimRank() != dimRank)
843 return emitError()
844 << "dimension-rank mismatch between encoding and tensor shape: "
845 << getDimRank() << " != " << dimRank;
846 return success();
847}
848
849//===----------------------------------------------------------------------===//
850// SparseTensorType Methods.
851//===----------------------------------------------------------------------===//
852
853bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
854 bool isUnique) const {
855 if (!hasEncoding())
856 return false;
857 if (!isCompressedLvl(l: startLvl) && !isLooseCompressedLvl(l: startLvl))
858 return false;
859 for (Level l = startLvl + 1; l < lvlRank; ++l)
860 if (!isSingletonLvl(l))
861 return false;
862 // If isUnique is true, then make sure that the last level is unique,
863 // that is, when lvlRank == 1, the only compressed level is unique,
864 // and when lvlRank > 1, the last singleton is unique.
865 return !isUnique || isUniqueLvl(l: lvlRank - 1);
866}
867
868Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const {
869 SmallVector<COOSegment> coo = getCOOSegments();
870 assert(coo.size() == 1 || coo.empty());
871 if (!coo.empty() && coo.front().isAoS()) {
872 return coo.front().lvlRange.first;
873 }
874 return lvlRank;
875}
876
877SmallVector<COOSegment>
878mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
879 SmallVector<COOSegment> ret;
880 if (!hasEncoding() || lvlRank <= 1)
881 return ret;
882
883 ArrayRef<LevelType> lts = getLvlTypes();
884 Level l = 0;
885 while (l < lvlRank) {
886 auto lt = lts[l];
887 if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
888 auto cur = lts.begin() + l;
889 auto end = std::find_if(first: cur + 1, last: lts.end(), pred: [](LevelType lt) {
890 return !lt.isa<LevelFormat::Singleton>();
891 });
892 unsigned cooLen = std::distance(first: cur, last: end);
893 if (cooLen > 1) {
894 // To support mixed SoA/AoS COO, we should break the segment when the
895 // storage scheme changes, for now we faithfully assume that all
896 // consecutive singleton levels have the same storage format as verified
897 // STEA.
898 ret.push_back(Elt: COOSegment{.lvlRange: std::make_pair(x&: l, y: l + cooLen),
899 .isSoA: lts[l + 1].isa<LevelPropNonDefault::SoA>()});
900 }
901 l += cooLen;
902 } else {
903 l++;
904 }
905 }
906 return ret;
907}
908
909RankedTensorType
910mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
911 SmallVector<LevelType> lvlTypes;
912 lvlTypes.reserve(N: lvlRank);
913 // A non-unique compressed level at beginning (unless this is
914 // also the last level, then it is unique).
915 lvlTypes.push_back(
916 Elt: *buildLevelType(lf: LevelFormat::Compressed, ordered, unique: lvlRank == 1));
917 if (lvlRank > 1) {
918 // Followed by n-2 non-unique singleton levels.
919 std::fill_n(std::back_inserter(x&: lvlTypes), lvlRank - 2,
920 *buildLevelType(lf: LevelFormat::Singleton, ordered, unique: false));
921 // Ends by a unique singleton level.
922 lvlTypes.push_back(Elt: *buildLevelType(lf: LevelFormat::Singleton, ordered, unique: true));
923 }
924 auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
925 getDimToLvl(), getLvlToDim(),
926 getPosWidth(), getCrdWidth());
927 return RankedTensorType::get(getDimShape(), getElementType(), enc);
928}
929
930//===----------------------------------------------------------------------===//
931// Convenience Methods.
932//===----------------------------------------------------------------------===//
933
934SparseTensorEncodingAttr
935mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
936 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
937 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
938 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
939 return mdtp.getEncoding();
940 return nullptr;
941}
942
943AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
944 MLIRContext *context) {
945 auto map = static_cast<AffineMap>(dimToLvl);
946 AffineMap lvlToDim;
947 // Return an empty lvlToDim when inference is not successful.
948 if (!map || map.getNumSymbols() != 0) {
949 lvlToDim = AffineMap();
950 } else if (map.isPermutation()) {
951 lvlToDim = inversePermutation(map);
952 } else if (isBlockSparsity(dimToLvl: map)) {
953 lvlToDim = inverseBlockSparsity(dimToLvl: map, context);
954 }
955 return lvlToDim;
956}
957
958AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
959 MLIRContext *context) {
960 SmallVector<AffineExpr> lvlExprs;
961 auto numLvls = dimToLvl.getNumResults();
962 lvlExprs.reserve(N: numLvls);
963 // lvlExprComponents stores information of the floordiv and mod operations
964 // applied to the same dimension, so as to build the lvlToDim map.
965 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
966 for (unsigned i = 0, n = numLvls; i < n; i++) {
967 auto result = dimToLvl.getResult(idx: i);
968 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) {
969 if (result.getKind() == AffineExprKind::FloorDiv) {
970 // Position of the dimension in dimToLvl.
971 auto pos = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()).getPosition();
972 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
973 "expected only one floordiv for each dimension");
974 SmallVector<AffineExpr, 3> components;
975 // Level variable for floordiv.
976 components.push_back(Elt: getAffineDimExpr(position: i, context));
977 // Multiplier.
978 components.push_back(Elt: binOp.getRHS());
979 // Map key is the position of the dimension.
980 lvlExprComponents[pos] = components;
981 } else if (result.getKind() == AffineExprKind::Mod) {
982 auto pos = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()).getPosition();
983 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
984 "expected floordiv before mod");
985 // Add level variable for mod to the same vector
986 // of the corresponding floordiv.
987 lvlExprComponents[pos].push_back(Elt: getAffineDimExpr(position: i, context));
988 } else {
989 assert(false && "expected floordiv or mod");
990 }
991 } else {
992 lvlExprs.push_back(Elt: getAffineDimExpr(position: i, context));
993 }
994 }
995 // Build lvlExprs from lvlExprComponents.
996 // For example, for il = i floordiv 2 and ii = i mod 2, the components
997 // would be [il, 2, ii]. It could be used to build the AffineExpr
998 // i = il * 2 + ii in lvlToDim.
999 for (auto &components : lvlExprComponents) {
1000 assert(components.second.size() == 3 &&
1001 "expected 3 components to build lvlExprs");
1002 auto mulOp = getAffineBinaryOpExpr(
1003 kind: AffineExprKind::Mul, lhs: components.second[0], rhs: components.second[1]);
1004 auto addOp =
1005 getAffineBinaryOpExpr(kind: AffineExprKind::Add, lhs: mulOp, rhs: components.second[2]);
1006 lvlExprs.push_back(Elt: addOp);
1007 }
1008 return dimToLvl.get(dimCount: dimToLvl.getNumResults(), symbolCount: 0, results: lvlExprs, context);
1009}
1010
1011SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
1012 assert(isBlockSparsity(dimToLvl) &&
1013 "expected dimToLvl to be block sparsity for calling getBlockSize");
1014 SmallVector<unsigned> blockSize;
1015 for (auto result : dimToLvl.getResults()) {
1016 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) {
1017 if (result.getKind() == AffineExprKind::Mod) {
1018 blockSize.push_back(
1019 Elt: dyn_cast<AffineConstantExpr>(Val: binOp.getRHS()).getValue());
1020 }
1021 } else {
1022 blockSize.push_back(Elt: 0);
1023 }
1024 }
1025 return blockSize;
1026}
1027
1028bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
1029 if (!dimToLvl)
1030 return false;
1031 std::map<unsigned, int64_t> coeffientMap;
1032 bool hasBlock = false;
1033 for (auto result : dimToLvl.getResults()) {
1034 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) {
1035 // Check for "dim op const".
1036 auto dimOp = dyn_cast<AffineDimExpr>(Val: binOp.getLHS());
1037 auto conOp = dyn_cast<AffineConstantExpr>(Val: binOp.getRHS());
1038 if (!dimOp || !conOp || conOp.getValue() <= 0)
1039 return false;
1040 // Inspect "dim / const" or "dim % const".
1041 auto pos = dimOp.getPosition();
1042 if (binOp.getKind() == AffineExprKind::FloorDiv) {
1043 // Expect only one floordiv for each dimension.
1044 if (coeffientMap.find(x: pos) != coeffientMap.end())
1045 return false;
1046 // Record coefficient of the floordiv.
1047 coeffientMap[pos] = conOp.getValue();
1048 } else if (binOp.getKind() == AffineExprKind::Mod) {
1049 // Expect floordiv before mod.
1050 if (coeffientMap.find(x: pos) == coeffientMap.end())
1051 return false;
1052 // Expect mod to have the same coefficient as floordiv.
1053 if (conOp.getValue() != coeffientMap[pos])
1054 return false;
1055 hasBlock = true;
1056 } else {
1057 return false;
1058 }
1059 } else if (auto dimOp = dyn_cast<AffineDimExpr>(Val&: result)) {
1060 auto pos = dimOp.getPosition();
1061 // Expect dim to be unset.
1062 if (coeffientMap.find(x: pos) != coeffientMap.end())
1063 return false;
1064 coeffientMap[pos] = 0;
1065 } else {
1066 return false;
1067 }
1068 }
1069 return hasBlock;
1070}
1071
1072bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
1073 auto hasNonIdentityMap = [](Value v) {
1074 auto stt = tryGetSparseTensorType(v);
1075 return stt && !stt->isIdentity();
1076 };
1077
1078 return llvm::any_of(Range: op->getOperands(), P: hasNonIdentityMap) ||
1079 llvm::any_of(Range: op->getResults(), P: hasNonIdentityMap);
1080}
1081
1082Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
1083 if (enc) {
1084 assert(enc.isPermutation() && "Non permutation map not supported");
1085 if (const auto dimToLvl = enc.getDimToLvl())
1086 return dimToLvl.getDimPosition(l);
1087 }
1088 return l;
1089}
1090
1091Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) {
1092 if (enc) {
1093 assert(enc.isPermutation() && "Non permutation map not supported");
1094 if (const auto lvlToDim = enc.getLvlToDim())
1095 return lvlToDim.getDimPosition(d);
1096 }
1097 return d;
1098}
1099
1100/// We normalized sparse tensor encoding attribute by always using
1101/// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well
1102/// as other variants) lead to the same storage specifier type, and stripping
1103/// irrelevant fields that do not alter the sparse tensor memory layout.
1104static SparseTensorEncodingAttr
1105getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
1106 SmallVector<LevelType> lts;
1107 for (auto lt : enc.getLvlTypes())
1108 lts.push_back(lt.stripStorageIrrelevantProperties());
1109
1110 return SparseTensorEncodingAttr::get(
1111 enc.getContext(), lts,
1112 AffineMap(), // dimToLvl (irrelevant to storage specifier)
1113 AffineMap(), // lvlToDim (irrelevant to storage specifier)
1114 // Always use `index` for memSize and lvlSize instead of reusing
1115 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
1116 // value for different bitwidth, it also avoids casting between index and
1117 // integer (returned by DimOp)
1118 0, 0, enc.getDimSlices());
1119}
1120
1121StorageSpecifierType
1122StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
1123 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
1124}
1125
1126//===----------------------------------------------------------------------===//
1127// SparseTensorDialect Operations.
1128//===----------------------------------------------------------------------===//
1129
1130static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1131 return success(isSuccess: lvl < getSparseTensorType(val: tensor).getLvlRank());
1132}
1133
1134static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1135 const Type etp = getMemRefType(mem).getElementType();
1136 return success(isSuccess: width == 0 ? etp.isIndex() : etp.isInteger(width));
1137}
1138
1139static LogicalResult verifySparsifierGetterSetter(
1140 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1141 TypedValue<StorageSpecifierType> md, Operation *op) {
1142 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1143 return op->emitError(
1144 message: "redundant level argument for querying value memory size");
1145 }
1146
1147 const auto enc = md.getType().getEncoding();
1148 const Level lvlRank = enc.getLvlRank();
1149
1150 if (mdKind == StorageSpecifierKind::DimOffset ||
1151 mdKind == StorageSpecifierKind::DimStride)
1152 if (!enc.isSlice())
1153 return op->emitError(message: "requested slice data on non-slice tensor");
1154
1155 if (mdKind != StorageSpecifierKind::ValMemSize) {
1156 if (!lvl)
1157 return op->emitError(message: "missing level argument");
1158
1159 const Level l = lvl.value();
1160 if (l >= lvlRank)
1161 return op->emitError(message: "requested level is out of bounds");
1162
1163 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1164 return op->emitError(
1165 message: "requested position memory size on a singleton level");
1166 }
1167 return success();
1168}
1169
1170static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1171 switch (kind) {
1172 case SparseTensorFieldKind::CrdMemRef:
1173 return stt.getCrdType();
1174 case SparseTensorFieldKind::PosMemRef:
1175 return stt.getPosType();
1176 case SparseTensorFieldKind::ValMemRef:
1177 return stt.getElementType();
1178 case SparseTensorFieldKind::StorageSpec:
1179 return nullptr;
1180 }
1181 llvm_unreachable("Unrecognizable FieldKind");
1182}
1183
1184static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1185 SparseTensorType stt,
1186 RankedTensorType valTp,
1187 TypeRange lvlTps) {
1188 if (requiresStaticShape && !stt.hasStaticDimShape())
1189 return op->emitError(message: "the sparse-tensor must have static shape");
1190 if (!stt.hasEncoding())
1191 return op->emitError(message: "the sparse-tensor must have an encoding attribute");
1192
1193 // Verifies the trailing COO.
1194 Level cooStartLvl = stt.getAoSCOOStart();
1195 if (cooStartLvl < stt.getLvlRank()) {
1196 // We only supports trailing COO for now, must be the last input.
1197 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1198 // The coordinates should be in shape of <? x rank>
1199 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1200 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1201 op->emitError(message: "input/output trailing COO level-ranks don't match");
1202 }
1203 }
1204
1205 // Verifies that all types match.
1206 StorageLayout layout(stt.getEncoding());
1207 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1208 return op->emitError(message: "inconsistent number of fields between input/output");
1209
1210 unsigned idx = 0;
1211 bool misMatch = false;
1212 layout.foreachField(callback: [&idx, &misMatch, stt, valTp,
1213 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1214 Level lvl, LevelType lt) -> bool {
1215 if (fKind == SparseTensorFieldKind::StorageSpec)
1216 return true;
1217
1218 Type inputTp = nullptr;
1219 if (fKind == SparseTensorFieldKind::ValMemRef) {
1220 inputTp = valTp;
1221 } else {
1222 assert(fid == idx && stt.getLvlType(lvl) == lt);
1223 inputTp = lvlTps[idx++];
1224 }
1225 // The input element type and expected element type should match.
1226 Type inpElemTp = llvm::cast<TensorType>(Val&: inputTp).getElementType();
1227 Type expElemTp = getFieldElemType(stt, kind: fKind);
1228 if (inpElemTp != expElemTp) {
1229 misMatch = true;
1230 return false; // to terminate the iteration
1231 }
1232 return true;
1233 });
1234
1235 if (misMatch)
1236 return op->emitError(message: "input/output element-types don't match");
1237 return success();
1238}
1239
1240LogicalResult AssembleOp::verify() {
1241 const auto valuesTp = getRankedTensorType(getValues());
1242 const auto lvlsTp = getLevels().getTypes();
1243 const auto resTp = getSparseTensorType(getResult());
1244 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1245}
1246
1247LogicalResult DisassembleOp::verify() {
1248 if (getOutValues().getType() != getRetValues().getType())
1249 return emitError("output values and return value type mismatch");
1250
1251 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1252 if (ot.getType() != rt.getType())
1253 return emitError("output levels and return levels type mismatch");
1254
1255 const auto valuesTp = getRankedTensorType(getRetValues());
1256 const auto lvlsTp = getRetLevels().getTypes();
1257 const auto srcTp = getSparseTensorType(getTensor());
1258 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1259}
1260
1261LogicalResult ConvertOp::verify() {
1262 if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1263 if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1264 if (tp1.getRank() != tp2.getRank())
1265 return emitError("unexpected conversion mismatch in rank");
1266 auto dstEnc =
1267 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1268 if (dstEnc && dstEnc.isSlice())
1269 return emitError("cannot convert to a sparse tensor slice");
1270
1271 auto shape1 = tp1.getShape();
1272 auto shape2 = tp2.getShape();
1273 // Accept size matches between the source and the destination type
1274 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1275 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1276 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1277 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1278 return emitError("unexpected conversion mismatch in dimension ") << d;
1279 return success();
1280 }
1281 }
1282 return emitError("unexpected type in convert");
1283}
1284
1285OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1286 if (getType() == getSource().getType())
1287 return getSource();
1288 return {};
1289}
1290
1291bool ConvertOp::needsExtraSort() {
1292 SparseTensorType srcStt = getSparseTensorType(getSource());
1293 SparseTensorType dstStt = getSparseTensorType(getDest());
1294
1295 // We do not need an extra sort when returning unordered sparse tensors or
1296 // dense tensor since dense tensor support random access.
1297 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1298 return false;
1299
1300 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1301 srcStt.hasSameDimToLvl(dstStt)) {
1302 return false;
1303 }
1304
1305 // Source and dest tensors are ordered in different ways. We only do direct
1306 // dense to sparse conversion when the dense input is defined by a sparse
1307 // constant. Note that we can theoretically always directly convert from dense
1308 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1309 // performance.
1310 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1311 if (isa<SparseElementsAttr>(constOp.getValue()))
1312 return false;
1313
1314 return true;
1315}
1316
1317LogicalResult CrdTranslateOp::verify() {
1318 uint64_t inRank = getEncoder().getLvlRank();
1319 uint64_t outRank = getEncoder().getDimRank();
1320
1321 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1322 std::swap(inRank, outRank);
1323
1324 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1325 return emitError("Coordinate rank mismatch with encoding");
1326
1327 return success();
1328}
1329
1330LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1331 SmallVectorImpl<OpFoldResult> &results) {
1332 if (getEncoder().isIdentity()) {
1333 results.assign(getInCrds().begin(), getInCrds().end());
1334 return success();
1335 }
1336 if (getEncoder().isPermutation()) {
1337 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1338 ? getEncoder().getDimToLvl()
1339 : getEncoder().getLvlToDim();
1340 for (AffineExpr exp : perm.getResults())
1341 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1342 return success();
1343 }
1344
1345 // Fuse dim2lvl/lvl2dim pairs.
1346 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1347 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1348 return v.getDefiningOp() == def;
1349 });
1350 if (!sameDef)
1351 return failure();
1352
1353 bool oppositeDir = def.getDirection() != getDirection();
1354 bool sameOracle =
1355 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1356 bool sameCount = def.getNumResults() == getInCrds().size();
1357 if (!oppositeDir || !sameOracle || !sameCount)
1358 return failure();
1359
1360 // The definition produces the coordinates in the same order as the input
1361 // coordinates.
1362 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1363 [](auto valuePair) {
1364 auto [lhs, rhs] = valuePair;
1365 return lhs == rhs;
1366 });
1367
1368 if (!sameOrder)
1369 return failure();
1370 // l1 = dim2lvl (lvl2dim l0)
1371 // ==> l0
1372 results.append(def.getInCrds().begin(), def.getInCrds().end());
1373 return success();
1374}
1375
1376void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1377 int64_t index) {
1378 Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1379 return build(builder, state, source, val);
1380}
1381
1382LogicalResult LvlOp::verify() {
1383 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1384 auto stt = getSparseTensorType(getSource());
1385 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1386 emitError("Level index exceeds the rank of the input sparse tensor");
1387 }
1388 return success();
1389}
1390
1391std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1392 return getConstantIntValue(getIndex());
1393}
1394
1395Speculation::Speculatability LvlOp::getSpeculatability() {
1396 auto constantIndex = getConstantLvlIndex();
1397 if (!constantIndex)
1398 return Speculation::NotSpeculatable;
1399
1400 assert(constantIndex <
1401 cast<RankedTensorType>(getSource().getType()).getRank());
1402 return Speculation::Speculatable;
1403}
1404
1405OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1406 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1407 if (!lvlIndex)
1408 return {};
1409
1410 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1411 auto stt = getSparseTensorType(getSource());
1412 if (lvl >= stt.getLvlRank()) {
1413 // Follows the same convention used by tensor.dim operation. Out of bound
1414 // indices produce undefined behavior but are still valid IR. Don't choke on
1415 // them.
1416 return {};
1417 }
1418
1419 // Helper lambda to build an IndexAttr.
1420 auto getIndexAttr = [this](int64_t lvlSz) {
1421 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1422 };
1423
1424 SmallVector<Size> lvlShape = stt.getLvlShape();
1425 if (!ShapedType::isDynamic(lvlShape[lvl]))
1426 return getIndexAttr(lvlShape[lvl]);
1427
1428 return {};
1429}
1430
1431void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1432 SparseTensorEncodingAttr dstEnc, Value source) {
1433 auto srcStt = getSparseTensorType(source);
1434 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1435 SmallVector<int64_t> dstDimShape =
1436 dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1437 auto dstTp =
1438 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1439 return build(odsBuilder, odsState, dstTp, source);
1440}
1441
1442LogicalResult ReinterpretMapOp::verify() {
1443 auto srcStt = getSparseTensorType(getSource());
1444 auto dstStt = getSparseTensorType(getDest());
1445 ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes();
1446 ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes();
1447
1448 if (srcLvlTps.size() != dstLvlTps.size())
1449 return emitError("Level rank mismatch between source/dest tensors");
1450
1451 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1452 if (srcLvlTp != dstLvlTp)
1453 return emitError("Level type mismatch between source/dest tensors");
1454
1455 if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1456 srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1457 return emitError("Crd/Pos width mismatch between source/dest tensors");
1458 }
1459
1460 if (srcStt.getElementType() != dstStt.getElementType())
1461 return emitError("Element type mismatch between source/dest tensors");
1462
1463 SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1464 SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1465 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1466 if (srcLvlSz != dstLvlSz) {
1467 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1468 // compatible to <3x4>? For now, we require all the level sizes to be
1469 // *exactly* matched for simplicity.
1470 return emitError("Level size mismatch between source/dest tensors");
1471 }
1472 }
1473
1474 return success();
1475}
1476
1477OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1478 if (getSource().getType() == getDest().getType())
1479 return getSource();
1480
1481 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1482 // A -> B, B -> A ==> A
1483 if (def.getSource().getType() == getDest().getType())
1484 return def.getSource();
1485 }
1486 return {};
1487}
1488
1489template <typename ToBufferOp>
1490static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1491 OpaqueProperties prop,
1492 RegionRange region,
1493 SmallVectorImpl<mlir::Type> &ret) {
1494 typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1495 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1496 Type elemTp = nullptr;
1497 bool withStride = false;
1498 if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1499 elemTp = stt.getPosType();
1500 } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1501 std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1502 elemTp = stt.getCrdType();
1503 if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1504 withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1505 } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1506 elemTp = stt.getElementType();
1507 }
1508
1509 assert(elemTp && "unhandled operation.");
1510 SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1511 bufShape.push_back(ShapedType::kDynamic);
1512
1513 auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1514 stt.getContext(), ShapedType::kDynamic,
1515 {ShapedType::kDynamic})
1516 : StridedLayoutAttr();
1517 ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1518 return success();
1519}
1520
1521LogicalResult ToPositionsOp::verify() {
1522 auto stt = getSparseTensorType(getTensor());
1523 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1524 return emitError("requested level is out of bounds");
1525 if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
1526 return emitError("unexpected type for positions");
1527 return success();
1528}
1529
1530LogicalResult
1531ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1532 ValueRange ops, DictionaryAttr attr,
1533 OpaqueProperties prop, RegionRange region,
1534 SmallVectorImpl<mlir::Type> &ret) {
1535 return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1536}
1537
1538LogicalResult ToCoordinatesOp::verify() {
1539 auto stt = getSparseTensorType(getTensor());
1540 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1541 return emitError("requested level is out of bounds");
1542 if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
1543 return emitError("unexpected type for coordinates");
1544 return success();
1545}
1546
1547LogicalResult
1548ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1549 ValueRange ops, DictionaryAttr attr,
1550 OpaqueProperties prop, RegionRange region,
1551 SmallVectorImpl<mlir::Type> &ret) {
1552 return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1553}
1554
1555LogicalResult ToCoordinatesBufferOp::verify() {
1556 auto stt = getSparseTensorType(getTensor());
1557 if (stt.getAoSCOOStart() >= stt.getLvlRank())
1558 return emitError("expected sparse tensor with a COO region");
1559 return success();
1560}
1561
1562LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1563 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1564 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1565 SmallVectorImpl<mlir::Type> &ret) {
1566 return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1567 ret);
1568}
1569
1570LogicalResult ToValuesOp::verify() {
1571 auto stt = getSparseTensorType(getTensor());
1572 auto mtp = getMemRefType(getResult());
1573 if (stt.getElementType() != mtp.getElementType())
1574 return emitError("unexpected mismatch in element types");
1575 return success();
1576}
1577
1578LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1579 std::optional<Location> loc,
1580 ValueRange ops, DictionaryAttr attr,
1581 OpaqueProperties prop,
1582 RegionRange region,
1583 SmallVectorImpl<mlir::Type> &ret) {
1584 return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1585}
1586
1587LogicalResult ToSliceOffsetOp::verify() {
1588 auto rank = getRankedTensorType(getSlice()).getRank();
1589 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1590 return emitError("requested dimension out of bound");
1591 return success();
1592}
1593
1594LogicalResult ToSliceStrideOp::verify() {
1595 auto rank = getRankedTensorType(getSlice()).getRank();
1596 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1597 return emitError("requested dimension out of bound");
1598 return success();
1599}
1600
1601LogicalResult GetStorageSpecifierOp::verify() {
1602 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1603 getSpecifier(), getOperation());
1604}
1605
1606template <typename SpecifierOp>
1607static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1608 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1609}
1610
1611OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1612 const StorageSpecifierKind kind = getSpecifierKind();
1613 const auto lvl = getLevel();
1614 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1615 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1616 return op.getValue();
1617 return {};
1618}
1619
1620LogicalResult SetStorageSpecifierOp::verify() {
1621 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1622 getSpecifier(), getOperation());
1623}
1624
1625template <class T>
1626static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1627 const char *regionName,
1628 TypeRange inputTypes, Type outputType) {
1629 unsigned numArgs = region.getNumArguments();
1630 unsigned expectedNum = inputTypes.size();
1631 if (numArgs != expectedNum)
1632 return op->emitError() << regionName << " region must have exactly "
1633 << expectedNum << " arguments";
1634
1635 for (unsigned i = 0; i < numArgs; i++) {
1636 Type typ = region.getArgument(i).getType();
1637 if (typ != inputTypes[i])
1638 return op->emitError() << regionName << " region argument " << (i + 1)
1639 << " type mismatch";
1640 }
1641 Operation *term = region.front().getTerminator();
1642 YieldOp yield = dyn_cast<YieldOp>(term);
1643 if (!yield)
1644 return op->emitError() << regionName
1645 << " region must end with sparse_tensor.yield";
1646 if (!yield.hasSingleResult() ||
1647 yield.getSingleResult().getType() != outputType)
1648 return op->emitError() << regionName << " region yield type mismatch";
1649
1650 return success();
1651}
1652
1653LogicalResult BinaryOp::verify() {
1654 NamedAttrList attrs = (*this)->getAttrs();
1655 Type leftType = getX().getType();
1656 Type rightType = getY().getType();
1657 Type outputType = getOutput().getType();
1658 Region &overlap = getOverlapRegion();
1659 Region &left = getLeftRegion();
1660 Region &right = getRightRegion();
1661
1662 // Check correct number of block arguments and return type for each
1663 // non-empty region.
1664 if (!overlap.empty()) {
1665 if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1666 TypeRange{leftType, rightType}, outputType)))
1667 return failure();
1668 }
1669 if (!left.empty()) {
1670 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1671 outputType)))
1672 return failure();
1673 } else if (getLeftIdentity()) {
1674 if (leftType != outputType)
1675 return emitError("left=identity requires first argument to have the same "
1676 "type as the output");
1677 }
1678 if (!right.empty()) {
1679 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1680 outputType)))
1681 return failure();
1682 } else if (getRightIdentity()) {
1683 if (rightType != outputType)
1684 return emitError("right=identity requires second argument to have the "
1685 "same type as the output");
1686 }
1687 return success();
1688}
1689
1690LogicalResult UnaryOp::verify() {
1691 Type inputType = getX().getType();
1692 Type outputType = getOutput().getType();
1693
1694 // Check correct number of block arguments and return type for each
1695 // non-empty region.
1696 Region &present = getPresentRegion();
1697 if (!present.empty()) {
1698 if (failed(verifyNumBlockArgs(this, present, "present",
1699 TypeRange{inputType}, outputType)))
1700 return failure();
1701 }
1702 Region &absent = getAbsentRegion();
1703 if (!absent.empty()) {
1704 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1705 outputType)))
1706 return failure();
1707 // Absent branch can only yield invariant values.
1708 Block *absentBlock = &absent.front();
1709 Block *parent = getOperation()->getBlock();
1710 Value absentVal =
1711 cast<YieldOp>(absentBlock->getTerminator()).getSingleResult();
1712 if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1713 if (arg.getOwner() == parent)
1714 return emitError("absent region cannot yield linalg argument");
1715 } else if (Operation *def = absentVal.getDefiningOp()) {
1716 if (!isa<arith::ConstantOp>(def) &&
1717 (def->getBlock() == absentBlock || def->getBlock() == parent))
1718 return emitError("absent region cannot yield locally computed value");
1719 }
1720 }
1721 return success();
1722}
1723
1724bool ConcatenateOp::needsExtraSort() {
1725 SparseTensorType dstStt = getSparseTensorType(*this);
1726 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1727 return false;
1728
1729 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1730 return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1731 });
1732 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1733 // in all input/output buffers, and all input/output buffers have the same
1734 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1735 // CSC matrices along column).
1736 bool directLowerable =
1737 allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1738 return !directLowerable;
1739}
1740
1741LogicalResult ConcatenateOp::verify() {
1742 const auto dstTp = getSparseTensorType(*this);
1743 const Dimension concatDim = getDimension();
1744 const Dimension dimRank = dstTp.getDimRank();
1745
1746 if (getInputs().size() <= 1)
1747 return emitError("Need at least two tensors to concatenate.");
1748
1749 if (concatDim >= dimRank)
1750 return emitError(llvm::formatv(
1751 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1752 concatDim, dimRank));
1753
1754 for (const auto &it : llvm::enumerate(getInputs())) {
1755 const auto i = it.index();
1756 const auto srcTp = getSparseTensorType(it.value());
1757 if (srcTp.hasDynamicDimShape())
1758 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1759 const Dimension srcDimRank = srcTp.getDimRank();
1760 if (srcDimRank != dimRank)
1761 return emitError(
1762 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1763 "from the output tensor (rank={2}).",
1764 i, srcDimRank, dimRank));
1765 }
1766
1767 for (Dimension d = 0; d < dimRank; d++) {
1768 const Size dstSh = dstTp.getDimShape()[d];
1769 if (d == concatDim) {
1770 if (!ShapedType::isDynamic(dstSh)) {
1771 // If we reach here, then all inputs have static shapes. So we
1772 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1773 // to avoid redundant assertions in the loop.
1774 Size sumSz = 0;
1775 for (const auto src : getInputs())
1776 sumSz += getSparseTensorType(src).getDimShape()[d];
1777 // If all dimension are statically known, the sum of all the input
1778 // dimensions should be equal to the output dimension.
1779 if (sumSz != dstSh)
1780 return emitError(
1781 "The concatenation dimension of the output tensor should be the "
1782 "sum of all the concatenation dimensions of the input tensors.");
1783 }
1784 } else {
1785 Size prev = dstSh;
1786 for (const auto src : getInputs()) {
1787 const auto sh = getSparseTensorType(src).getDimShape()[d];
1788 if (!ShapedType::isDynamic(prev) && sh != prev)
1789 return emitError("All dimensions (expect for the concatenating one) "
1790 "should be equal.");
1791 prev = sh;
1792 }
1793 }
1794 }
1795
1796 return success();
1797}
1798
1799void PushBackOp::build(OpBuilder &builder, OperationState &result,
1800 Value curSize, Value inBuffer, Value value) {
1801 build(builder, result, curSize, inBuffer, value, Value());
1802}
1803
1804LogicalResult PushBackOp::verify() {
1805 if (Value n = getN()) {
1806 std::optional<int64_t> nValue = getConstantIntValue(n);
1807 if (nValue && nValue.value() < 1)
1808 return emitOpError("n must be not less than 1");
1809 }
1810 return success();
1811}
1812
1813LogicalResult CompressOp::verify() {
1814 const auto stt = getSparseTensorType(getTensor());
1815 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1816 return emitOpError("incorrect number of coordinates");
1817 return success();
1818}
1819
1820void ForeachOp::build(
1821 OpBuilder &builder, OperationState &result, Value tensor,
1822 ValueRange initArgs, AffineMapAttr order,
1823 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1824 bodyBuilder) {
1825 build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1826 // Builds foreach body.
1827 if (!bodyBuilder)
1828 return;
1829 const auto stt = getSparseTensorType(tensor);
1830 const Dimension dimRank = stt.getDimRank();
1831
1832 // Starts with `dimRank`-many coordinates.
1833 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1834 // Followed by one value.
1835 blockArgTypes.push_back(stt.getElementType());
1836 // Followed by the reduction variables.
1837 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1838
1839 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1840
1841 OpBuilder::InsertionGuard guard(builder);
1842 auto &region = *result.regions.front();
1843 Block *bodyBlock =
1844 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1845 bodyBuilder(builder, result.location,
1846 bodyBlock->getArguments().slice(0, dimRank),
1847 bodyBlock->getArguments()[dimRank],
1848 bodyBlock->getArguments().drop_front(dimRank + 1));
1849}
1850
1851LogicalResult ForeachOp::verify() {
1852 const auto t = getSparseTensorType(getTensor());
1853 const Dimension dimRank = t.getDimRank();
1854 const auto args = getBody()->getArguments();
1855
1856 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1857 return emitError("Level traverse order does not match tensor's level rank");
1858
1859 if (dimRank + 1 + getInitArgs().size() != args.size())
1860 return emitError("Unmatched number of arguments in the block");
1861
1862 if (getNumResults() != getInitArgs().size())
1863 return emitError("Mismatch in number of init arguments and results");
1864
1865 if (getResultTypes() != getInitArgs().getTypes())
1866 return emitError("Mismatch in types of init arguments and results");
1867
1868 // Cannot mark this const, because the getters aren't.
1869 auto yield = cast<YieldOp>(getBody()->getTerminator());
1870 if (yield.getNumOperands() != getNumResults() ||
1871 yield.getOperands().getTypes() != getResultTypes())
1872 return emitError("Mismatch in types of yield values and results");
1873
1874 const auto iTp = IndexType::get(getContext());
1875 for (Dimension d = 0; d < dimRank; d++)
1876 if (args[d].getType() != iTp)
1877 emitError(
1878 llvm::formatv("Expecting Index type for argument at index {0}", d));
1879
1880 const auto elemTp = t.getElementType();
1881 const auto valueTp = args[dimRank].getType();
1882 if (elemTp != valueTp)
1883 emitError(llvm::formatv("Unmatched element type between input tensor and "
1884 "block argument, expected:{0}, got: {1}",
1885 elemTp, valueTp));
1886 return success();
1887}
1888
1889OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1890 if (getSparseTensorEncoding(getInputCoo().getType()) ==
1891 getSparseTensorEncoding(getResultCoo().getType()))
1892 return getInputCoo();
1893
1894 return {};
1895}
1896
1897LogicalResult ReorderCOOOp::verify() {
1898 SparseTensorType srcStt = getSparseTensorType(getInputCoo());
1899 SparseTensorType dstStt = getSparseTensorType(getResultCoo());
1900
1901 if (!srcStt.isCOOType() || !dstStt.isCOOType())
1902 emitError("Expected COO sparse tensors only");
1903
1904 if (!srcStt.hasSameDimToLvl(dstStt))
1905 emitError("Unmatched dim2lvl map between input and result COO");
1906
1907 if (srcStt.getPosType() != dstStt.getPosType() ||
1908 srcStt.getCrdType() != dstStt.getCrdType() ||
1909 srcStt.getElementType() != dstStt.getElementType())
1910 emitError("Unmatched storage format between input and result COO");
1911
1912 return success();
1913}
1914
1915LogicalResult ReduceOp::verify() {
1916 Type inputType = getX().getType();
1917 Region &formula = getRegion();
1918 return verifyNumBlockArgs(this, formula, "reduce",
1919 TypeRange{inputType, inputType}, inputType);
1920}
1921
1922LogicalResult SelectOp::verify() {
1923 Builder b(getContext());
1924 Type inputType = getX().getType();
1925 Type boolType = b.getI1Type();
1926 Region &formula = getRegion();
1927 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
1928 boolType);
1929}
1930
1931LogicalResult SortOp::verify() {
1932 AffineMap xPerm = getPermMap();
1933 uint64_t nx = xPerm.getNumDims();
1934 if (nx < 1)
1935 emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
1936
1937 if (!xPerm.isPermutation())
1938 emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
1939
1940 // We can't check the size of the buffers when n or buffer dimensions aren't
1941 // compile-time constants.
1942 std::optional<int64_t> cn = getConstantIntValue(getN());
1943 if (!cn)
1944 return success();
1945
1946 // Verify dimensions.
1947 const auto checkDim = [&](Value v, Size minSize, const char *message) {
1948 const Size sh = getMemRefType(v).getShape()[0];
1949 if (!ShapedType::isDynamic(sh) && sh < minSize)
1950 emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
1951 };
1952 uint64_t n = cn.value();
1953 uint64_t ny = 0;
1954 if (auto nyAttr = getNyAttr())
1955 ny = nyAttr.getInt();
1956 checkDim(getXy(), n * (nx + ny),
1957 "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
1958 for (Value opnd : getYs())
1959 checkDim(opnd, n, "Expected dimension(y) >= n");
1960
1961 return success();
1962}
1963
1964//===----------------------------------------------------------------------===//
1965// Sparse Tensor Iteration Operations.
1966//===----------------------------------------------------------------------===//
1967
1968IterSpaceType IteratorType::getIterSpaceType() const {
1969 return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
1970 getHiLvl());
1971}
1972
1973IteratorType IterSpaceType::getIteratorType() const {
1974 return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
1975}
1976
1977/// Parses a level range in the form "$lo `to` $hi"
1978/// or simply "$lo" if $hi - $lo = 1
1979static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
1980 Level &lvlHi) {
1981 if (parser.parseInteger(result&: lvlLo))
1982 return failure();
1983
1984 if (succeeded(result: parser.parseOptionalKeyword(keyword: "to"))) {
1985 if (parser.parseInteger(result&: lvlHi))
1986 return failure();
1987 } else {
1988 lvlHi = lvlLo + 1;
1989 }
1990
1991 if (lvlHi <= lvlLo)
1992 parser.emitError(loc: parser.getNameLoc(),
1993 message: "expect larger level upper bound than lower bound");
1994
1995 return success();
1996}
1997
1998/// Parses a level range in the form "$lo `to` $hi"
1999/// or simply "$lo" if $hi - $lo = 1
2000static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
2001 IntegerAttr &lvlHiAttr) {
2002 Level lvlLo, lvlHi;
2003 if (parseLevelRange(parser, lvlLo, lvlHi))
2004 return failure();
2005
2006 lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
2007 lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
2008 return success();
2009}
2010
2011/// Prints a level range in the form "$lo `to` $hi"
2012/// or simply "$lo" if $hi - $lo = 1
2013static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
2014
2015 if (lo + 1 == hi)
2016 p << lo;
2017 else
2018 p << lo << " to " << hi;
2019}
2020
2021/// Prints a level range in the form "$lo `to` $hi"
2022/// or simply "$lo" if $hi - $lo = 1
2023static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
2024 IntegerAttr lvlHi) {
2025 unsigned lo = lvlLo.getValue().getZExtValue();
2026 unsigned hi = lvlHi.getValue().getZExtValue();
2027 printLevelRange(p, lo, hi);
2028}
2029
2030LogicalResult ExtractIterSpaceOp::inferReturnTypes(
2031 MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
2032 DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
2033 SmallVectorImpl<mlir::Type> &ret) {
2034
2035 ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
2036 SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
2037 ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
2038 adaptor.getHiLvl()));
2039 return success();
2040}
2041
2042LogicalResult ExtractIterSpaceOp::verify() {
2043 if (getLoLvl() >= getHiLvl())
2044 return emitOpError("expected smaller level low than level high");
2045
2046 TypedValue<IteratorType> pIter = getParentIter();
2047 if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
2048 return emitOpError(
2049 "parent iterator should be specified iff level lower bound equals 0");
2050 }
2051
2052 if (pIter) {
2053 IterSpaceType spaceTp = getResultSpace().getType();
2054 if (pIter.getType().getEncoding() != spaceTp.getEncoding())
2055 return emitOpError(
2056 "mismatch in parent iterator encoding and iteration space encoding.");
2057
2058 if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
2059 return emitOpError("parent iterator should be used to extract an "
2060 "iteration space from a consecutive level.");
2061 }
2062
2063 return success();
2064}
2065
2066/// Materialize a single constant operation from a given attribute value with
2067/// the desired resultant type.
2068Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
2069 Attribute value, Type type,
2070 Location loc) {
2071 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
2072 return op;
2073 return nullptr;
2074}
2075
2076namespace {
2077struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
2078 using OpAsmDialectInterface::OpAsmDialectInterface;
2079
2080 AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2081 if (isa<SparseTensorEncodingAttr>(attr)) {
2082 os << "sparse";
2083 return AliasResult::OverridableAlias;
2084 }
2085 return AliasResult::NoAlias;
2086 }
2087};
2088} // namespace
2089
2090void SparseTensorDialect::initialize() {
2091 addInterface<SparseTensorAsmDialectInterface>();
2092 addAttributes<
2093#define GET_ATTRDEF_LIST
2094#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2095 >();
2096 addTypes<
2097#define GET_TYPEDEF_LIST
2098#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
2099 >();
2100 addOperations<
2101#define GET_OP_LIST
2102#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2103 >();
2104 declarePromisedInterfaces<
2105 bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp,
2106 NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp,
2107 ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>();
2108}
2109
2110#define GET_OP_CLASSES
2111#include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
2112
2113#include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
2114

source code of mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp