1//===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "CodegenUtils.h"
10#include "SparseTensorDescriptor.h"
11
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/Utils/Utils.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/Tensor/IR/Tensor.h"
18#include "mlir/IR/Matchers.h"
19#include "mlir/IR/Types.h"
20#include "mlir/IR/Value.h"
21#include <optional>
22
23using namespace mlir;
24using namespace mlir::sparse_tensor;
25
26//===----------------------------------------------------------------------===//
27// ExecutionEngine/SparseTensorUtils helper functions.
28//===----------------------------------------------------------------------===//
29
30OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
31 switch (width) {
32 case 64:
33 return OverheadType::kU64;
34 case 32:
35 return OverheadType::kU32;
36 case 16:
37 return OverheadType::kU16;
38 case 8:
39 return OverheadType::kU8;
40 case 0:
41 return OverheadType::kIndex;
42 }
43 llvm_unreachable("Unsupported overhead bitwidth");
44}
45
46OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
47 if (tp.isIndex())
48 return OverheadType::kIndex;
49 if (auto intTp = dyn_cast<IntegerType>(tp))
50 return overheadTypeEncoding(intTp.getWidth());
51 llvm_unreachable("Unknown overhead type");
52}
53
54Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
55 switch (ot) {
56 case OverheadType::kIndex:
57 return builder.getIndexType();
58 case OverheadType::kU64:
59 return builder.getIntegerType(64);
60 case OverheadType::kU32:
61 return builder.getIntegerType(32);
62 case OverheadType::kU16:
63 return builder.getIntegerType(16);
64 case OverheadType::kU8:
65 return builder.getIntegerType(8);
66 }
67 llvm_unreachable("Unknown OverheadType");
68}
69
70OverheadType
71mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) {
72 return overheadTypeEncoding(enc.getPosWidth());
73}
74
75OverheadType
76mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) {
77 return overheadTypeEncoding(enc.getCrdWidth());
78}
79
80// TODO: we ought to add some `static_assert` tests to ensure that the
81// `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder,
82// {pos,crd}OverheadTypeEncoding(enc))`
83
84// TODO: Adjust the naming convention for the constructors of
85// `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro
86// here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce
87// the possibility of typo bugs or things getting out of sync.
88StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
89 switch (ot) {
90 case OverheadType::kIndex:
91 return "0";
92#define CASE(ONAME, O) \
93 case OverheadType::kU##ONAME: \
94 return #ONAME;
95 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(CASE)
96#undef CASE
97 }
98 llvm_unreachable("Unknown OverheadType");
99}
100
101StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
102 return overheadTypeFunctionSuffix(ot: overheadTypeEncoding(tp));
103}
104
105PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
106 if (elemTp.isF64())
107 return PrimaryType::kF64;
108 if (elemTp.isF32())
109 return PrimaryType::kF32;
110 if (elemTp.isF16())
111 return PrimaryType::kF16;
112 if (elemTp.isBF16())
113 return PrimaryType::kBF16;
114 if (elemTp.isInteger(width: 64))
115 return PrimaryType::kI64;
116 if (elemTp.isInteger(width: 32))
117 return PrimaryType::kI32;
118 if (elemTp.isInteger(width: 16))
119 return PrimaryType::kI16;
120 if (elemTp.isInteger(width: 8))
121 return PrimaryType::kI8;
122 if (auto complexTp = dyn_cast<ComplexType>(elemTp)) {
123 auto complexEltTp = complexTp.getElementType();
124 if (complexEltTp.isF64())
125 return PrimaryType::kC64;
126 if (complexEltTp.isF32())
127 return PrimaryType::kC32;
128 }
129 llvm_unreachable("Unknown primary type");
130}
131
132StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
133 switch (pt) {
134#define CASE(VNAME, V) \
135 case PrimaryType::k##VNAME: \
136 return #VNAME;
137 MLIR_SPARSETENSOR_FOREVERY_V(CASE)
138#undef CASE
139 }
140 llvm_unreachable("Unknown PrimaryType");
141}
142
143StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
144 return primaryTypeFunctionSuffix(pt: primaryTypeEncoding(elemTp));
145}
146
147//===----------------------------------------------------------------------===//
148// Misc code generators.
149//===----------------------------------------------------------------------===//
150
151Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
152 Type dstTp) {
153 const Type srcTp = value.getType();
154 if (srcTp == dstTp)
155 return value;
156
157 // int <=> index
158 if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp))
159 return builder.create<arith::IndexCastOp>(loc, dstTp, value);
160
161 const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp);
162 const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
163 return mlir::convertScalarToDtype(b&: builder, loc, operand: value, toType: dstTp, isUnsignedCast);
164}
165
166Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
167 Value elem, Type dstTp) {
168 if (auto rtp = dyn_cast<RankedTensorType>(dstTp)) {
169 // Scalars can only be converted to 0-ranked tensors.
170 assert(rtp.getRank() == 0);
171 elem = sparse_tensor::genCast(builder, loc, value: elem, dstTp: rtp.getElementType());
172 return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
173 }
174 return sparse_tensor::genCast(builder, loc, value: elem, dstTp);
175}
176
177Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
178 ValueRange s) {
179 Value load = builder.create<memref::LoadOp>(loc, mem, s);
180 if (!isa<IndexType>(Val: load.getType())) {
181 if (load.getType().getIntOrFloatBitWidth() < 64)
182 load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
183 load =
184 builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
185 }
186 return load;
187}
188
189mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
190 if (isa<FloatType>(Val: tp))
191 return builder.getFloatAttr(tp, 1.0);
192 if (isa<IndexType>(Val: tp))
193 return builder.getIndexAttr(1);
194 if (auto intTp = dyn_cast<IntegerType>(tp))
195 return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
196 if (isa<RankedTensorType, VectorType>(Val: tp)) {
197 auto shapedTp = cast<ShapedType>(tp);
198 if (auto one = getOneAttr(builder, shapedTp.getElementType()))
199 return DenseElementsAttr::get(shapedTp, one);
200 }
201 llvm_unreachable("Unsupported attribute type");
202}
203
204Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
205 Value v) {
206 Type tp = v.getType();
207 Value zero = constantZero(builder, loc, tp);
208 if (isa<FloatType>(tp))
209 return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
210 zero);
211 if (tp.isIntOrIndex())
212 return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
213 zero);
214 if (dyn_cast<ComplexType>(tp))
215 return builder.create<complex::NotEqualOp>(loc, v, zero);
216 llvm_unreachable("Non-numeric type");
217}
218
219void mlir::sparse_tensor::genReshapeDstShape(
220 OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
221 ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
222 ArrayRef<ReassociationIndices> reassociation) {
223 // Collapse shape.
224 if (reassociation.size() < srcShape.size()) {
225 unsigned start = 0;
226 for (const auto &map : llvm::enumerate(First&: reassociation)) {
227 auto dstDim = constantIndex(builder, loc, i: 1);
228 for (unsigned i = start; i < start + map.value().size(); i++) {
229 dstDim = builder.create<arith::MulIOp>(loc, dstDim, srcShape[i]);
230 }
231 dstShape.push_back(Elt: dstDim);
232 start = start + map.value().size();
233 }
234 assert(start == srcShape.size());
235 return;
236 }
237
238 // Expand shape.
239 assert(reassociation.size() == srcShape.size());
240 unsigned start = 0;
241 // Expand the i-th dimension in srcShape.
242 for (unsigned i = 0, size = srcShape.size(); i < size; i++) {
243 const auto &map = reassociation[i];
244 auto srcDim = srcShape[i];
245 // Iterate through dimensions expanded from the i-th dimension.
246 for (unsigned j = start; j < start + map.size(); j++) {
247 // There can be only one dynamic sized dimension among dimensions
248 // expanded from the i-th dimension in srcShape.
249 // For example, if srcDim = 8, then the expanded shape could be <2x?x2>,
250 // but not <2x?x?>.
251 if (staticDstShape[j] == ShapedType::kDynamic) {
252 // The expanded dimension has dynamic size. We compute the dimension
253 // by dividing srcDim by the product of the static dimensions.
254 Size product = 1;
255 for (unsigned k = start; k < start + map.size(); k++) {
256 if (staticDstShape[k] != ShapedType::kDynamic) {
257 product *= staticDstShape[k];
258 }
259 }
260 // Compute the dynamic dimension size.
261 Value productVal = constantIndex(builder, loc, i: product);
262 Value dynamicSize =
263 builder.create<arith::DivUIOp>(loc, srcDim, productVal);
264 dstShape.push_back(Elt: dynamicSize);
265 } else {
266 // The expanded dimension is statically known.
267 dstShape.push_back(Elt: constantIndex(builder, loc, i: staticDstShape[j]));
268 }
269 }
270 start = start + map.size();
271 }
272 assert(start == staticDstShape.size());
273}
274
275void mlir::sparse_tensor::reshapeCvs(
276 OpBuilder &builder, Location loc,
277 ArrayRef<ReassociationIndices> reassociation, // NOLINT
278 ValueRange srcSizes, ValueRange srcCvs, // NOLINT
279 ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) {
280 const unsigned srcRank = srcSizes.size();
281 const unsigned dstRank = dstSizes.size();
282 assert(srcRank == srcCvs.size() && "Source rank mismatch");
283 const bool isCollapse = srcRank > dstRank;
284 const ValueRange sizes = isCollapse ? srcSizes : dstSizes;
285 // Iterate over reassociation map.
286 unsigned i = 0;
287 unsigned start = 0;
288 for (const auto &map : llvm::enumerate(First&: reassociation)) {
289 // Prepare strides information in dimension slice.
290 Value linear = constantIndex(builder, loc, i: 1);
291 for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
292 linear = builder.create<arith::MulIOp>(loc, linear, sizes[j]);
293 }
294 // Start expansion.
295 Value val;
296 if (!isCollapse)
297 val = srcCvs[i];
298 // Iterate over dimension slice.
299 for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
300 linear = builder.create<arith::DivUIOp>(loc, linear, sizes[j]);
301 if (isCollapse) {
302 const Value mul = builder.create<arith::MulIOp>(loc, srcCvs[j], linear);
303 val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul;
304 } else {
305 const Value old = val;
306 val = builder.create<arith::DivUIOp>(loc, val, linear);
307 assert(dstCvs.size() == j);
308 dstCvs.push_back(Elt: val);
309 val = builder.create<arith::RemUIOp>(loc, old, linear);
310 }
311 }
312 // Finalize collapse.
313 if (isCollapse) {
314 assert(dstCvs.size() == i);
315 dstCvs.push_back(Elt: val);
316 }
317 start += map.value().size();
318 i++;
319 }
320 assert(dstCvs.size() == dstRank);
321}
322
323FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name,
324 TypeRange resultType,
325 ValueRange operands,
326 EmitCInterface emitCInterface) {
327 MLIRContext *context = module.getContext();
328 auto result = SymbolRefAttr::get(context, name);
329 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
330 if (!func) {
331 OpBuilder moduleBuilder(module.getBodyRegion());
332 func = moduleBuilder.create<func::FuncOp>(
333 module.getLoc(), name,
334 FunctionType::get(context, operands.getTypes(), resultType));
335 func.setPrivate();
336 if (static_cast<bool>(emitCInterface))
337 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
338 UnitAttr::get(context));
339 }
340 return result;
341}
342
343func::CallOp mlir::sparse_tensor::createFuncCall(
344 OpBuilder &builder, Location loc, StringRef name, TypeRange resultType,
345 ValueRange operands, EmitCInterface emitCInterface) {
346 auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
347 FlatSymbolRefAttr fn =
348 getFunc(module, name, resultType, operands, emitCInterface);
349 return builder.create<func::CallOp>(loc, resultType, fn, operands);
350}
351
352Type mlir::sparse_tensor::getOpaquePointerType(MLIRContext *ctx) {
353 return LLVM::LLVMPointerType::get(ctx);
354}
355
356Type mlir::sparse_tensor::getOpaquePointerType(Builder &builder) {
357 return getOpaquePointerType(ctx: builder.getContext());
358}
359
360Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc,
361 unsigned sz, Type tp, bool staticShape) {
362 if (staticShape) {
363 auto memTp = MemRefType::get({sz}, tp);
364 return builder.create<memref::AllocaOp>(loc, memTp);
365 }
366 return genAlloca(builder, loc, sz: constantIndex(builder, loc, i: sz), tp);
367}
368
369Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz,
370 Type tp) {
371 auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
372 return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz});
373}
374
375Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc,
376 Type tp) {
377 return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
378}
379
380Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc,
381 ValueRange values) {
382 const unsigned sz = values.size();
383 assert(sz >= 1);
384 Value buffer = genAlloca(builder, loc, sz, tp: values[0].getType());
385 for (unsigned i = 0; i < sz; i++) {
386 Value idx = constantIndex(builder, loc, i);
387 builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
388 }
389 return buffer;
390}
391
392Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc,
393 RankedTensorType tensorTp,
394 ValueRange sizes) {
395 Type elemTp = tensorTp.getElementType();
396 auto shape = tensorTp.getShape();
397 auto memTp = MemRefType::get(shape, elemTp);
398 SmallVector<Value> dynamicSizes;
399 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) {
400 if (shape[i] == ShapedType::kDynamic)
401 dynamicSizes.push_back(Elt: sizes[i]);
402 }
403 Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes);
404 Value zero = constantZero(builder, loc, tp: elemTp);
405 builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem});
406 return mem;
407}
408
409void mlir::sparse_tensor::deallocDenseTensor(OpBuilder &builder, Location loc,
410 Value buffer) {
411 builder.create<memref::DeallocOp>(loc, buffer);
412}
413
414void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder,
415 SmallVectorImpl<Value> &sizes,
416 Location loc, Value src) {
417 const Dimension dimRank = getSparseTensorType(val: src).getDimRank();
418 for (Dimension d = 0; d < dimRank; d++)
419 sizes.push_back(Elt: linalg::createOrFoldDimOp(b&: builder, loc, val: src, dim: d));
420}
421
422Operation *mlir::sparse_tensor::getTop(Operation *op) {
423 for (; isa<scf::ForOp>(op->getParentOp()) ||
424 isa<scf::WhileOp>(op->getParentOp()) ||
425 isa<scf::ParallelOp>(op->getParentOp()) ||
426 isa<scf::IfOp>(op->getParentOp());
427 op = op->getParentOp())
428 ;
429 return op;
430}
431
432void sparse_tensor::foreachInSparseConstant(
433 OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
434 function_ref<void(ArrayRef<Value>, Value)> callback) {
435 if (!order)
436 order = builder.getMultiDimIdentityMap(rank: attr.getType().getRank());
437
438 auto stt = SparseTensorType(getRankedTensorType(attr));
439 const Dimension dimRank = stt.getDimRank();
440 const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
441 const auto values = attr.getValues().getValues<Attribute>();
442
443 // This is like the `Element<V>` class in the runtime library, but for
444 // MLIR attributes. In the future we may want to move this out into
445 // a proper class definition to help improve code legibility (e.g.,
446 // `first` -> `coords`, `second` -> `value`) as well as being able
447 // to factor out analogues of `ElementLT<V>` for the sort below, etc.
448 using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>;
449
450 // Construct the COO from the SparseElementsAttr.
451 SmallVector<ElementAttr> elems;
452 for (size_t i = 0, nse = values.size(); i < nse; i++) {
453 elems.emplace_back();
454 elems.back().second = values[i];
455 auto &coords = elems.back().first;
456 coords.reserve(dimRank);
457 for (Dimension d = 0; d < dimRank; d++)
458 coords.push_back(coordinates[i * dimRank + d]);
459 }
460
461 // Sorts the sparse element attribute based on coordinates.
462 std::sort(elems.begin(), elems.end(),
463 [order](const ElementAttr &lhs, const ElementAttr &rhs) {
464 if (std::addressof(lhs) == std::addressof(rhs))
465 return false;
466
467 auto lhsCoords = llvm::map_to_vector(
468 lhs.first, [](IntegerAttr i) { return i.getInt(); });
469 auto rhsCoords = llvm::map_to_vector(
470 rhs.first, [](IntegerAttr i) { return i.getInt(); });
471
472 SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords);
473 SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords);
474 // Sort the element based on the lvl coordinates.
475 for (Level l = 0; l < order.getNumResults(); l++) {
476 if (lhsLvlCrds[l] == rhsLvlCrds[l])
477 continue;
478 return lhsLvlCrds[l] < rhsLvlCrds[l];
479 }
480 llvm_unreachable("no equal coordinate in sparse element attr");
481 });
482
483 SmallVector<Value> cvs;
484 cvs.reserve(N: dimRank);
485 for (size_t i = 0, nse = values.size(); i < nse; i++) {
486 // Remap coordinates.
487 cvs.clear();
488 for (Dimension d = 0; d < dimRank; d++) {
489 auto crd = elems[i].first[d].getInt();
490 cvs.push_back(Elt: builder.create<arith::ConstantIndexOp>(loc, crd));
491 }
492 // Remap value.
493 Value val;
494 if (isa<ComplexType>(attr.getElementType())) {
495 auto valAttr = cast<ArrayAttr>(elems[i].second);
496 val = builder.create<complex::ConstantOp>(loc, attr.getElementType(),
497 valAttr);
498 } else {
499 auto valAttr = cast<TypedAttr>(elems[i].second);
500 val = builder.create<arith::ConstantOp>(loc, valAttr);
501 }
502 assert(val);
503 callback(cvs, val);
504 }
505}
506
507SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc,
508 size_t size, Value mem,
509 size_t offsetIdx, Value offsetVal) {
510#ifndef NDEBUG
511 const auto memTp = cast<MemRefType>(mem.getType());
512 assert(memTp.getRank() == 1);
513 const Size memSh = memTp.getDimSize(0);
514 assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size));
515 assert(offsetIdx == 0 || offsetIdx < size);
516#endif // NDEBUG
517 SmallVector<Value> vs;
518 vs.reserve(N: size);
519 for (unsigned i = 0; i < size; i++) {
520 Value v = builder.create<memref::LoadOp>(loc, mem,
521 constantIndex(builder, loc, i));
522 if (i == offsetIdx && offsetVal)
523 v = builder.create<arith::AddIOp>(loc, v, offsetVal);
524 vs.push_back(Elt: v);
525 }
526 return vs;
527}
528
529void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
530 ValueRange vs, size_t offsetIdx, Value offsetVal) {
531#ifndef NDEBUG
532 const size_t vsize = vs.size();
533 const auto memTp = cast<MemRefType>(mem.getType());
534 assert(memTp.getRank() == 1);
535 const Size memSh = memTp.getDimSize(0);
536 assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize));
537 assert(offsetIdx == 0 || offsetIdx < vsize);
538#endif // NDEBUG
539 for (const auto &v : llvm::enumerate(First&: vs)) {
540 const Value w =
541 (offsetIdx == v.index() && offsetVal)
542 ? builder.create<arith::AddIOp>(loc, v.value(), offsetVal)
543 : v.value();
544 builder.create<memref::StoreOp>(loc, w, mem,
545 constantIndex(builder, loc, v.index()));
546 }
547}
548
549TypedValue<BaseMemRefType>
550sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
551 auto tTp = llvm::cast<TensorType>(Val: tensor.getType());
552 auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
553 return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
554 .getResult();
555}
556
557Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
558 Value tensor) {
559 return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
560}
561
562Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
563 Value tensor, Dimension dim) {
564 auto enc = getSparseTensorEncoding(tensor.getType());
565 assert(enc && enc.isSlice());
566 std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim);
567 if (offset.has_value())
568 return constantIndex(builder, loc, i: *offset);
569 return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim));
570}
571
572Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
573 Value tensor, Dimension dim) {
574 auto enc = getSparseTensorEncoding(tensor.getType());
575 assert(enc && enc.isSlice());
576 std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim);
577 if (stride.has_value())
578 return constantIndex(builder, loc, i: *stride);
579 return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
580}
581
582Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
583 SparseTensorType stt, Value tensor,
584 /*out*/ SmallVectorImpl<Value> &dimSizesValues,
585 /*out*/ Value &dimSizesBuffer) {
586 // Construct the dimension **shapes** buffer. The buffer contains the static
587 // size per dimension, or otherwise a zero for a dynamic size.
588 Dimension dimRank = stt.getDimRank();
589 dimSizesValues.clear();
590 dimSizesValues.reserve(N: dimRank);
591 for (const Size sz : stt.getDimShape()) {
592 const auto s = ShapedType::isDynamic(sz) ? 0 : sz;
593 dimSizesValues.push_back(Elt: constantIndex(builder, loc, s));
594 }
595 Value dimShapesBuffer = allocaBuffer(builder, loc, values: dimSizesValues);
596 // Create the `CheckedSparseTensorReader`. This reader performs a
597 // consistency check on the static sizes, but accepts any size
598 // of each dimension with a dynamic size.
599 Type opaqueTp = getOpaquePointerType(builder);
600 Type eltTp = stt.getElementType();
601 Value valTp = constantPrimaryTypeEncoding(builder, loc, elemTp: eltTp);
602 Value reader =
603 createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp,
604 {tensor, dimShapesBuffer, valTp}, EmitCInterface::On)
605 .getResult(0);
606 // For static shapes, the shape buffer can be used right away. For dynamic
607 // shapes, use the information from the reader to construct a buffer that
608 // supplies the actual size for each dynamic dimension.
609 dimSizesBuffer = dimShapesBuffer;
610 if (stt.hasDynamicDimShape()) {
611 Type indexTp = builder.getIndexType();
612 auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
613 dimSizesBuffer =
614 createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp,
615 reader, EmitCInterface::On)
616 .getResult(0);
617 // Also convert the dim shapes values into dim sizes values, just in case
618 // subsequent clients need the values (DCE will remove unused).
619 for (Dimension d = 0; d < dimRank; d++) {
620 if (stt.isDynamicDim(d))
621 dimSizesValues[d] = builder.create<memref::LoadOp>(
622 loc, dimSizesBuffer, constantIndex(builder, loc, d));
623 }
624 }
625 return reader;
626}
627
628Value sparse_tensor::genMapBuffers(
629 OpBuilder &builder, Location loc, SparseTensorType stt,
630 ArrayRef<Value> dimSizesValues, Value dimSizesBuffer,
631 /*out*/ SmallVectorImpl<Value> &lvlSizesValues,
632 /*out*/ Value &dim2lvlBuffer,
633 /*out*/ Value &lvl2dimBuffer) {
634 const Dimension dimRank = stt.getDimRank();
635 const Level lvlRank = stt.getLvlRank();
636 lvlSizesValues.clear();
637 lvlSizesValues.reserve(N: lvlRank);
638 // For an identity mapping, the dim2lvl and lvl2dim mappings are
639 // identical as are dimSizes and lvlSizes, so buffers are reused
640 // as much as possible.
641 if (stt.isIdentity()) {
642 assert(dimRank == lvlRank);
643 SmallVector<Value> iotaValues;
644 iotaValues.reserve(N: lvlRank);
645 for (Level l = 0; l < lvlRank; l++) {
646 iotaValues.push_back(Elt: constantIndex(builder, loc, i: l));
647 lvlSizesValues.push_back(Elt: dimSizesValues[l]);
648 }
649 dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, values: iotaValues);
650 return dimSizesBuffer; // now lvlSizesBuffer
651 }
652 // Otherwise, some code needs to be generated to set up the buffers.
653 // This code deals with permutations as well as non-permutations that
654 // arise from rank changing blocking.
655 const auto dimToLvl = stt.getDimToLvl();
656 const auto lvlToDim = stt.getLvlToDim();
657 SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
658 SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
659 // Generate dim2lvl.
660 assert(lvlRank == dimToLvl.getNumResults());
661 for (Level l = 0; l < lvlRank; l++) {
662 AffineExpr exp = dimToLvl.getResult(idx: l);
663 // We expect:
664 // (1) l = d
665 // (2) l = d / c
666 // (3) l = d % c
667 Dimension d = 0;
668 uint64_t cf = 0, cm = 0;
669 switch (exp.getKind()) {
670 case AffineExprKind::DimId: {
671 d = cast<AffineDimExpr>(Val&: exp).getPosition();
672 break;
673 }
674 case AffineExprKind::FloorDiv: {
675 auto floor = cast<AffineBinaryOpExpr>(Val&: exp);
676 d = cast<AffineDimExpr>(Val: floor.getLHS()).getPosition();
677 cf = cast<AffineConstantExpr>(Val: floor.getRHS()).getValue();
678 break;
679 }
680 case AffineExprKind::Mod: {
681 auto mod = cast<AffineBinaryOpExpr>(Val&: exp);
682 d = cast<AffineDimExpr>(Val: mod.getLHS()).getPosition();
683 cm = cast<AffineConstantExpr>(Val: mod.getRHS()).getValue();
684 break;
685 }
686 default:
687 llvm::report_fatal_error(reason: "unsupported dim2lvl in sparse tensor type");
688 }
689 dim2lvlValues[l] = constantIndex(builder, loc, i: encodeDim(i: d, cf, cm));
690 // Compute the level sizes.
691 // (1) l = d : size(d)
692 // (2) l = d / c : size(d) / c
693 // (3) l = d % c : c
694 Value lvlSz;
695 if (cm == 0) {
696 lvlSz = dimSizesValues[d];
697 if (cf != 0)
698 lvlSz = builder.create<arith::DivUIOp>(loc, lvlSz,
699 constantIndex(builder, loc, cf));
700 } else {
701 lvlSz = constantIndex(builder, loc, i: cm);
702 }
703 lvlSizesValues.push_back(Elt: lvlSz);
704 }
705 // Generate lvl2dim.
706 assert(dimRank == lvlToDim.getNumResults());
707 for (Dimension d = 0; d < dimRank; d++) {
708 AffineExpr exp = lvlToDim.getResult(idx: d);
709 // We expect:
710 // (1) d = l
711 // (2) d = l' * c + l
712 Level l = 0, ll = 0;
713 uint64_t c = 0;
714 switch (exp.getKind()) {
715 case AffineExprKind::DimId: {
716 l = cast<AffineDimExpr>(Val&: exp).getPosition();
717 break;
718 }
719 case AffineExprKind::Add: {
720 // Always mul on lhs, symbol/constant on rhs.
721 auto add = cast<AffineBinaryOpExpr>(Val&: exp);
722 assert(add.getLHS().getKind() == AffineExprKind::Mul);
723 auto mul = cast<AffineBinaryOpExpr>(Val: add.getLHS());
724 ll = cast<AffineDimExpr>(Val: mul.getLHS()).getPosition();
725 c = cast<AffineConstantExpr>(Val: mul.getRHS()).getValue();
726 l = cast<AffineDimExpr>(Val: add.getRHS()).getPosition();
727 break;
728 }
729 default:
730 llvm::report_fatal_error(reason: "unsupported lvl2dim in sparse tensor type");
731 }
732 lvl2dimValues[d] = constantIndex(builder, loc, i: encodeLvl(i: l, c, ii: ll));
733 }
734 // Return buffers.
735 dim2lvlBuffer = allocaBuffer(builder, loc, values: dim2lvlValues);
736 lvl2dimBuffer = allocaBuffer(builder, loc, values: lvl2dimValues);
737 return allocaBuffer(builder, loc, values: lvlSizesValues); // lvlSizesBuffer
738}
739

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp