1 | //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <utility> |
10 | |
11 | #include "Detail/DimLvlMapParser.h" |
12 | |
13 | #include "mlir/Dialect/SparseTensor/IR/Enums.h" |
14 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
15 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
16 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
17 | |
18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
19 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
20 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
21 | #include "mlir/IR/Builders.h" |
22 | #include "mlir/IR/DialectImplementation.h" |
23 | #include "mlir/IR/Matchers.h" |
24 | #include "mlir/IR/OpImplementation.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | #include "llvm/ADT/TypeSwitch.h" |
27 | #include "llvm/Support/FormatVariadic.h" |
28 | |
29 | #define GET_ATTRDEF_CLASSES |
30 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" |
31 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" |
32 | |
33 | // Forward declarations, following custom print/parsing methods are referenced |
34 | // by the generated code for SparseTensorTypes.td. |
35 | static mlir::ParseResult parseLevelRange(mlir::AsmParser &, |
36 | mlir::sparse_tensor::Level &, |
37 | mlir::sparse_tensor::Level &); |
38 | static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, |
39 | mlir::sparse_tensor::Level); |
40 | |
41 | #define GET_TYPEDEF_CLASSES |
42 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" |
43 | |
44 | using namespace mlir; |
45 | using namespace mlir::sparse_tensor; |
46 | |
47 | // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as |
48 | // well. |
49 | namespace mlir::sparse_tensor { |
50 | llvm::hash_code hash_value(LevelType lt) { |
51 | return llvm::hash_value(static_cast<uint64_t>(lt)); |
52 | } |
53 | } // namespace mlir::sparse_tensor |
54 | |
55 | //===----------------------------------------------------------------------===// |
56 | // Local Convenience Methods. |
57 | //===----------------------------------------------------------------------===// |
58 | |
59 | static constexpr bool acceptBitWidth(unsigned bitWidth) { |
60 | switch (bitWidth) { |
61 | case 0: |
62 | case 8: |
63 | case 16: |
64 | case 32: |
65 | case 64: |
66 | return true; |
67 | default: |
68 | return false; |
69 | } |
70 | } |
71 | |
72 | static SmallVector<Size> |
73 | getSparseFieldShape(const SparseTensorEncodingAttr enc, |
74 | std::optional<ArrayRef<int64_t>> dimShape) { |
75 | assert(enc); |
76 | // With only encoding, we can not determine the static shape for leading |
77 | // batch levels, we therefore return a dynamic shape memref instead. |
78 | SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic); |
79 | if (dimShape.has_value()) { |
80 | // If the actual tensor shape is provided, we can then refine the leading |
81 | // batch dimension. |
82 | SmallVector<int64_t> lvlShape = |
83 | enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl); |
84 | memrefShape.assign(lvlShape.begin(), |
85 | lvlShape.begin() + enc.getBatchLvlRank()); |
86 | } |
87 | // Another dynamic dimension to store the sparse level. |
88 | memrefShape.push_back(ShapedType::kDynamic); |
89 | return memrefShape; |
90 | } |
91 | |
92 | //===----------------------------------------------------------------------===// |
93 | // SparseTensorDialect StorageLayout. |
94 | //===----------------------------------------------------------------------===// |
95 | |
96 | static constexpr Level kInvalidLevel = -1u; |
97 | static constexpr Level kInvalidFieldIndex = -1u; |
98 | static constexpr FieldIndex kDataFieldStartingIdx = 0; |
99 | |
100 | void StorageLayout::foreachField( |
101 | llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level, |
102 | LevelType)> |
103 | callback) const { |
104 | const auto lvlTypes = enc.getLvlTypes(); |
105 | const Level lvlRank = enc.getLvlRank(); |
106 | SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments(); |
107 | FieldIndex fieldIdx = kDataFieldStartingIdx; |
108 | |
109 | ArrayRef cooSegsRef = cooSegs; |
110 | // Per-level storage. |
111 | for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { |
112 | const auto lt = lvlTypes[l]; |
113 | if (isWithPosLT(lt)) { |
114 | if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) |
115 | return; |
116 | } |
117 | if (isWithCrdLT(lt)) { |
118 | if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) |
119 | return; |
120 | } |
121 | if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { |
122 | if (!cooSegsRef.front().isSoA) { |
123 | // AoS COO, all singletons are fused into one memrefs. Skips the entire |
124 | // COO segement. |
125 | l = cooSegsRef.front().lvlRange.second; |
126 | } else { |
127 | // SoA COO, each singleton level has one memref. |
128 | l++; |
129 | } |
130 | // Expire handled COO segment. |
131 | cooSegsRef = cooSegsRef.drop_front(); |
132 | } else { |
133 | // Non COO levels. |
134 | l++; |
135 | } |
136 | } |
137 | // The values array. |
138 | if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, |
139 | LevelFormat::Undef))) |
140 | return; |
141 | // Put metadata at the end. |
142 | if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, |
143 | LevelFormat::Undef))) |
144 | return; |
145 | } |
146 | |
147 | void sparse_tensor::foreachFieldAndTypeInSparseTensor( |
148 | SparseTensorType stt, |
149 | llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level, |
150 | LevelType)> |
151 | callback) { |
152 | assert(stt.hasEncoding()); |
153 | |
154 | SmallVector<int64_t> memrefShape = |
155 | getSparseFieldShape(stt.getEncoding(), stt.getDimShape()); |
156 | |
157 | const Type specType = StorageSpecifierType::get(stt.getEncoding()); |
158 | // memref<[batch] x ? x pos> positions |
159 | const Type posMemType = MemRefType::get(memrefShape, stt.getPosType()); |
160 | // memref<[batch] x ? x crd> coordinates |
161 | const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType()); |
162 | // memref<[batch] x ? x eltType> values |
163 | const Type valMemType = MemRefType::get(memrefShape, stt.getElementType()); |
164 | |
165 | StorageLayout(stt).foreachField(callback: [specType, posMemType, crdMemType, valMemType, |
166 | callback](FieldIndex fieldIdx, |
167 | SparseTensorFieldKind fieldKind, |
168 | Level lvl, LevelType lt) -> bool { |
169 | switch (fieldKind) { |
170 | case SparseTensorFieldKind::StorageSpec: |
171 | return callback(specType, fieldIdx, fieldKind, lvl, lt); |
172 | case SparseTensorFieldKind::PosMemRef: |
173 | return callback(posMemType, fieldIdx, fieldKind, lvl, lt); |
174 | case SparseTensorFieldKind::CrdMemRef: |
175 | return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); |
176 | case SparseTensorFieldKind::ValMemRef: |
177 | return callback(valMemType, fieldIdx, fieldKind, lvl, lt); |
178 | }; |
179 | llvm_unreachable("unrecognized field kind" ); |
180 | }); |
181 | } |
182 | |
183 | unsigned StorageLayout::getNumFields() const { |
184 | unsigned numFields = 0; |
185 | foreachField(callback: [&numFields](FieldIndex, SparseTensorFieldKind, Level, |
186 | LevelType) -> bool { |
187 | numFields++; |
188 | return true; |
189 | }); |
190 | return numFields; |
191 | } |
192 | |
193 | unsigned StorageLayout::getNumDataFields() const { |
194 | unsigned numFields = 0; // one value memref |
195 | foreachField(callback: [&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, |
196 | LevelType) -> bool { |
197 | if (fidx >= kDataFieldStartingIdx) |
198 | numFields++; |
199 | return true; |
200 | }); |
201 | numFields -= 1; // the last field is StorageSpecifier |
202 | assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); |
203 | return numFields; |
204 | } |
205 | |
206 | std::pair<FieldIndex, unsigned> |
207 | StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, |
208 | std::optional<Level> lvl) const { |
209 | FieldIndex fieldIdx = kInvalidFieldIndex; |
210 | unsigned stride = 1; |
211 | if (kind == SparseTensorFieldKind::CrdMemRef) { |
212 | assert(lvl.has_value()); |
213 | const Level cooStart = SparseTensorType(enc).getAoSCOOStart(); |
214 | const Level lvlRank = enc.getLvlRank(); |
215 | if (lvl.value() >= cooStart && lvl.value() < lvlRank) { |
216 | lvl = cooStart; |
217 | stride = lvlRank - cooStart; |
218 | } |
219 | } |
220 | foreachField(callback: [lvl, kind, &fieldIdx](FieldIndex fIdx, |
221 | SparseTensorFieldKind fKind, Level fLvl, |
222 | LevelType lt) -> bool { |
223 | if ((lvl && fLvl == lvl.value() && kind == fKind) || |
224 | (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { |
225 | fieldIdx = fIdx; |
226 | // Returns false to break the iteration. |
227 | return false; |
228 | } |
229 | return true; |
230 | }); |
231 | assert(fieldIdx != kInvalidFieldIndex); |
232 | return std::pair<FieldIndex, unsigned>(fieldIdx, stride); |
233 | } |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // SparseTensorDialect Attribute Methods. |
237 | //===----------------------------------------------------------------------===// |
238 | |
239 | std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { |
240 | return isDynamic(v) ? std::nullopt |
241 | : std::make_optional(static_cast<uint64_t>(v)); |
242 | } |
243 | |
244 | std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { |
245 | return getStatic(getOffset()); |
246 | } |
247 | |
248 | std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { |
249 | return getStatic(getStride()); |
250 | } |
251 | |
252 | std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { |
253 | return getStatic(getSize()); |
254 | } |
255 | |
256 | bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { |
257 | return isDynamic(getOffset()) && isDynamic(getStride()) && |
258 | isDynamic(getSize()); |
259 | } |
260 | |
261 | std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { |
262 | return isDynamic(v) ? "?" : std::to_string(v); |
263 | } |
264 | |
265 | void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { |
266 | assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr" ); |
267 | os << '('; |
268 | os << getStaticString(getOffset()); |
269 | os << ", " ; |
270 | os << getStaticString(getSize()); |
271 | os << ", " ; |
272 | os << getStaticString(getStride()); |
273 | os << ')'; |
274 | } |
275 | |
276 | void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { |
277 | print(printer.getStream()); |
278 | } |
279 | |
280 | static ParseResult parseOptionalStaticSlice(int64_t &result, |
281 | AsmParser &parser) { |
282 | auto parseResult = parser.parseOptionalInteger(result); |
283 | if (parseResult.has_value()) { |
284 | if (parseResult.value().succeeded() && result < 0) { |
285 | parser.emitError( |
286 | loc: parser.getCurrentLocation(), |
287 | message: "expect positive value or ? for slice offset/size/stride" ); |
288 | return failure(); |
289 | } |
290 | return parseResult.value(); |
291 | } |
292 | |
293 | // Else, and '?' which represented dynamic slice |
294 | result = SparseTensorDimSliceAttr::kDynamic; |
295 | return parser.parseQuestion(); |
296 | } |
297 | |
298 | Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { |
299 | int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; |
300 | |
301 | if (failed(parser.parseLParen()) || |
302 | failed(parseOptionalStaticSlice(offset, parser)) || |
303 | failed(parser.parseComma()) || |
304 | failed(parseOptionalStaticSlice(size, parser)) || |
305 | failed(parser.parseComma()) || |
306 | failed(parseOptionalStaticSlice(stride, parser)) || |
307 | failed(parser.parseRParen())) |
308 | return {}; |
309 | |
310 | return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(), |
311 | offset, size, stride); |
312 | } |
313 | |
314 | LogicalResult |
315 | SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
316 | int64_t offset, int64_t size, int64_t stride) { |
317 | if (!isDynamic(offset) && offset < 0) |
318 | return emitError() << "expect non-negative value or ? for slice offset" ; |
319 | if (!isDynamic(size) && size <= 0) |
320 | return emitError() << "expect positive value or ? for slice size" ; |
321 | if (!isDynamic(stride) && stride <= 0) |
322 | return emitError() << "expect positive value or ? for slice stride" ; |
323 | return success(); |
324 | } |
325 | |
326 | SparseTensorEncodingAttr |
327 | SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { |
328 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr" ); |
329 | return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl, |
330 | AffineMap(), getPosWidth(), |
331 | getCrdWidth()); |
332 | } |
333 | |
334 | SparseTensorEncodingAttr |
335 | SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { |
336 | return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); |
337 | } |
338 | |
339 | SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { |
340 | return withDimToLvl(AffineMap()); |
341 | } |
342 | |
343 | SparseTensorEncodingAttr |
344 | SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, |
345 | unsigned crdWidth) const { |
346 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr" ); |
347 | return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), |
348 | getDimToLvl(), getLvlToDim(), posWidth, |
349 | crdWidth); |
350 | } |
351 | |
352 | SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { |
353 | return withBitWidths(0, 0); |
354 | } |
355 | |
356 | SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( |
357 | ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { |
358 | return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), |
359 | getDimToLvl(), getLvlToDim(), |
360 | getPosWidth(), getCrdWidth(), dimSlices); |
361 | } |
362 | |
363 | SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { |
364 | return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{}); |
365 | } |
366 | |
367 | uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { |
368 | ArrayRef<LevelType> lvlTypes = getLvlTypes(); |
369 | auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); |
370 | return std::distance(lastBatch, lvlTypes.rend()); |
371 | } |
372 | |
373 | bool SparseTensorEncodingAttr::isAllDense() const { |
374 | return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); |
375 | } |
376 | |
377 | bool SparseTensorEncodingAttr::isAllOrdered() const { |
378 | return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); |
379 | } |
380 | |
381 | Type SparseTensorEncodingAttr::getCrdElemType() const { |
382 | if (!getImpl()) |
383 | return nullptr; |
384 | if (getCrdWidth()) |
385 | return IntegerType::get(getContext(), getCrdWidth()); |
386 | return IndexType::get(getContext()); |
387 | } |
388 | |
389 | Type SparseTensorEncodingAttr::getPosElemType() const { |
390 | if (!getImpl()) |
391 | return nullptr; |
392 | if (getPosWidth()) |
393 | return IntegerType::get(getContext(), getPosWidth()); |
394 | return IndexType::get(getContext()); |
395 | } |
396 | |
397 | MemRefType SparseTensorEncodingAttr::getCrdMemRefType( |
398 | std::optional<ArrayRef<int64_t>> dimShape) const { |
399 | SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); |
400 | return MemRefType::get(shape, getCrdElemType()); |
401 | } |
402 | |
403 | MemRefType SparseTensorEncodingAttr::getPosMemRefType( |
404 | std::optional<ArrayRef<int64_t>> dimShape) const { |
405 | SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); |
406 | return MemRefType::get(shape, getPosElemType()); |
407 | } |
408 | |
409 | bool SparseTensorEncodingAttr::isIdentity() const { |
410 | return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); |
411 | } |
412 | |
413 | bool SparseTensorEncodingAttr::isPermutation() const { |
414 | return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); |
415 | } |
416 | |
417 | Dimension SparseTensorEncodingAttr::getDimRank() const { |
418 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr" ); |
419 | const auto dimToLvl = getDimToLvl(); |
420 | return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); |
421 | } |
422 | |
423 | Level SparseTensorEncodingAttr::getLvlRank() const { |
424 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr" ); |
425 | return getLvlTypes().size(); |
426 | } |
427 | |
428 | LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { |
429 | if (!getImpl()) |
430 | return LevelFormat::Batch; |
431 | assert(l < getLvlRank() && "Level is out of bounds" ); |
432 | return getLvlTypes()[l]; |
433 | } |
434 | |
435 | bool SparseTensorEncodingAttr::isSlice() const { |
436 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr" ); |
437 | return !getDimSlices().empty(); |
438 | } |
439 | |
440 | SparseTensorDimSliceAttr |
441 | SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { |
442 | assert(isSlice() && "Is not a slice" ); |
443 | const auto dimSlices = getDimSlices(); |
444 | assert(dim < dimSlices.size() && "Dimension is out of bounds" ); |
445 | return dimSlices[dim]; |
446 | } |
447 | |
448 | std::optional<uint64_t> |
449 | SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { |
450 | return getDimSlice(dim).getStaticOffset(); |
451 | } |
452 | |
453 | std::optional<uint64_t> |
454 | SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { |
455 | return getDimSlice(dim).getStaticStride(); |
456 | } |
457 | |
458 | std::optional<uint64_t> |
459 | SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { |
460 | return getStaticDimSliceOffset(toDim(*this, lvl)); |
461 | } |
462 | |
463 | std::optional<uint64_t> |
464 | SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { |
465 | return getStaticDimSliceStride(toDim(*this, lvl)); |
466 | } |
467 | |
468 | SmallVector<int64_t> |
469 | SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape, |
470 | CrdTransDirectionKind dir) const { |
471 | if (isIdentity()) |
472 | return SmallVector<int64_t>(srcShape); |
473 | |
474 | SmallVector<int64_t> ret; |
475 | unsigned rank = |
476 | dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); |
477 | ret.reserve(rank); |
478 | |
479 | if (isPermutation()) { |
480 | for (unsigned r = 0; r < rank; r++) { |
481 | unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) |
482 | : toLvl(*this, r); |
483 | ret.push_back(srcShape[trans]); |
484 | } |
485 | return ret; |
486 | } |
487 | |
488 | // Handle non-permutation maps. |
489 | AffineMap transMap = |
490 | dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); |
491 | |
492 | SmallVector<AffineExpr> dimRep; |
493 | dimRep.reserve(srcShape.size()); |
494 | for (int64_t sz : srcShape) { |
495 | if (!ShapedType::isDynamic(sz)) { |
496 | // Push back the max coordinate for the given dimension/level size. |
497 | dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); |
498 | } else { |
499 | // A dynamic size, use a AffineDimExpr to symbolize the value. |
500 | dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); |
501 | } |
502 | }; |
503 | |
504 | for (AffineExpr exp : transMap.getResults()) { |
505 | // Do constant propagation on the affine map. |
506 | AffineExpr evalExp = |
507 | simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); |
508 | // use llvm namespace here to avoid ambiguity |
509 | if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) { |
510 | ret.push_back(c.getValue() + 1); |
511 | } else { |
512 | if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp); |
513 | mod && mod.getKind() == AffineExprKind::Mod) { |
514 | // We can still infer a static bound for expressions in form |
515 | // "d % constant" since d % constant \in [0, constant). |
516 | if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) { |
517 | ret.push_back(bound.getValue()); |
518 | continue; |
519 | } |
520 | } |
521 | ret.push_back(ShapedType::kDynamic); |
522 | } |
523 | } |
524 | assert(ret.size() == rank); |
525 | return ret; |
526 | } |
527 | |
528 | ValueRange |
529 | SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, |
530 | ValueRange crds, |
531 | CrdTransDirectionKind dir) const { |
532 | if (!getImpl()) |
533 | return crds; |
534 | |
535 | SmallVector<Type> retType( |
536 | dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), |
537 | builder.getIndexType()); |
538 | auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this); |
539 | return transOp.getOutCrds(); |
540 | } |
541 | |
542 | Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { |
543 | // Open "<{" part. |
544 | if (failed(parser.parseLess())) |
545 | return {}; |
546 | if (failed(parser.parseLBrace())) |
547 | return {}; |
548 | |
549 | // Process the data from the parsed dictionary value into struct-like data. |
550 | SmallVector<LevelType> lvlTypes; |
551 | SmallVector<SparseTensorDimSliceAttr> dimSlices; |
552 | AffineMap dimToLvl = {}; |
553 | AffineMap lvlToDim = {}; |
554 | unsigned posWidth = 0; |
555 | unsigned crdWidth = 0; |
556 | StringRef attrName; |
557 | SmallVector<StringRef, 3> keys = {"map" , "posWidth" , "crdWidth" }; |
558 | while (succeeded(parser.parseOptionalKeyword(&attrName))) { |
559 | // Detect admissible keyword. |
560 | auto *it = find(keys, attrName); |
561 | if (it == keys.end()) { |
562 | parser.emitError(parser.getNameLoc(), "unexpected key: " ) << attrName; |
563 | return {}; |
564 | } |
565 | unsigned keyWordIndex = it - keys.begin(); |
566 | // Consume the `=` after keys |
567 | if (failed(parser.parseEqual())) |
568 | return {}; |
569 | // Dispatch on keyword. |
570 | switch (keyWordIndex) { |
571 | case 0: { // map |
572 | ir_detail::DimLvlMapParser cParser(parser); |
573 | auto res = cParser.parseDimLvlMap(); |
574 | if (failed(res)) |
575 | return {}; |
576 | const auto &dlm = *res; |
577 | |
578 | const Level lvlRank = dlm.getLvlRank(); |
579 | for (Level lvl = 0; lvl < lvlRank; lvl++) |
580 | lvlTypes.push_back(dlm.getLvlType(lvl)); |
581 | |
582 | const Dimension dimRank = dlm.getDimRank(); |
583 | for (Dimension dim = 0; dim < dimRank; dim++) |
584 | dimSlices.push_back(dlm.getDimSlice(dim)); |
585 | // NOTE: the old syntax requires an all-or-nothing approach to |
586 | // `dimSlices`; therefore, if any slice actually exists then we need |
587 | // to convert null-DSA into default/nop DSA. |
588 | const auto isDefined = [](SparseTensorDimSliceAttr slice) { |
589 | return static_cast<bool>(slice.getImpl()); |
590 | }; |
591 | if (llvm::any_of(dimSlices, isDefined)) { |
592 | const auto defaultSlice = |
593 | SparseTensorDimSliceAttr::get(parser.getContext()); |
594 | for (Dimension dim = 0; dim < dimRank; dim++) |
595 | if (!isDefined(dimSlices[dim])) |
596 | dimSlices[dim] = defaultSlice; |
597 | } else { |
598 | dimSlices.clear(); |
599 | } |
600 | |
601 | dimToLvl = dlm.getDimToLvlMap(parser.getContext()); |
602 | lvlToDim = dlm.getLvlToDimMap(parser.getContext()); |
603 | break; |
604 | } |
605 | case 1: { // posWidth |
606 | Attribute attr; |
607 | if (failed(parser.parseAttribute(attr))) |
608 | return {}; |
609 | auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); |
610 | if (!intAttr) { |
611 | parser.emitError(parser.getNameLoc(), |
612 | "expected an integral position bitwidth" ); |
613 | return {}; |
614 | } |
615 | posWidth = intAttr.getInt(); |
616 | break; |
617 | } |
618 | case 2: { // crdWidth |
619 | Attribute attr; |
620 | if (failed(parser.parseAttribute(attr))) |
621 | return {}; |
622 | auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); |
623 | if (!intAttr) { |
624 | parser.emitError(parser.getNameLoc(), |
625 | "expected an integral index bitwidth" ); |
626 | return {}; |
627 | } |
628 | crdWidth = intAttr.getInt(); |
629 | break; |
630 | } |
631 | } // switch |
632 | // Only last item can omit the comma. |
633 | if (parser.parseOptionalComma().failed()) |
634 | break; |
635 | } |
636 | |
637 | // Close "}>" part. |
638 | if (failed(parser.parseRBrace())) |
639 | return {}; |
640 | if (failed(parser.parseGreater())) |
641 | return {}; |
642 | |
643 | // Construct struct-like storage for attribute. |
644 | if (!lvlToDim || lvlToDim.isEmpty()) { |
645 | lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); |
646 | } |
647 | return parser.getChecked<SparseTensorEncodingAttr>( |
648 | parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, |
649 | dimSlices); |
650 | } |
651 | |
652 | void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { |
653 | auto map = static_cast<AffineMap>(getDimToLvl()); |
654 | // Empty affine map indicates identity map |
655 | if (!map) |
656 | map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); |
657 | printer << "<{ map = " ; |
658 | printSymbols(map, printer); |
659 | printer << '('; |
660 | printDimensions(map, printer, getDimSlices()); |
661 | printer << ") -> (" ; |
662 | printLevels(map, printer, getLvlTypes()); |
663 | printer << ')'; |
664 | // Print remaining members only for non-default values. |
665 | if (getPosWidth()) |
666 | printer << ", posWidth = " << getPosWidth(); |
667 | if (getCrdWidth()) |
668 | printer << ", crdWidth = " << getCrdWidth(); |
669 | printer << " }>" ; |
670 | } |
671 | |
672 | void SparseTensorEncodingAttr::printSymbols(AffineMap &map, |
673 | AsmPrinter &printer) const { |
674 | if (map.getNumSymbols() == 0) |
675 | return; |
676 | printer << '['; |
677 | for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) |
678 | printer << 's' << i << ", " ; |
679 | if (map.getNumSymbols() >= 1) |
680 | printer << 's' << map.getNumSymbols() - 1; |
681 | printer << ']'; |
682 | } |
683 | |
684 | void SparseTensorEncodingAttr::printDimensions( |
685 | AffineMap &map, AsmPrinter &printer, |
686 | ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { |
687 | if (!dimSlices.empty()) { |
688 | for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) |
689 | printer << 'd' << i << " : " << dimSlices[i] << ", " ; |
690 | if (map.getNumDims() >= 1) { |
691 | printer << 'd' << map.getNumDims() - 1 << " : " |
692 | << dimSlices[map.getNumDims() - 1]; |
693 | } |
694 | } else { |
695 | for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) |
696 | printer << 'd' << i << ", " ; |
697 | if (map.getNumDims() >= 1) |
698 | printer << 'd' << map.getNumDims() - 1; |
699 | } |
700 | } |
701 | |
702 | void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, |
703 | ArrayRef<LevelType> lvlTypes) const { |
704 | for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { |
705 | map.getResult(i).print(printer.getStream()); |
706 | printer << " : " << toMLIRString(lvlTypes[i]) << ", " ; |
707 | } |
708 | if (map.getNumResults() >= 1) { |
709 | auto lastIndex = map.getNumResults() - 1; |
710 | map.getResult(lastIndex).print(printer.getStream()); |
711 | printer << " : " << toMLIRString(lvlTypes[lastIndex]); |
712 | } |
713 | } |
714 | |
715 | LogicalResult SparseTensorEncodingAttr::verify( |
716 | function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes, |
717 | AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, |
718 | unsigned crdWidth, ArrayRef<SparseTensorDimSliceAttr> dimSlices) { |
719 | if (!acceptBitWidth(posWidth)) |
720 | return emitError() << "unexpected position bitwidth: " << posWidth; |
721 | if (!acceptBitWidth(crdWidth)) |
722 | return emitError() << "unexpected coordinate bitwidth: " << crdWidth; |
723 | if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); |
724 | it != std::end(lvlTypes)) { |
725 | if (it == lvlTypes.begin() || |
726 | (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1)))) |
727 | return emitError() << "expected compressed or loose_compressed level " |
728 | "before singleton level" ; |
729 | if (!std::all_of(it, lvlTypes.end(), |
730 | [](LevelType i) { return isSingletonLT(i); })) |
731 | return emitError() << "expected all singleton lvlTypes " |
732 | "following a singleton level" ; |
733 | // We can potentially support mixed SoA/AoS singleton levels. |
734 | if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) { |
735 | return it->isa<LevelPropNonDefault::SoA>() == |
736 | i.isa<LevelPropNonDefault::SoA>(); |
737 | })) { |
738 | return emitError() << "expected all singleton lvlTypes stored in the " |
739 | "same memory layout (SoA vs AoS)." ; |
740 | } |
741 | } |
742 | |
743 | auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); |
744 | if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) |
745 | return emitError() << "Batch lvlType can only be leading levels." ; |
746 | |
747 | // SoA property can only be applied on singleton level. |
748 | auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { |
749 | return lt.isa<LevelPropNonDefault::SoA>(); |
750 | }); |
751 | if (llvm::any_of(soaLvls, [](LevelType lt) { |
752 | return !lt.isa<LevelFormat::Singleton>(); |
753 | })) { |
754 | return emitError() << "SoA is only applicable to singleton lvlTypes." ; |
755 | } |
756 | |
757 | // TODO: audit formats that actually are supported by backend. |
758 | if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT); |
759 | it != std::end(lvlTypes)) { |
760 | if (it != lvlTypes.end() - 1) |
761 | return emitError() << "expected n_out_of_m to be the last level type" ; |
762 | if (!std::all_of(lvlTypes.begin(), it, |
763 | [](LevelType i) { return isDenseLT(i); })) |
764 | return emitError() << "expected all dense lvlTypes " |
765 | "before a n_out_of_m level" ; |
766 | if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { |
767 | if (!isBlockSparsity(dimToLvl)) { |
768 | return emitError() |
769 | << "expected 1xm block structure for n_out_of_m level" ; |
770 | } |
771 | auto sizes = getBlockSize(dimToLvl); |
772 | unsigned coefficient = 0; |
773 | for (const auto &elem : sizes) { |
774 | if (elem != 0) { |
775 | if (elem != coefficient && coefficient != 0) { |
776 | return emitError() << "expected only one blocked level " |
777 | "with the same coefficients" ; |
778 | } |
779 | coefficient = elem; |
780 | } |
781 | } |
782 | if (coefficient != getM(*it)) { |
783 | return emitError() << "expected coeffiencts of Affine expressions " |
784 | "to be equal to m of n_out_of_m level" ; |
785 | } |
786 | } |
787 | } |
788 | // Before we can check that the level-rank is consistent/coherent |
789 | // across all fields, we need to define it. The source-of-truth for |
790 | // the `getLvlRank` method is the length of the level-types array, |
791 | // since it must always be provided and have full rank; therefore we |
792 | // use that same source-of-truth here. |
793 | const Level lvlRank = lvlTypes.size(); |
794 | if (lvlRank == 0) |
795 | return emitError() << "expected a non-empty array for lvlTypes" ; |
796 | // We save `dimRank` here because we'll also need it to verify `dimSlices`. |
797 | const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; |
798 | if (dimToLvl) { |
799 | if (dimToLvl.getNumResults() != lvlRank) |
800 | return emitError() |
801 | << "level-rank mismatch between dimToLvl and lvlTypes: " |
802 | << dimToLvl.getNumResults() << " != " << lvlRank; |
803 | auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); |
804 | // Symbols can't be inferred but are acceptable. |
805 | if (!inferRes && dimToLvl.getNumSymbols() == 0) |
806 | return emitError() << "failed to infer lvlToDim from dimToLvl" ; |
807 | if (lvlToDim && (inferRes != lvlToDim)) |
808 | return emitError() << "expected lvlToDim to be an inverse of dimToLvl" ; |
809 | if (dimRank > lvlRank) |
810 | return emitError() << "unexpected dimToLvl mapping from " << dimRank |
811 | << " to " << lvlRank; |
812 | } |
813 | if (!dimSlices.empty()) { |
814 | if (dimSlices.size() != dimRank) |
815 | return emitError() |
816 | << "dimension-rank mismatch between dimSlices and dimToLvl: " |
817 | << dimSlices.size() << " != " << dimRank; |
818 | // Compiler support for `dimSlices` currently requires that the two |
819 | // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) |
820 | if (dimRank != lvlRank) |
821 | return emitError() |
822 | << "dimSlices expected dimension-rank to match level-rank: " |
823 | << dimRank << " != " << lvlRank; |
824 | } |
825 | return success(); |
826 | } |
827 | |
828 | LogicalResult SparseTensorEncodingAttr::verifyEncoding( |
829 | ArrayRef<Size> dimShape, Type elementType, |
830 | function_ref<InFlightDiagnostic()> emitError) const { |
831 | // Check structural integrity. In particular, this ensures that the |
832 | // level-rank is coherent across all the fields. |
833 | if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), |
834 | getPosWidth(), getCrdWidth(), getDimSlices()))) |
835 | return failure(); |
836 | // Check integrity with tensor type specifics. In particular, we |
837 | // need only check that the dimension-rank of the tensor agrees with |
838 | // the dimension-rank of the encoding. |
839 | const Dimension dimRank = dimShape.size(); |
840 | if (dimRank == 0) |
841 | return emitError() << "expected non-scalar sparse tensor" ; |
842 | if (getDimRank() != dimRank) |
843 | return emitError() |
844 | << "dimension-rank mismatch between encoding and tensor shape: " |
845 | << getDimRank() << " != " << dimRank; |
846 | return success(); |
847 | } |
848 | |
849 | //===----------------------------------------------------------------------===// |
850 | // SparseTensorType Methods. |
851 | //===----------------------------------------------------------------------===// |
852 | |
853 | bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, |
854 | bool isUnique) const { |
855 | if (!hasEncoding()) |
856 | return false; |
857 | if (!isCompressedLvl(l: startLvl) && !isLooseCompressedLvl(l: startLvl)) |
858 | return false; |
859 | for (Level l = startLvl + 1; l < lvlRank; ++l) |
860 | if (!isSingletonLvl(l)) |
861 | return false; |
862 | // If isUnique is true, then make sure that the last level is unique, |
863 | // that is, when lvlRank == 1, the only compressed level is unique, |
864 | // and when lvlRank > 1, the last singleton is unique. |
865 | return !isUnique || isUniqueLvl(l: lvlRank - 1); |
866 | } |
867 | |
868 | Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const { |
869 | SmallVector<COOSegment> coo = getCOOSegments(); |
870 | assert(coo.size() == 1 || coo.empty()); |
871 | if (!coo.empty() && coo.front().isAoS()) { |
872 | return coo.front().lvlRange.first; |
873 | } |
874 | return lvlRank; |
875 | } |
876 | |
877 | SmallVector<COOSegment> |
878 | mlir::sparse_tensor::SparseTensorType::getCOOSegments() const { |
879 | SmallVector<COOSegment> ret; |
880 | if (!hasEncoding() || lvlRank <= 1) |
881 | return ret; |
882 | |
883 | ArrayRef<LevelType> lts = getLvlTypes(); |
884 | Level l = 0; |
885 | while (l < lvlRank) { |
886 | auto lt = lts[l]; |
887 | if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) { |
888 | auto cur = lts.begin() + l; |
889 | auto end = std::find_if(first: cur + 1, last: lts.end(), pred: [](LevelType lt) { |
890 | return !lt.isa<LevelFormat::Singleton>(); |
891 | }); |
892 | unsigned cooLen = std::distance(first: cur, last: end); |
893 | if (cooLen > 1) { |
894 | // To support mixed SoA/AoS COO, we should break the segment when the |
895 | // storage scheme changes, for now we faithfully assume that all |
896 | // consecutive singleton levels have the same storage format as verified |
897 | // STEA. |
898 | ret.push_back(Elt: COOSegment{.lvlRange: std::make_pair(x&: l, y: l + cooLen), |
899 | .isSoA: lts[l + 1].isa<LevelPropNonDefault::SoA>()}); |
900 | } |
901 | l += cooLen; |
902 | } else { |
903 | l++; |
904 | } |
905 | } |
906 | return ret; |
907 | } |
908 | |
909 | RankedTensorType |
910 | mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { |
911 | SmallVector<LevelType> lvlTypes; |
912 | lvlTypes.reserve(N: lvlRank); |
913 | // A non-unique compressed level at beginning (unless this is |
914 | // also the last level, then it is unique). |
915 | lvlTypes.push_back( |
916 | Elt: *buildLevelType(lf: LevelFormat::Compressed, ordered, unique: lvlRank == 1)); |
917 | if (lvlRank > 1) { |
918 | // Followed by n-2 non-unique singleton levels. |
919 | std::fill_n(std::back_inserter(x&: lvlTypes), lvlRank - 2, |
920 | *buildLevelType(lf: LevelFormat::Singleton, ordered, unique: false)); |
921 | // Ends by a unique singleton level. |
922 | lvlTypes.push_back(Elt: *buildLevelType(lf: LevelFormat::Singleton, ordered, unique: true)); |
923 | } |
924 | auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes, |
925 | getDimToLvl(), getLvlToDim(), |
926 | getPosWidth(), getCrdWidth()); |
927 | return RankedTensorType::get(getDimShape(), getElementType(), enc); |
928 | } |
929 | |
930 | //===----------------------------------------------------------------------===// |
931 | // Convenience Methods. |
932 | //===----------------------------------------------------------------------===// |
933 | |
934 | SparseTensorEncodingAttr |
935 | mlir::sparse_tensor::getSparseTensorEncoding(Type type) { |
936 | if (auto ttp = llvm::dyn_cast<RankedTensorType>(type)) |
937 | return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding()); |
938 | if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type)) |
939 | return mdtp.getEncoding(); |
940 | return nullptr; |
941 | } |
942 | |
943 | AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, |
944 | MLIRContext *context) { |
945 | auto map = static_cast<AffineMap>(dimToLvl); |
946 | AffineMap lvlToDim; |
947 | // Return an empty lvlToDim when inference is not successful. |
948 | if (!map || map.getNumSymbols() != 0) { |
949 | lvlToDim = AffineMap(); |
950 | } else if (map.isPermutation()) { |
951 | lvlToDim = inversePermutation(map); |
952 | } else if (isBlockSparsity(dimToLvl: map)) { |
953 | lvlToDim = inverseBlockSparsity(dimToLvl: map, context); |
954 | } |
955 | return lvlToDim; |
956 | } |
957 | |
958 | AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, |
959 | MLIRContext *context) { |
960 | SmallVector<AffineExpr> lvlExprs; |
961 | auto numLvls = dimToLvl.getNumResults(); |
962 | lvlExprs.reserve(N: numLvls); |
963 | // lvlExprComponents stores information of the floordiv and mod operations |
964 | // applied to the same dimension, so as to build the lvlToDim map. |
965 | std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents; |
966 | for (unsigned i = 0, n = numLvls; i < n; i++) { |
967 | auto result = dimToLvl.getResult(idx: i); |
968 | if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) { |
969 | if (result.getKind() == AffineExprKind::FloorDiv) { |
970 | // Position of the dimension in dimToLvl. |
971 | auto pos = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()).getPosition(); |
972 | assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && |
973 | "expected only one floordiv for each dimension" ); |
974 | SmallVector<AffineExpr, 3> components; |
975 | // Level variable for floordiv. |
976 | components.push_back(Elt: getAffineDimExpr(position: i, context)); |
977 | // Multiplier. |
978 | components.push_back(Elt: binOp.getRHS()); |
979 | // Map key is the position of the dimension. |
980 | lvlExprComponents[pos] = components; |
981 | } else if (result.getKind() == AffineExprKind::Mod) { |
982 | auto pos = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()).getPosition(); |
983 | assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && |
984 | "expected floordiv before mod" ); |
985 | // Add level variable for mod to the same vector |
986 | // of the corresponding floordiv. |
987 | lvlExprComponents[pos].push_back(Elt: getAffineDimExpr(position: i, context)); |
988 | } else { |
989 | assert(false && "expected floordiv or mod" ); |
990 | } |
991 | } else { |
992 | lvlExprs.push_back(Elt: getAffineDimExpr(position: i, context)); |
993 | } |
994 | } |
995 | // Build lvlExprs from lvlExprComponents. |
996 | // For example, for il = i floordiv 2 and ii = i mod 2, the components |
997 | // would be [il, 2, ii]. It could be used to build the AffineExpr |
998 | // i = il * 2 + ii in lvlToDim. |
999 | for (auto &components : lvlExprComponents) { |
1000 | assert(components.second.size() == 3 && |
1001 | "expected 3 components to build lvlExprs" ); |
1002 | auto mulOp = getAffineBinaryOpExpr( |
1003 | kind: AffineExprKind::Mul, lhs: components.second[0], rhs: components.second[1]); |
1004 | auto addOp = |
1005 | getAffineBinaryOpExpr(kind: AffineExprKind::Add, lhs: mulOp, rhs: components.second[2]); |
1006 | lvlExprs.push_back(Elt: addOp); |
1007 | } |
1008 | return dimToLvl.get(dimCount: dimToLvl.getNumResults(), symbolCount: 0, results: lvlExprs, context); |
1009 | } |
1010 | |
1011 | SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { |
1012 | assert(isBlockSparsity(dimToLvl) && |
1013 | "expected dimToLvl to be block sparsity for calling getBlockSize" ); |
1014 | SmallVector<unsigned> blockSize; |
1015 | for (auto result : dimToLvl.getResults()) { |
1016 | if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) { |
1017 | if (result.getKind() == AffineExprKind::Mod) { |
1018 | blockSize.push_back( |
1019 | Elt: dyn_cast<AffineConstantExpr>(Val: binOp.getRHS()).getValue()); |
1020 | } |
1021 | } else { |
1022 | blockSize.push_back(Elt: 0); |
1023 | } |
1024 | } |
1025 | return blockSize; |
1026 | } |
1027 | |
1028 | bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { |
1029 | if (!dimToLvl) |
1030 | return false; |
1031 | std::map<unsigned, int64_t> coeffientMap; |
1032 | bool hasBlock = false; |
1033 | for (auto result : dimToLvl.getResults()) { |
1034 | if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) { |
1035 | // Check for "dim op const". |
1036 | auto dimOp = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()); |
1037 | auto conOp = dyn_cast<AffineConstantExpr>(Val: binOp.getRHS()); |
1038 | if (!dimOp || !conOp || conOp.getValue() <= 0) |
1039 | return false; |
1040 | // Inspect "dim / const" or "dim % const". |
1041 | auto pos = dimOp.getPosition(); |
1042 | if (binOp.getKind() == AffineExprKind::FloorDiv) { |
1043 | // Expect only one floordiv for each dimension. |
1044 | if (coeffientMap.find(x: pos) != coeffientMap.end()) |
1045 | return false; |
1046 | // Record coefficient of the floordiv. |
1047 | coeffientMap[pos] = conOp.getValue(); |
1048 | } else if (binOp.getKind() == AffineExprKind::Mod) { |
1049 | // Expect floordiv before mod. |
1050 | if (coeffientMap.find(x: pos) == coeffientMap.end()) |
1051 | return false; |
1052 | // Expect mod to have the same coefficient as floordiv. |
1053 | if (conOp.getValue() != coeffientMap[pos]) |
1054 | return false; |
1055 | hasBlock = true; |
1056 | } else { |
1057 | return false; |
1058 | } |
1059 | } else if (auto dimOp = dyn_cast<AffineDimExpr>(Val&: result)) { |
1060 | auto pos = dimOp.getPosition(); |
1061 | // Expect dim to be unset. |
1062 | if (coeffientMap.find(x: pos) != coeffientMap.end()) |
1063 | return false; |
1064 | coeffientMap[pos] = 0; |
1065 | } else { |
1066 | return false; |
1067 | } |
1068 | } |
1069 | return hasBlock; |
1070 | } |
1071 | |
1072 | bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { |
1073 | auto hasNonIdentityMap = [](Value v) { |
1074 | auto stt = tryGetSparseTensorType(v); |
1075 | return stt && !stt->isIdentity(); |
1076 | }; |
1077 | |
1078 | return llvm::any_of(Range: op->getOperands(), P: hasNonIdentityMap) || |
1079 | llvm::any_of(Range: op->getResults(), P: hasNonIdentityMap); |
1080 | } |
1081 | |
1082 | Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { |
1083 | if (enc) { |
1084 | assert(enc.isPermutation() && "Non permutation map not supported" ); |
1085 | if (const auto dimToLvl = enc.getDimToLvl()) |
1086 | return dimToLvl.getDimPosition(l); |
1087 | } |
1088 | return l; |
1089 | } |
1090 | |
1091 | Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { |
1092 | if (enc) { |
1093 | assert(enc.isPermutation() && "Non permutation map not supported" ); |
1094 | if (const auto lvlToDim = enc.getLvlToDim()) |
1095 | return lvlToDim.getDimPosition(d); |
1096 | } |
1097 | return d; |
1098 | } |
1099 | |
1100 | /// We normalized sparse tensor encoding attribute by always using |
1101 | /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well |
1102 | /// as other variants) lead to the same storage specifier type, and stripping |
1103 | /// irrelevant fields that do not alter the sparse tensor memory layout. |
1104 | static SparseTensorEncodingAttr |
1105 | getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { |
1106 | SmallVector<LevelType> lts; |
1107 | for (auto lt : enc.getLvlTypes()) |
1108 | lts.push_back(lt.stripStorageIrrelevantProperties()); |
1109 | |
1110 | return SparseTensorEncodingAttr::get( |
1111 | enc.getContext(), lts, |
1112 | AffineMap(), // dimToLvl (irrelevant to storage specifier) |
1113 | AffineMap(), // lvlToDim (irrelevant to storage specifier) |
1114 | // Always use `index` for memSize and lvlSize instead of reusing |
1115 | // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA |
1116 | // value for different bitwidth, it also avoids casting between index and |
1117 | // integer (returned by DimOp) |
1118 | 0, 0, enc.getDimSlices()); |
1119 | } |
1120 | |
1121 | StorageSpecifierType |
1122 | StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { |
1123 | return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); |
1124 | } |
1125 | |
1126 | //===----------------------------------------------------------------------===// |
1127 | // SparseTensorDialect Operations. |
1128 | //===----------------------------------------------------------------------===// |
1129 | |
1130 | static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { |
1131 | return success(isSuccess: lvl < getSparseTensorType(val: tensor).getLvlRank()); |
1132 | } |
1133 | |
1134 | static LogicalResult isMatchingWidth(Value mem, unsigned width) { |
1135 | const Type etp = getMemRefType(mem).getElementType(); |
1136 | return success(isSuccess: width == 0 ? etp.isIndex() : etp.isInteger(width)); |
1137 | } |
1138 | |
1139 | static LogicalResult verifySparsifierGetterSetter( |
1140 | StorageSpecifierKind mdKind, std::optional<Level> lvl, |
1141 | TypedValue<StorageSpecifierType> md, Operation *op) { |
1142 | if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { |
1143 | return op->emitError( |
1144 | message: "redundant level argument for querying value memory size" ); |
1145 | } |
1146 | |
1147 | const auto enc = md.getType().getEncoding(); |
1148 | const Level lvlRank = enc.getLvlRank(); |
1149 | |
1150 | if (mdKind == StorageSpecifierKind::DimOffset || |
1151 | mdKind == StorageSpecifierKind::DimStride) |
1152 | if (!enc.isSlice()) |
1153 | return op->emitError(message: "requested slice data on non-slice tensor" ); |
1154 | |
1155 | if (mdKind != StorageSpecifierKind::ValMemSize) { |
1156 | if (!lvl) |
1157 | return op->emitError(message: "missing level argument" ); |
1158 | |
1159 | const Level l = lvl.value(); |
1160 | if (l >= lvlRank) |
1161 | return op->emitError(message: "requested level is out of bounds" ); |
1162 | |
1163 | if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) |
1164 | return op->emitError( |
1165 | message: "requested position memory size on a singleton level" ); |
1166 | } |
1167 | return success(); |
1168 | } |
1169 | |
1170 | static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { |
1171 | switch (kind) { |
1172 | case SparseTensorFieldKind::CrdMemRef: |
1173 | return stt.getCrdType(); |
1174 | case SparseTensorFieldKind::PosMemRef: |
1175 | return stt.getPosType(); |
1176 | case SparseTensorFieldKind::ValMemRef: |
1177 | return stt.getElementType(); |
1178 | case SparseTensorFieldKind::StorageSpec: |
1179 | return nullptr; |
1180 | } |
1181 | llvm_unreachable("Unrecognizable FieldKind" ); |
1182 | } |
1183 | |
1184 | static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, |
1185 | SparseTensorType stt, |
1186 | RankedTensorType valTp, |
1187 | TypeRange lvlTps) { |
1188 | if (requiresStaticShape && !stt.hasStaticDimShape()) |
1189 | return op->emitError(message: "the sparse-tensor must have static shape" ); |
1190 | if (!stt.hasEncoding()) |
1191 | return op->emitError(message: "the sparse-tensor must have an encoding attribute" ); |
1192 | |
1193 | // Verifies the trailing COO. |
1194 | Level cooStartLvl = stt.getAoSCOOStart(); |
1195 | if (cooStartLvl < stt.getLvlRank()) { |
1196 | // We only supports trailing COO for now, must be the last input. |
1197 | auto cooTp = llvm::cast<ShapedType>(lvlTps.back()); |
1198 | // The coordinates should be in shape of <? x rank> |
1199 | unsigned expCOORank = stt.getLvlRank() - cooStartLvl; |
1200 | if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { |
1201 | op->emitError(message: "input/output trailing COO level-ranks don't match" ); |
1202 | } |
1203 | } |
1204 | |
1205 | // Verifies that all types match. |
1206 | StorageLayout layout(stt.getEncoding()); |
1207 | if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref |
1208 | return op->emitError(message: "inconsistent number of fields between input/output" ); |
1209 | |
1210 | unsigned idx = 0; |
1211 | bool misMatch = false; |
1212 | layout.foreachField(callback: [&idx, &misMatch, stt, valTp, |
1213 | lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, |
1214 | Level lvl, LevelType lt) -> bool { |
1215 | if (fKind == SparseTensorFieldKind::StorageSpec) |
1216 | return true; |
1217 | |
1218 | Type inputTp = nullptr; |
1219 | if (fKind == SparseTensorFieldKind::ValMemRef) { |
1220 | inputTp = valTp; |
1221 | } else { |
1222 | assert(fid == idx && stt.getLvlType(lvl) == lt); |
1223 | inputTp = lvlTps[idx++]; |
1224 | } |
1225 | // The input element type and expected element type should match. |
1226 | Type inpElemTp = llvm::cast<TensorType>(Val&: inputTp).getElementType(); |
1227 | Type expElemTp = getFieldElemType(stt, kind: fKind); |
1228 | if (inpElemTp != expElemTp) { |
1229 | misMatch = true; |
1230 | return false; // to terminate the iteration |
1231 | } |
1232 | return true; |
1233 | }); |
1234 | |
1235 | if (misMatch) |
1236 | return op->emitError(message: "input/output element-types don't match" ); |
1237 | return success(); |
1238 | } |
1239 | |
1240 | LogicalResult AssembleOp::verify() { |
1241 | const auto valuesTp = getRankedTensorType(getValues()); |
1242 | const auto lvlsTp = getLevels().getTypes(); |
1243 | const auto resTp = getSparseTensorType(getResult()); |
1244 | return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); |
1245 | } |
1246 | |
1247 | LogicalResult DisassembleOp::verify() { |
1248 | if (getOutValues().getType() != getRetValues().getType()) |
1249 | return emitError("output values and return value type mismatch" ); |
1250 | |
1251 | for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) |
1252 | if (ot.getType() != rt.getType()) |
1253 | return emitError("output levels and return levels type mismatch" ); |
1254 | |
1255 | const auto valuesTp = getRankedTensorType(getRetValues()); |
1256 | const auto lvlsTp = getRetLevels().getTypes(); |
1257 | const auto srcTp = getSparseTensorType(getTensor()); |
1258 | return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); |
1259 | } |
1260 | |
1261 | LogicalResult ConvertOp::verify() { |
1262 | if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) { |
1263 | if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) { |
1264 | if (tp1.getRank() != tp2.getRank()) |
1265 | return emitError("unexpected conversion mismatch in rank" ); |
1266 | auto dstEnc = |
1267 | llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding()); |
1268 | if (dstEnc && dstEnc.isSlice()) |
1269 | return emitError("cannot convert to a sparse tensor slice" ); |
1270 | |
1271 | auto shape1 = tp1.getShape(); |
1272 | auto shape2 = tp2.getShape(); |
1273 | // Accept size matches between the source and the destination type |
1274 | // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or |
1275 | // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). |
1276 | for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) |
1277 | if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) |
1278 | return emitError("unexpected conversion mismatch in dimension " ) << d; |
1279 | return success(); |
1280 | } |
1281 | } |
1282 | return emitError("unexpected type in convert" ); |
1283 | } |
1284 | |
1285 | OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { |
1286 | if (getType() == getSource().getType()) |
1287 | return getSource(); |
1288 | return {}; |
1289 | } |
1290 | |
1291 | bool ConvertOp::needsExtraSort() { |
1292 | SparseTensorType srcStt = getSparseTensorType(getSource()); |
1293 | SparseTensorType dstStt = getSparseTensorType(getDest()); |
1294 | |
1295 | // We do not need an extra sort when returning unordered sparse tensors or |
1296 | // dense tensor since dense tensor support random access. |
1297 | if (dstStt.isAllDense() || !dstStt.isAllOrdered()) |
1298 | return false; |
1299 | |
1300 | if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && |
1301 | srcStt.hasSameDimToLvl(dstStt)) { |
1302 | return false; |
1303 | } |
1304 | |
1305 | // Source and dest tensors are ordered in different ways. We only do direct |
1306 | // dense to sparse conversion when the dense input is defined by a sparse |
1307 | // constant. Note that we can theoretically always directly convert from dense |
1308 | // inputs by rotating dense loops but it leads to bad cache locality and hurt |
1309 | // performance. |
1310 | if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>()) |
1311 | if (isa<SparseElementsAttr>(constOp.getValue())) |
1312 | return false; |
1313 | |
1314 | return true; |
1315 | } |
1316 | |
1317 | LogicalResult CrdTranslateOp::verify() { |
1318 | uint64_t inRank = getEncoder().getLvlRank(); |
1319 | uint64_t outRank = getEncoder().getDimRank(); |
1320 | |
1321 | if (getDirection() == CrdTransDirectionKind::dim2lvl) |
1322 | std::swap(inRank, outRank); |
1323 | |
1324 | if (inRank != getInCrds().size() || outRank != getOutCrds().size()) |
1325 | return emitError("Coordinate rank mismatch with encoding" ); |
1326 | |
1327 | return success(); |
1328 | } |
1329 | |
1330 | LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, |
1331 | SmallVectorImpl<OpFoldResult> &results) { |
1332 | if (getEncoder().isIdentity()) { |
1333 | results.assign(getInCrds().begin(), getInCrds().end()); |
1334 | return success(); |
1335 | } |
1336 | if (getEncoder().isPermutation()) { |
1337 | AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl |
1338 | ? getEncoder().getDimToLvl() |
1339 | : getEncoder().getLvlToDim(); |
1340 | for (AffineExpr exp : perm.getResults()) |
1341 | results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]); |
1342 | return success(); |
1343 | } |
1344 | |
1345 | // Fuse dim2lvl/lvl2dim pairs. |
1346 | auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>(); |
1347 | bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { |
1348 | return v.getDefiningOp() == def; |
1349 | }); |
1350 | if (!sameDef) |
1351 | return failure(); |
1352 | |
1353 | bool oppositeDir = def.getDirection() != getDirection(); |
1354 | bool sameOracle = |
1355 | def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); |
1356 | bool sameCount = def.getNumResults() == getInCrds().size(); |
1357 | if (!oppositeDir || !sameOracle || !sameCount) |
1358 | return failure(); |
1359 | |
1360 | // The definition produces the coordinates in the same order as the input |
1361 | // coordinates. |
1362 | bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), |
1363 | [](auto valuePair) { |
1364 | auto [lhs, rhs] = valuePair; |
1365 | return lhs == rhs; |
1366 | }); |
1367 | |
1368 | if (!sameOrder) |
1369 | return failure(); |
1370 | // l1 = dim2lvl (lvl2dim l0) |
1371 | // ==> l0 |
1372 | results.append(def.getInCrds().begin(), def.getInCrds().end()); |
1373 | return success(); |
1374 | } |
1375 | |
1376 | void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, |
1377 | int64_t index) { |
1378 | Value val = builder.create<arith::ConstantIndexOp>(state.location, index); |
1379 | return build(builder, state, source, val); |
1380 | } |
1381 | |
1382 | LogicalResult LvlOp::verify() { |
1383 | if (std::optional<uint64_t> lvl = getConstantLvlIndex()) { |
1384 | auto stt = getSparseTensorType(getSource()); |
1385 | if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank()) |
1386 | emitError("Level index exceeds the rank of the input sparse tensor" ); |
1387 | } |
1388 | return success(); |
1389 | } |
1390 | |
1391 | std::optional<uint64_t> LvlOp::getConstantLvlIndex() { |
1392 | return getConstantIntValue(getIndex()); |
1393 | } |
1394 | |
1395 | Speculation::Speculatability LvlOp::getSpeculatability() { |
1396 | auto constantIndex = getConstantLvlIndex(); |
1397 | if (!constantIndex) |
1398 | return Speculation::NotSpeculatable; |
1399 | |
1400 | assert(constantIndex < |
1401 | cast<RankedTensorType>(getSource().getType()).getRank()); |
1402 | return Speculation::Speculatable; |
1403 | } |
1404 | |
1405 | OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { |
1406 | auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); |
1407 | if (!lvlIndex) |
1408 | return {}; |
1409 | |
1410 | Level lvl = lvlIndex.getAPSInt().getZExtValue(); |
1411 | auto stt = getSparseTensorType(getSource()); |
1412 | if (lvl >= stt.getLvlRank()) { |
1413 | // Follows the same convention used by tensor.dim operation. Out of bound |
1414 | // indices produce undefined behavior but are still valid IR. Don't choke on |
1415 | // them. |
1416 | return {}; |
1417 | } |
1418 | |
1419 | // Helper lambda to build an IndexAttr. |
1420 | auto getIndexAttr = [this](int64_t lvlSz) { |
1421 | return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); |
1422 | }; |
1423 | |
1424 | SmallVector<Size> lvlShape = stt.getLvlShape(); |
1425 | if (!ShapedType::isDynamic(lvlShape[lvl])) |
1426 | return getIndexAttr(lvlShape[lvl]); |
1427 | |
1428 | return {}; |
1429 | } |
1430 | |
1431 | void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
1432 | SparseTensorEncodingAttr dstEnc, Value source) { |
1433 | auto srcStt = getSparseTensorType(source); |
1434 | SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape(); |
1435 | SmallVector<int64_t> dstDimShape = |
1436 | dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); |
1437 | auto dstTp = |
1438 | RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); |
1439 | return build(odsBuilder, odsState, dstTp, source); |
1440 | } |
1441 | |
1442 | LogicalResult ReinterpretMapOp::verify() { |
1443 | auto srcStt = getSparseTensorType(getSource()); |
1444 | auto dstStt = getSparseTensorType(getDest()); |
1445 | ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes(); |
1446 | ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes(); |
1447 | |
1448 | if (srcLvlTps.size() != dstLvlTps.size()) |
1449 | return emitError("Level rank mismatch between source/dest tensors" ); |
1450 | |
1451 | for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) |
1452 | if (srcLvlTp != dstLvlTp) |
1453 | return emitError("Level type mismatch between source/dest tensors" ); |
1454 | |
1455 | if (srcStt.getPosWidth() != dstStt.getPosWidth() || |
1456 | srcStt.getCrdWidth() != dstStt.getCrdWidth()) { |
1457 | return emitError("Crd/Pos width mismatch between source/dest tensors" ); |
1458 | } |
1459 | |
1460 | if (srcStt.getElementType() != dstStt.getElementType()) |
1461 | return emitError("Element type mismatch between source/dest tensors" ); |
1462 | |
1463 | SmallVector<Size> srcLvlShape = srcStt.getLvlShape(); |
1464 | SmallVector<Size> dstLvlShape = dstStt.getLvlShape(); |
1465 | for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { |
1466 | if (srcLvlSz != dstLvlSz) { |
1467 | // Should we allow one side to be dynamic size, e.g., <?x?> should be |
1468 | // compatible to <3x4>? For now, we require all the level sizes to be |
1469 | // *exactly* matched for simplicity. |
1470 | return emitError("Level size mismatch between source/dest tensors" ); |
1471 | } |
1472 | } |
1473 | |
1474 | return success(); |
1475 | } |
1476 | |
1477 | OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { |
1478 | if (getSource().getType() == getDest().getType()) |
1479 | return getSource(); |
1480 | |
1481 | if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) { |
1482 | // A -> B, B -> A ==> A |
1483 | if (def.getSource().getType() == getDest().getType()) |
1484 | return def.getSource(); |
1485 | } |
1486 | return {}; |
1487 | } |
1488 | |
1489 | template <typename ToBufferOp> |
1490 | static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, |
1491 | OpaqueProperties prop, |
1492 | RegionRange region, |
1493 | SmallVectorImpl<mlir::Type> &ret) { |
1494 | typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); |
1495 | SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); |
1496 | Type elemTp = nullptr; |
1497 | bool withStride = false; |
1498 | if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) { |
1499 | elemTp = stt.getPosType(); |
1500 | } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> || |
1501 | std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) { |
1502 | elemTp = stt.getCrdType(); |
1503 | if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>) |
1504 | withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); |
1505 | } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) { |
1506 | elemTp = stt.getElementType(); |
1507 | } |
1508 | |
1509 | assert(elemTp && "unhandled operation." ); |
1510 | SmallVector<int64_t> bufShape = stt.getBatchLvlShape(); |
1511 | bufShape.push_back(ShapedType::kDynamic); |
1512 | |
1513 | auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( |
1514 | stt.getContext(), ShapedType::kDynamic, |
1515 | {ShapedType::kDynamic}) |
1516 | : StridedLayoutAttr(); |
1517 | ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); |
1518 | return success(); |
1519 | } |
1520 | |
1521 | LogicalResult ToPositionsOp::verify() { |
1522 | auto stt = getSparseTensorType(getTensor()); |
1523 | if (failed(lvlIsInBounds(getLevel(), getTensor()))) |
1524 | return emitError("requested level is out of bounds" ); |
1525 | if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) |
1526 | return emitError("unexpected type for positions" ); |
1527 | return success(); |
1528 | } |
1529 | |
1530 | LogicalResult |
1531 | ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, |
1532 | ValueRange ops, DictionaryAttr attr, |
1533 | OpaqueProperties prop, RegionRange region, |
1534 | SmallVectorImpl<mlir::Type> &ret) { |
1535 | return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret); |
1536 | } |
1537 | |
1538 | LogicalResult ToCoordinatesOp::verify() { |
1539 | auto stt = getSparseTensorType(getTensor()); |
1540 | if (failed(lvlIsInBounds(getLevel(), getTensor()))) |
1541 | return emitError("requested level is out of bounds" ); |
1542 | if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) |
1543 | return emitError("unexpected type for coordinates" ); |
1544 | return success(); |
1545 | } |
1546 | |
1547 | LogicalResult |
1548 | ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, |
1549 | ValueRange ops, DictionaryAttr attr, |
1550 | OpaqueProperties prop, RegionRange region, |
1551 | SmallVectorImpl<mlir::Type> &ret) { |
1552 | return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret); |
1553 | } |
1554 | |
1555 | LogicalResult ToCoordinatesBufferOp::verify() { |
1556 | auto stt = getSparseTensorType(getTensor()); |
1557 | if (stt.getAoSCOOStart() >= stt.getLvlRank()) |
1558 | return emitError("expected sparse tensor with a COO region" ); |
1559 | return success(); |
1560 | } |
1561 | |
1562 | LogicalResult ToCoordinatesBufferOp::inferReturnTypes( |
1563 | MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, |
1564 | DictionaryAttr attr, OpaqueProperties prop, RegionRange region, |
1565 | SmallVectorImpl<mlir::Type> &ret) { |
1566 | return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region, |
1567 | ret); |
1568 | } |
1569 | |
1570 | LogicalResult ToValuesOp::verify() { |
1571 | auto stt = getSparseTensorType(getTensor()); |
1572 | auto mtp = getMemRefType(getResult()); |
1573 | if (stt.getElementType() != mtp.getElementType()) |
1574 | return emitError("unexpected mismatch in element types" ); |
1575 | return success(); |
1576 | } |
1577 | |
1578 | LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, |
1579 | std::optional<Location> loc, |
1580 | ValueRange ops, DictionaryAttr attr, |
1581 | OpaqueProperties prop, |
1582 | RegionRange region, |
1583 | SmallVectorImpl<mlir::Type> &ret) { |
1584 | return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret); |
1585 | } |
1586 | |
1587 | LogicalResult ToSliceOffsetOp::verify() { |
1588 | auto rank = getRankedTensorType(getSlice()).getRank(); |
1589 | if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) |
1590 | return emitError("requested dimension out of bound" ); |
1591 | return success(); |
1592 | } |
1593 | |
1594 | LogicalResult ToSliceStrideOp::verify() { |
1595 | auto rank = getRankedTensorType(getSlice()).getRank(); |
1596 | if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) |
1597 | return emitError("requested dimension out of bound" ); |
1598 | return success(); |
1599 | } |
1600 | |
1601 | LogicalResult GetStorageSpecifierOp::verify() { |
1602 | return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), |
1603 | getSpecifier(), getOperation()); |
1604 | } |
1605 | |
1606 | template <typename SpecifierOp> |
1607 | static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { |
1608 | return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>(); |
1609 | } |
1610 | |
1611 | OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { |
1612 | const StorageSpecifierKind kind = getSpecifierKind(); |
1613 | const auto lvl = getLevel(); |
1614 | for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) |
1615 | if (kind == op.getSpecifierKind() && lvl == op.getLevel()) |
1616 | return op.getValue(); |
1617 | return {}; |
1618 | } |
1619 | |
1620 | LogicalResult SetStorageSpecifierOp::verify() { |
1621 | return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), |
1622 | getSpecifier(), getOperation()); |
1623 | } |
1624 | |
1625 | template <class T> |
1626 | static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, |
1627 | const char *regionName, |
1628 | TypeRange inputTypes, Type outputType) { |
1629 | unsigned numArgs = region.getNumArguments(); |
1630 | unsigned expectedNum = inputTypes.size(); |
1631 | if (numArgs != expectedNum) |
1632 | return op->emitError() << regionName << " region must have exactly " |
1633 | << expectedNum << " arguments" ; |
1634 | |
1635 | for (unsigned i = 0; i < numArgs; i++) { |
1636 | Type typ = region.getArgument(i).getType(); |
1637 | if (typ != inputTypes[i]) |
1638 | return op->emitError() << regionName << " region argument " << (i + 1) |
1639 | << " type mismatch" ; |
1640 | } |
1641 | Operation *term = region.front().getTerminator(); |
1642 | YieldOp yield = dyn_cast<YieldOp>(term); |
1643 | if (!yield) |
1644 | return op->emitError() << regionName |
1645 | << " region must end with sparse_tensor.yield" ; |
1646 | if (!yield.hasSingleResult() || |
1647 | yield.getSingleResult().getType() != outputType) |
1648 | return op->emitError() << regionName << " region yield type mismatch" ; |
1649 | |
1650 | return success(); |
1651 | } |
1652 | |
1653 | LogicalResult BinaryOp::verify() { |
1654 | NamedAttrList attrs = (*this)->getAttrs(); |
1655 | Type leftType = getX().getType(); |
1656 | Type rightType = getY().getType(); |
1657 | Type outputType = getOutput().getType(); |
1658 | Region &overlap = getOverlapRegion(); |
1659 | Region &left = getLeftRegion(); |
1660 | Region &right = getRightRegion(); |
1661 | |
1662 | // Check correct number of block arguments and return type for each |
1663 | // non-empty region. |
1664 | if (!overlap.empty()) { |
1665 | if (failed(verifyNumBlockArgs(this, overlap, "overlap" , |
1666 | TypeRange{leftType, rightType}, outputType))) |
1667 | return failure(); |
1668 | } |
1669 | if (!left.empty()) { |
1670 | if (failed(verifyNumBlockArgs(this, left, "left" , TypeRange{leftType}, |
1671 | outputType))) |
1672 | return failure(); |
1673 | } else if (getLeftIdentity()) { |
1674 | if (leftType != outputType) |
1675 | return emitError("left=identity requires first argument to have the same " |
1676 | "type as the output" ); |
1677 | } |
1678 | if (!right.empty()) { |
1679 | if (failed(verifyNumBlockArgs(this, right, "right" , TypeRange{rightType}, |
1680 | outputType))) |
1681 | return failure(); |
1682 | } else if (getRightIdentity()) { |
1683 | if (rightType != outputType) |
1684 | return emitError("right=identity requires second argument to have the " |
1685 | "same type as the output" ); |
1686 | } |
1687 | return success(); |
1688 | } |
1689 | |
1690 | LogicalResult UnaryOp::verify() { |
1691 | Type inputType = getX().getType(); |
1692 | Type outputType = getOutput().getType(); |
1693 | |
1694 | // Check correct number of block arguments and return type for each |
1695 | // non-empty region. |
1696 | Region &present = getPresentRegion(); |
1697 | if (!present.empty()) { |
1698 | if (failed(verifyNumBlockArgs(this, present, "present" , |
1699 | TypeRange{inputType}, outputType))) |
1700 | return failure(); |
1701 | } |
1702 | Region &absent = getAbsentRegion(); |
1703 | if (!absent.empty()) { |
1704 | if (failed(verifyNumBlockArgs(this, absent, "absent" , TypeRange{}, |
1705 | outputType))) |
1706 | return failure(); |
1707 | // Absent branch can only yield invariant values. |
1708 | Block *absentBlock = &absent.front(); |
1709 | Block *parent = getOperation()->getBlock(); |
1710 | Value absentVal = |
1711 | cast<YieldOp>(absentBlock->getTerminator()).getSingleResult(); |
1712 | if (auto arg = dyn_cast<BlockArgument>(absentVal)) { |
1713 | if (arg.getOwner() == parent) |
1714 | return emitError("absent region cannot yield linalg argument" ); |
1715 | } else if (Operation *def = absentVal.getDefiningOp()) { |
1716 | if (!isa<arith::ConstantOp>(def) && |
1717 | (def->getBlock() == absentBlock || def->getBlock() == parent)) |
1718 | return emitError("absent region cannot yield locally computed value" ); |
1719 | } |
1720 | } |
1721 | return success(); |
1722 | } |
1723 | |
1724 | bool ConcatenateOp::needsExtraSort() { |
1725 | SparseTensorType dstStt = getSparseTensorType(*this); |
1726 | if (dstStt.isAllDense() || !dstStt.isAllOrdered()) |
1727 | return false; |
1728 | |
1729 | bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { |
1730 | return getSparseTensorType(op).hasSameDimToLvl(dstStt); |
1731 | }); |
1732 | // TODO: When conDim != 0, as long as conDim corresponding to the first level |
1733 | // in all input/output buffers, and all input/output buffers have the same |
1734 | // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate |
1735 | // CSC matrices along column). |
1736 | bool directLowerable = |
1737 | allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); |
1738 | return !directLowerable; |
1739 | } |
1740 | |
1741 | LogicalResult ConcatenateOp::verify() { |
1742 | const auto dstTp = getSparseTensorType(*this); |
1743 | const Dimension concatDim = getDimension(); |
1744 | const Dimension dimRank = dstTp.getDimRank(); |
1745 | |
1746 | if (getInputs().size() <= 1) |
1747 | return emitError("Need at least two tensors to concatenate." ); |
1748 | |
1749 | if (concatDim >= dimRank) |
1750 | return emitError(llvm::formatv( |
1751 | "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})" , |
1752 | concatDim, dimRank)); |
1753 | |
1754 | for (const auto &it : llvm::enumerate(getInputs())) { |
1755 | const auto i = it.index(); |
1756 | const auto srcTp = getSparseTensorType(it.value()); |
1757 | if (srcTp.hasDynamicDimShape()) |
1758 | return emitError(llvm::formatv("Input tensor ${0} has dynamic shape" , i)); |
1759 | const Dimension srcDimRank = srcTp.getDimRank(); |
1760 | if (srcDimRank != dimRank) |
1761 | return emitError( |
1762 | llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " |
1763 | "from the output tensor (rank={2})." , |
1764 | i, srcDimRank, dimRank)); |
1765 | } |
1766 | |
1767 | for (Dimension d = 0; d < dimRank; d++) { |
1768 | const Size dstSh = dstTp.getDimShape()[d]; |
1769 | if (d == concatDim) { |
1770 | if (!ShapedType::isDynamic(dstSh)) { |
1771 | // If we reach here, then all inputs have static shapes. So we |
1772 | // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` |
1773 | // to avoid redundant assertions in the loop. |
1774 | Size sumSz = 0; |
1775 | for (const auto src : getInputs()) |
1776 | sumSz += getSparseTensorType(src).getDimShape()[d]; |
1777 | // If all dimension are statically known, the sum of all the input |
1778 | // dimensions should be equal to the output dimension. |
1779 | if (sumSz != dstSh) |
1780 | return emitError( |
1781 | "The concatenation dimension of the output tensor should be the " |
1782 | "sum of all the concatenation dimensions of the input tensors." ); |
1783 | } |
1784 | } else { |
1785 | Size prev = dstSh; |
1786 | for (const auto src : getInputs()) { |
1787 | const auto sh = getSparseTensorType(src).getDimShape()[d]; |
1788 | if (!ShapedType::isDynamic(prev) && sh != prev) |
1789 | return emitError("All dimensions (expect for the concatenating one) " |
1790 | "should be equal." ); |
1791 | prev = sh; |
1792 | } |
1793 | } |
1794 | } |
1795 | |
1796 | return success(); |
1797 | } |
1798 | |
1799 | void PushBackOp::build(OpBuilder &builder, OperationState &result, |
1800 | Value curSize, Value inBuffer, Value value) { |
1801 | build(builder, result, curSize, inBuffer, value, Value()); |
1802 | } |
1803 | |
1804 | LogicalResult PushBackOp::verify() { |
1805 | if (Value n = getN()) { |
1806 | std::optional<int64_t> nValue = getConstantIntValue(n); |
1807 | if (nValue && nValue.value() < 1) |
1808 | return emitOpError("n must be not less than 1" ); |
1809 | } |
1810 | return success(); |
1811 | } |
1812 | |
1813 | LogicalResult CompressOp::verify() { |
1814 | const auto stt = getSparseTensorType(getTensor()); |
1815 | if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size())) |
1816 | return emitOpError("incorrect number of coordinates" ); |
1817 | return success(); |
1818 | } |
1819 | |
1820 | void ForeachOp::build( |
1821 | OpBuilder &builder, OperationState &result, Value tensor, |
1822 | ValueRange initArgs, AffineMapAttr order, |
1823 | function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)> |
1824 | bodyBuilder) { |
1825 | build(builder, result, initArgs.getTypes(), tensor, initArgs, order); |
1826 | // Builds foreach body. |
1827 | if (!bodyBuilder) |
1828 | return; |
1829 | const auto stt = getSparseTensorType(tensor); |
1830 | const Dimension dimRank = stt.getDimRank(); |
1831 | |
1832 | // Starts with `dimRank`-many coordinates. |
1833 | SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType()); |
1834 | // Followed by one value. |
1835 | blockArgTypes.push_back(stt.getElementType()); |
1836 | // Followed by the reduction variables. |
1837 | blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); |
1838 | |
1839 | SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc()); |
1840 | |
1841 | OpBuilder::InsertionGuard guard(builder); |
1842 | auto ®ion = *result.regions.front(); |
1843 | Block *bodyBlock = |
1844 | builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); |
1845 | bodyBuilder(builder, result.location, |
1846 | bodyBlock->getArguments().slice(0, dimRank), |
1847 | bodyBlock->getArguments()[dimRank], |
1848 | bodyBlock->getArguments().drop_front(dimRank + 1)); |
1849 | } |
1850 | |
1851 | LogicalResult ForeachOp::verify() { |
1852 | const auto t = getSparseTensorType(getTensor()); |
1853 | const Dimension dimRank = t.getDimRank(); |
1854 | const auto args = getBody()->getArguments(); |
1855 | |
1856 | if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) |
1857 | return emitError("Level traverse order does not match tensor's level rank" ); |
1858 | |
1859 | if (dimRank + 1 + getInitArgs().size() != args.size()) |
1860 | return emitError("Unmatched number of arguments in the block" ); |
1861 | |
1862 | if (getNumResults() != getInitArgs().size()) |
1863 | return emitError("Mismatch in number of init arguments and results" ); |
1864 | |
1865 | if (getResultTypes() != getInitArgs().getTypes()) |
1866 | return emitError("Mismatch in types of init arguments and results" ); |
1867 | |
1868 | // Cannot mark this const, because the getters aren't. |
1869 | auto yield = cast<YieldOp>(getBody()->getTerminator()); |
1870 | if (yield.getNumOperands() != getNumResults() || |
1871 | yield.getOperands().getTypes() != getResultTypes()) |
1872 | return emitError("Mismatch in types of yield values and results" ); |
1873 | |
1874 | const auto iTp = IndexType::get(getContext()); |
1875 | for (Dimension d = 0; d < dimRank; d++) |
1876 | if (args[d].getType() != iTp) |
1877 | emitError( |
1878 | llvm::formatv("Expecting Index type for argument at index {0}" , d)); |
1879 | |
1880 | const auto elemTp = t.getElementType(); |
1881 | const auto valueTp = args[dimRank].getType(); |
1882 | if (elemTp != valueTp) |
1883 | emitError(llvm::formatv("Unmatched element type between input tensor and " |
1884 | "block argument, expected:{0}, got: {1}" , |
1885 | elemTp, valueTp)); |
1886 | return success(); |
1887 | } |
1888 | |
1889 | OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { |
1890 | if (getSparseTensorEncoding(getInputCoo().getType()) == |
1891 | getSparseTensorEncoding(getResultCoo().getType())) |
1892 | return getInputCoo(); |
1893 | |
1894 | return {}; |
1895 | } |
1896 | |
1897 | LogicalResult ReorderCOOOp::verify() { |
1898 | SparseTensorType srcStt = getSparseTensorType(getInputCoo()); |
1899 | SparseTensorType dstStt = getSparseTensorType(getResultCoo()); |
1900 | |
1901 | if (!srcStt.isCOOType() || !dstStt.isCOOType()) |
1902 | emitError("Expected COO sparse tensors only" ); |
1903 | |
1904 | if (!srcStt.hasSameDimToLvl(dstStt)) |
1905 | emitError("Unmatched dim2lvl map between input and result COO" ); |
1906 | |
1907 | if (srcStt.getPosType() != dstStt.getPosType() || |
1908 | srcStt.getCrdType() != dstStt.getCrdType() || |
1909 | srcStt.getElementType() != dstStt.getElementType()) |
1910 | emitError("Unmatched storage format between input and result COO" ); |
1911 | |
1912 | return success(); |
1913 | } |
1914 | |
1915 | LogicalResult ReduceOp::verify() { |
1916 | Type inputType = getX().getType(); |
1917 | Region &formula = getRegion(); |
1918 | return verifyNumBlockArgs(this, formula, "reduce" , |
1919 | TypeRange{inputType, inputType}, inputType); |
1920 | } |
1921 | |
1922 | LogicalResult SelectOp::verify() { |
1923 | Builder b(getContext()); |
1924 | Type inputType = getX().getType(); |
1925 | Type boolType = b.getI1Type(); |
1926 | Region &formula = getRegion(); |
1927 | return verifyNumBlockArgs(this, formula, "select" , TypeRange{inputType}, |
1928 | boolType); |
1929 | } |
1930 | |
1931 | LogicalResult SortOp::verify() { |
1932 | AffineMap xPerm = getPermMap(); |
1933 | uint64_t nx = xPerm.getNumDims(); |
1934 | if (nx < 1) |
1935 | emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}" , nx)); |
1936 | |
1937 | if (!xPerm.isPermutation()) |
1938 | emitError(llvm::formatv("Expected a permutation map, got {0}" , xPerm)); |
1939 | |
1940 | // We can't check the size of the buffers when n or buffer dimensions aren't |
1941 | // compile-time constants. |
1942 | std::optional<int64_t> cn = getConstantIntValue(getN()); |
1943 | if (!cn) |
1944 | return success(); |
1945 | |
1946 | // Verify dimensions. |
1947 | const auto checkDim = [&](Value v, Size minSize, const char *message) { |
1948 | const Size sh = getMemRefType(v).getShape()[0]; |
1949 | if (!ShapedType::isDynamic(sh) && sh < minSize) |
1950 | emitError(llvm::formatv("{0} got {1} < {2}" , message, sh, minSize)); |
1951 | }; |
1952 | uint64_t n = cn.value(); |
1953 | uint64_t ny = 0; |
1954 | if (auto nyAttr = getNyAttr()) |
1955 | ny = nyAttr.getInt(); |
1956 | checkDim(getXy(), n * (nx + ny), |
1957 | "Expected dimension(xy) >= n * (rank(perm_map) + ny)" ); |
1958 | for (Value opnd : getYs()) |
1959 | checkDim(opnd, n, "Expected dimension(y) >= n" ); |
1960 | |
1961 | return success(); |
1962 | } |
1963 | |
1964 | //===----------------------------------------------------------------------===// |
1965 | // Sparse Tensor Iteration Operations. |
1966 | //===----------------------------------------------------------------------===// |
1967 | |
1968 | IterSpaceType IteratorType::getIterSpaceType() const { |
1969 | return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), |
1970 | getHiLvl()); |
1971 | } |
1972 | |
1973 | IteratorType IterSpaceType::getIteratorType() const { |
1974 | return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); |
1975 | } |
1976 | |
1977 | /// Parses a level range in the form "$lo `to` $hi" |
1978 | /// or simply "$lo" if $hi - $lo = 1 |
1979 | static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, |
1980 | Level &lvlHi) { |
1981 | if (parser.parseInteger(result&: lvlLo)) |
1982 | return failure(); |
1983 | |
1984 | if (succeeded(result: parser.parseOptionalKeyword(keyword: "to" ))) { |
1985 | if (parser.parseInteger(result&: lvlHi)) |
1986 | return failure(); |
1987 | } else { |
1988 | lvlHi = lvlLo + 1; |
1989 | } |
1990 | |
1991 | if (lvlHi <= lvlLo) |
1992 | parser.emitError(loc: parser.getNameLoc(), |
1993 | message: "expect larger level upper bound than lower bound" ); |
1994 | |
1995 | return success(); |
1996 | } |
1997 | |
1998 | /// Parses a level range in the form "$lo `to` $hi" |
1999 | /// or simply "$lo" if $hi - $lo = 1 |
2000 | static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, |
2001 | IntegerAttr &lvlHiAttr) { |
2002 | Level lvlLo, lvlHi; |
2003 | if (parseLevelRange(parser, lvlLo, lvlHi)) |
2004 | return failure(); |
2005 | |
2006 | lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); |
2007 | lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); |
2008 | return success(); |
2009 | } |
2010 | |
2011 | /// Prints a level range in the form "$lo `to` $hi" |
2012 | /// or simply "$lo" if $hi - $lo = 1 |
2013 | static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { |
2014 | |
2015 | if (lo + 1 == hi) |
2016 | p << lo; |
2017 | else |
2018 | p << lo << " to " << hi; |
2019 | } |
2020 | |
2021 | /// Prints a level range in the form "$lo `to` $hi" |
2022 | /// or simply "$lo" if $hi - $lo = 1 |
2023 | static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, |
2024 | IntegerAttr lvlHi) { |
2025 | unsigned lo = lvlLo.getValue().getZExtValue(); |
2026 | unsigned hi = lvlHi.getValue().getZExtValue(); |
2027 | printLevelRange(p, lo, hi); |
2028 | } |
2029 | |
2030 | LogicalResult ExtractIterSpaceOp::inferReturnTypes( |
2031 | MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, |
2032 | DictionaryAttr attr, OpaqueProperties prop, RegionRange region, |
2033 | SmallVectorImpl<mlir::Type> &ret) { |
2034 | |
2035 | ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); |
2036 | SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); |
2037 | ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), |
2038 | adaptor.getHiLvl())); |
2039 | return success(); |
2040 | } |
2041 | |
2042 | LogicalResult ExtractIterSpaceOp::verify() { |
2043 | if (getLoLvl() >= getHiLvl()) |
2044 | return emitOpError("expected smaller level low than level high" ); |
2045 | |
2046 | TypedValue<IteratorType> pIter = getParentIter(); |
2047 | if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { |
2048 | return emitOpError( |
2049 | "parent iterator should be specified iff level lower bound equals 0" ); |
2050 | } |
2051 | |
2052 | if (pIter) { |
2053 | IterSpaceType spaceTp = getResultSpace().getType(); |
2054 | if (pIter.getType().getEncoding() != spaceTp.getEncoding()) |
2055 | return emitOpError( |
2056 | "mismatch in parent iterator encoding and iteration space encoding." ); |
2057 | |
2058 | if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) |
2059 | return emitOpError("parent iterator should be used to extract an " |
2060 | "iteration space from a consecutive level." ); |
2061 | } |
2062 | |
2063 | return success(); |
2064 | } |
2065 | |
2066 | /// Materialize a single constant operation from a given attribute value with |
2067 | /// the desired resultant type. |
2068 | Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, |
2069 | Attribute value, Type type, |
2070 | Location loc) { |
2071 | if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) |
2072 | return op; |
2073 | return nullptr; |
2074 | } |
2075 | |
2076 | namespace { |
2077 | struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { |
2078 | using OpAsmDialectInterface::OpAsmDialectInterface; |
2079 | |
2080 | AliasResult getAlias(Attribute attr, raw_ostream &os) const override { |
2081 | if (isa<SparseTensorEncodingAttr>(attr)) { |
2082 | os << "sparse" ; |
2083 | return AliasResult::OverridableAlias; |
2084 | } |
2085 | return AliasResult::NoAlias; |
2086 | } |
2087 | }; |
2088 | } // namespace |
2089 | |
2090 | void SparseTensorDialect::initialize() { |
2091 | addInterface<SparseTensorAsmDialectInterface>(); |
2092 | addAttributes< |
2093 | #define GET_ATTRDEF_LIST |
2094 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" |
2095 | >(); |
2096 | addTypes< |
2097 | #define GET_TYPEDEF_LIST |
2098 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" |
2099 | >(); |
2100 | addOperations< |
2101 | #define GET_OP_LIST |
2102 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" |
2103 | >(); |
2104 | declarePromisedInterfaces< |
2105 | bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp, |
2106 | NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp, |
2107 | ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>(); |
2108 | } |
2109 | |
2110 | #define GET_OP_CLASSES |
2111 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" |
2112 | |
2113 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" |
2114 | |