1//===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "CodegenUtils.h"
10
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Linalg/IR/Linalg.h"
13#include "mlir/Dialect/Linalg/Utils/Utils.h"
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/Types.h"
17#include "mlir/IR/Value.h"
18#include <optional>
19
20using namespace mlir;
21using namespace mlir::sparse_tensor;
22
23//===----------------------------------------------------------------------===//
24// ExecutionEngine/SparseTensorUtils helper functions.
25//===----------------------------------------------------------------------===//
26
27OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
28 switch (width) {
29 case 64:
30 return OverheadType::kU64;
31 case 32:
32 return OverheadType::kU32;
33 case 16:
34 return OverheadType::kU16;
35 case 8:
36 return OverheadType::kU8;
37 case 0:
38 return OverheadType::kIndex;
39 }
40 llvm_unreachable("Unsupported overhead bitwidth");
41}
42
43OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
44 if (tp.isIndex())
45 return OverheadType::kIndex;
46 if (auto intTp = dyn_cast<IntegerType>(Val&: tp))
47 return overheadTypeEncoding(width: intTp.getWidth());
48 llvm_unreachable("Unknown overhead type");
49}
50
51Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
52 switch (ot) {
53 case OverheadType::kIndex:
54 return builder.getIndexType();
55 case OverheadType::kU64:
56 return builder.getIntegerType(width: 64);
57 case OverheadType::kU32:
58 return builder.getIntegerType(width: 32);
59 case OverheadType::kU16:
60 return builder.getIntegerType(width: 16);
61 case OverheadType::kU8:
62 return builder.getIntegerType(width: 8);
63 }
64 llvm_unreachable("Unknown OverheadType");
65}
66
67OverheadType
68mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) {
69 return overheadTypeEncoding(width: enc.getPosWidth());
70}
71
72OverheadType
73mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) {
74 return overheadTypeEncoding(width: enc.getCrdWidth());
75}
76
77// TODO: we ought to add some `static_assert` tests to ensure that the
78// `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder,
79// {pos,crd}OverheadTypeEncoding(enc))`
80
81// TODO: Adjust the naming convention for the constructors of
82// `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro
83// here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce
84// the possibility of typo bugs or things getting out of sync.
85StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
86 switch (ot) {
87 case OverheadType::kIndex:
88 return "0";
89#define CASE(ONAME, O) \
90 case OverheadType::kU##ONAME: \
91 return #ONAME;
92 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(CASE)
93#undef CASE
94 }
95 llvm_unreachable("Unknown OverheadType");
96}
97
98StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
99 return overheadTypeFunctionSuffix(ot: overheadTypeEncoding(tp));
100}
101
102PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
103 if (elemTp.isF64())
104 return PrimaryType::kF64;
105 if (elemTp.isF32())
106 return PrimaryType::kF32;
107 if (elemTp.isF16())
108 return PrimaryType::kF16;
109 if (elemTp.isBF16())
110 return PrimaryType::kBF16;
111 if (elemTp.isInteger(width: 64))
112 return PrimaryType::kI64;
113 if (elemTp.isInteger(width: 32))
114 return PrimaryType::kI32;
115 if (elemTp.isInteger(width: 16))
116 return PrimaryType::kI16;
117 if (elemTp.isInteger(width: 8))
118 return PrimaryType::kI8;
119 if (auto complexTp = dyn_cast<ComplexType>(Val&: elemTp)) {
120 auto complexEltTp = complexTp.getElementType();
121 if (complexEltTp.isF64())
122 return PrimaryType::kC64;
123 if (complexEltTp.isF32())
124 return PrimaryType::kC32;
125 }
126 llvm_unreachable("Unknown primary type");
127}
128
129StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
130 switch (pt) {
131#define CASE(VNAME, V) \
132 case PrimaryType::k##VNAME: \
133 return #VNAME;
134 MLIR_SPARSETENSOR_FOREVERY_V(CASE)
135#undef CASE
136 }
137 llvm_unreachable("Unknown PrimaryType");
138}
139
140StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
141 return primaryTypeFunctionSuffix(pt: primaryTypeEncoding(elemTp));
142}
143
144//===----------------------------------------------------------------------===//
145// Misc code generators.
146//===----------------------------------------------------------------------===//
147
148Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
149 Type dstTp) {
150 const Type srcTp = value.getType();
151 if (srcTp == dstTp)
152 return value;
153
154 // int <=> index
155 if (isa<IndexType>(Val: srcTp) || isa<IndexType>(Val: dstTp))
156 return builder.create<arith::IndexCastOp>(location: loc, args&: dstTp, args&: value);
157
158 const auto srcIntTp = dyn_cast_or_null<IntegerType>(Val: srcTp);
159 const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
160 return mlir::convertScalarToDtype(b&: builder, loc, operand: value, toType: dstTp, isUnsignedCast);
161}
162
163Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
164 Value elem, Type dstTp) {
165 if (auto rtp = dyn_cast<RankedTensorType>(Val&: dstTp)) {
166 // Scalars can only be converted to 0-ranked tensors.
167 assert(rtp.getRank() == 0);
168 elem = sparse_tensor::genCast(builder, loc, value: elem, dstTp: rtp.getElementType());
169 return builder.create<tensor::FromElementsOp>(location: loc, args&: rtp, args&: elem);
170 }
171 return sparse_tensor::genCast(builder, loc, value: elem, dstTp);
172}
173
174Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
175 ValueRange s) {
176 Value load = builder.create<memref::LoadOp>(location: loc, args&: mem, args&: s);
177 if (!isa<IndexType>(Val: load.getType())) {
178 if (load.getType().getIntOrFloatBitWidth() < 64)
179 load = builder.create<arith::ExtUIOp>(location: loc, args: builder.getI64Type(), args&: load);
180 load =
181 builder.create<arith::IndexCastOp>(location: loc, args: builder.getIndexType(), args&: load);
182 }
183 return load;
184}
185
186mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
187 if (isa<FloatType>(Val: tp))
188 return builder.getFloatAttr(type: tp, value: 1.0);
189 if (isa<IndexType>(Val: tp))
190 return builder.getIndexAttr(value: 1);
191 if (auto intTp = dyn_cast<IntegerType>(Val&: tp))
192 return builder.getIntegerAttr(type: tp, value: APInt(intTp.getWidth(), 1));
193 if (isa<RankedTensorType, VectorType>(Val: tp)) {
194 auto shapedTp = cast<ShapedType>(Val&: tp);
195 if (auto one = getOneAttr(builder, tp: shapedTp.getElementType()))
196 return DenseElementsAttr::get(type: shapedTp, values: one);
197 }
198 llvm_unreachable("Unsupported attribute type");
199}
200
201Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
202 Value v) {
203 Type tp = v.getType();
204 Value zero = constantZero(builder, loc, tp);
205 if (isa<FloatType>(Val: tp))
206 return builder.create<arith::CmpFOp>(location: loc, args: arith::CmpFPredicate::UNE, args&: v,
207 args&: zero);
208 if (tp.isIntOrIndex())
209 return builder.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::ne, args&: v,
210 args&: zero);
211 if (isa<ComplexType>(Val: tp))
212 return builder.create<complex::NotEqualOp>(location: loc, args&: v, args&: zero);
213 llvm_unreachable("Non-numeric type");
214}
215
216void mlir::sparse_tensor::genReshapeDstShape(
217 OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
218 ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
219 ArrayRef<ReassociationIndices> reassociation) {
220 // Collapse shape.
221 if (reassociation.size() < srcShape.size()) {
222 unsigned start = 0;
223 for (const auto &map : llvm::enumerate(First&: reassociation)) {
224 auto dstDim = constantIndex(builder, loc, i: 1);
225 for (unsigned i = start; i < start + map.value().size(); i++) {
226 dstDim = builder.create<arith::MulIOp>(location: loc, args&: dstDim, args: srcShape[i]);
227 }
228 dstShape.push_back(Elt: dstDim);
229 start = start + map.value().size();
230 }
231 assert(start == srcShape.size());
232 return;
233 }
234
235 // Expand shape.
236 assert(reassociation.size() == srcShape.size());
237 unsigned start = 0;
238 // Expand the i-th dimension in srcShape.
239 for (unsigned i = 0, size = srcShape.size(); i < size; i++) {
240 const auto &map = reassociation[i];
241 auto srcDim = srcShape[i];
242 // Iterate through dimensions expanded from the i-th dimension.
243 for (unsigned j = start; j < start + map.size(); j++) {
244 // There can be only one dynamic sized dimension among dimensions
245 // expanded from the i-th dimension in srcShape.
246 // For example, if srcDim = 8, then the expanded shape could be <2x?x2>,
247 // but not <2x?x?>.
248 if (staticDstShape[j] == ShapedType::kDynamic) {
249 // The expanded dimension has dynamic size. We compute the dimension
250 // by dividing srcDim by the product of the static dimensions.
251 Size product = 1;
252 for (unsigned k = start; k < start + map.size(); k++) {
253 if (staticDstShape[k] != ShapedType::kDynamic) {
254 product *= staticDstShape[k];
255 }
256 }
257 // Compute the dynamic dimension size.
258 Value productVal = constantIndex(builder, loc, i: product);
259 Value dynamicSize =
260 builder.create<arith::DivUIOp>(location: loc, args&: srcDim, args&: productVal);
261 dstShape.push_back(Elt: dynamicSize);
262 } else {
263 // The expanded dimension is statically known.
264 dstShape.push_back(Elt: constantIndex(builder, loc, i: staticDstShape[j]));
265 }
266 }
267 start = start + map.size();
268 }
269 assert(start == staticDstShape.size());
270}
271
272void mlir::sparse_tensor::reshapeCvs(
273 OpBuilder &builder, Location loc,
274 ArrayRef<ReassociationIndices> reassociation, // NOLINT
275 ValueRange srcSizes, ValueRange srcCvs, // NOLINT
276 ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) {
277 const unsigned srcRank = srcSizes.size();
278 const unsigned dstRank = dstSizes.size();
279 assert(srcRank == srcCvs.size() && "Source rank mismatch");
280 const bool isCollapse = srcRank > dstRank;
281 const ValueRange sizes = isCollapse ? srcSizes : dstSizes;
282 // Iterate over reassociation map.
283 unsigned i = 0;
284 unsigned start = 0;
285 for (const auto &map : llvm::enumerate(First&: reassociation)) {
286 // Prepare strides information in dimension slice.
287 Value linear = constantIndex(builder, loc, i: 1);
288 for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
289 linear = builder.create<arith::MulIOp>(location: loc, args&: linear, args: sizes[j]);
290 }
291 // Start expansion.
292 Value val;
293 if (!isCollapse)
294 val = srcCvs[i];
295 // Iterate over dimension slice.
296 for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
297 linear = builder.create<arith::DivUIOp>(location: loc, args&: linear, args: sizes[j]);
298 if (isCollapse) {
299 const Value mul = builder.create<arith::MulIOp>(location: loc, args: srcCvs[j], args&: linear);
300 val = val ? builder.create<arith::AddIOp>(location: loc, args&: val, args: mul) : mul;
301 } else {
302 const Value old = val;
303 val = builder.create<arith::DivUIOp>(location: loc, args&: val, args&: linear);
304 assert(dstCvs.size() == j);
305 dstCvs.push_back(Elt: val);
306 val = builder.create<arith::RemUIOp>(location: loc, args: old, args&: linear);
307 }
308 }
309 // Finalize collapse.
310 if (isCollapse) {
311 assert(dstCvs.size() == i);
312 dstCvs.push_back(Elt: val);
313 }
314 start += map.value().size();
315 i++;
316 }
317 assert(dstCvs.size() == dstRank);
318}
319
320FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name,
321 TypeRange resultType,
322 ValueRange operands,
323 EmitCInterface emitCInterface) {
324 MLIRContext *context = module.getContext();
325 auto result = SymbolRefAttr::get(ctx: context, value: name);
326 auto func = module.lookupSymbol<func::FuncOp>(name: result.getAttr());
327 if (!func) {
328 OpBuilder moduleBuilder(module.getBodyRegion());
329 func = moduleBuilder.create<func::FuncOp>(
330 location: module.getLoc(), args&: name,
331 args: FunctionType::get(context, inputs: operands.getTypes(), results: resultType));
332 func.setPrivate();
333 if (static_cast<bool>(emitCInterface))
334 func->setAttr(name: LLVM::LLVMDialect::getEmitCWrapperAttrName(),
335 value: UnitAttr::get(context));
336 }
337 return result;
338}
339
340func::CallOp mlir::sparse_tensor::createFuncCall(
341 OpBuilder &builder, Location loc, StringRef name, TypeRange resultType,
342 ValueRange operands, EmitCInterface emitCInterface) {
343 auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
344 FlatSymbolRefAttr fn =
345 getFunc(module, name, resultType, operands, emitCInterface);
346 return builder.create<func::CallOp>(location: loc, args&: resultType, args&: fn, args&: operands);
347}
348
349Type mlir::sparse_tensor::getOpaquePointerType(MLIRContext *ctx) {
350 return LLVM::LLVMPointerType::get(context: ctx);
351}
352
353Type mlir::sparse_tensor::getOpaquePointerType(Builder &builder) {
354 return getOpaquePointerType(ctx: builder.getContext());
355}
356
357Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc,
358 unsigned sz, Type tp, bool staticShape) {
359 if (staticShape) {
360 auto memTp = MemRefType::get(shape: {sz}, elementType: tp);
361 return builder.create<memref::AllocaOp>(location: loc, args&: memTp);
362 }
363 return genAlloca(builder, loc, sz: constantIndex(builder, loc, i: sz), tp);
364}
365
366Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz,
367 Type tp) {
368 auto memTp = MemRefType::get(shape: {ShapedType::kDynamic}, elementType: tp);
369 return builder.create<memref::AllocaOp>(location: loc, args&: memTp, args: ValueRange{sz});
370}
371
372Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc,
373 Type tp) {
374 return builder.create<memref::AllocaOp>(location: loc, args: MemRefType::get(shape: {}, elementType: tp));
375}
376
377Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc,
378 ValueRange values) {
379 const unsigned sz = values.size();
380 assert(sz >= 1);
381 Value buffer = genAlloca(builder, loc, sz, tp: values[0].getType());
382 for (unsigned i = 0; i < sz; i++) {
383 Value idx = constantIndex(builder, loc, i);
384 builder.create<memref::StoreOp>(location: loc, args: values[i], args&: buffer, args&: idx);
385 }
386 return buffer;
387}
388
389Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc,
390 RankedTensorType tensorTp,
391 ValueRange sizes) {
392 Type elemTp = tensorTp.getElementType();
393 auto shape = tensorTp.getShape();
394 auto memTp = MemRefType::get(shape, elementType: elemTp);
395 SmallVector<Value> dynamicSizes;
396 for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) {
397 if (shape[i] == ShapedType::kDynamic)
398 dynamicSizes.push_back(Elt: sizes[i]);
399 }
400 Value mem = builder.create<memref::AllocOp>(location: loc, args&: memTp, args&: dynamicSizes);
401 Value zero = constantZero(builder, loc, tp: elemTp);
402 builder.create<linalg::FillOp>(location: loc, args: ValueRange{zero}, args: ValueRange{mem});
403 return mem;
404}
405
406void mlir::sparse_tensor::deallocDenseTensor(OpBuilder &builder, Location loc,
407 Value buffer) {
408 builder.create<memref::DeallocOp>(location: loc, args&: buffer);
409}
410
411void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder,
412 SmallVectorImpl<Value> &sizes,
413 Location loc, Value src) {
414 const Dimension dimRank = getSparseTensorType(val: src).getDimRank();
415 for (Dimension d = 0; d < dimRank; d++)
416 sizes.push_back(Elt: linalg::createOrFoldDimOp(b&: builder, loc, val: src, dim: d));
417}
418
419Operation *mlir::sparse_tensor::getTop(Operation *op) {
420 for (; isa<scf::ForOp>(Val: op->getParentOp()) ||
421 isa<scf::WhileOp>(Val: op->getParentOp()) ||
422 isa<scf::ParallelOp>(Val: op->getParentOp()) ||
423 isa<scf::IfOp>(Val: op->getParentOp());
424 op = op->getParentOp())
425 ;
426 return op;
427}
428
429void sparse_tensor::foreachInSparseConstant(
430 OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
431 function_ref<void(ArrayRef<Value>, Value)> callback) {
432 if (!order)
433 order = builder.getMultiDimIdentityMap(rank: attr.getType().getRank());
434
435 auto stt = SparseTensorType(getRankedTensorType(t&: attr));
436 const Dimension dimRank = stt.getDimRank();
437 const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
438 const auto values = attr.getValues().getValues<Attribute>();
439
440 // This is like the `Element<V>` class in the runtime library, but for
441 // MLIR attributes. In the future we may want to move this out into
442 // a proper class definition to help improve code legibility (e.g.,
443 // `first` -> `coords`, `second` -> `value`) as well as being able
444 // to factor out analogues of `ElementLT<V>` for the sort below, etc.
445 using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>;
446
447 // Construct the COO from the SparseElementsAttr.
448 SmallVector<ElementAttr> elems;
449 for (size_t i = 0, nse = values.size(); i < nse; i++) {
450 elems.emplace_back();
451 elems.back().second = values[i];
452 auto &coords = elems.back().first;
453 coords.reserve(N: dimRank);
454 for (Dimension d = 0; d < dimRank; d++)
455 coords.push_back(Elt: coordinates[i * dimRank + d]);
456 }
457
458 // Sorts the sparse element attribute based on coordinates.
459 llvm::sort(C&: elems, Comp: [order](const ElementAttr &lhs, const ElementAttr &rhs) {
460 if (std::addressof(r: lhs) == std::addressof(r: rhs))
461 return false;
462
463 auto lhsCoords = llvm::map_to_vector(
464 C: lhs.first, F: [](IntegerAttr i) { return i.getInt(); });
465 auto rhsCoords = llvm::map_to_vector(
466 C: rhs.first, F: [](IntegerAttr i) { return i.getInt(); });
467
468 SmallVector<int64_t, 4> lhsLvlCrds = order.compose(values: lhsCoords);
469 SmallVector<int64_t, 4> rhsLvlCrds = order.compose(values: rhsCoords);
470 // Sort the element based on the lvl coordinates.
471 for (Level l = 0; l < order.getNumResults(); l++) {
472 if (lhsLvlCrds[l] == rhsLvlCrds[l])
473 continue;
474 return lhsLvlCrds[l] < rhsLvlCrds[l];
475 }
476 llvm_unreachable("no equal coordinate in sparse element attr");
477 });
478
479 SmallVector<Value> cvs;
480 cvs.reserve(N: dimRank);
481 for (size_t i = 0, nse = values.size(); i < nse; i++) {
482 // Remap coordinates.
483 cvs.clear();
484 for (Dimension d = 0; d < dimRank; d++) {
485 auto crd = elems[i].first[d].getInt();
486 cvs.push_back(Elt: builder.create<arith::ConstantIndexOp>(location: loc, args&: crd));
487 }
488 // Remap value.
489 Value val;
490 if (isa<ComplexType>(Val: attr.getElementType())) {
491 auto valAttr = cast<ArrayAttr>(Val&: elems[i].second);
492 val = builder.create<complex::ConstantOp>(location: loc, args: attr.getElementType(),
493 args&: valAttr);
494 } else {
495 auto valAttr = cast<TypedAttr>(Val&: elems[i].second);
496 val = builder.create<arith::ConstantOp>(location: loc, args&: valAttr);
497 }
498 assert(val);
499 callback(cvs, val);
500 }
501}
502
503SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc,
504 size_t size, Value mem,
505 size_t offsetIdx, Value offsetVal) {
506#ifndef NDEBUG
507 const auto memTp = cast<MemRefType>(mem.getType());
508 assert(memTp.getRank() == 1);
509 const Size memSh = memTp.getDimSize(0);
510 assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size));
511 assert(offsetIdx == 0 || offsetIdx < size);
512#endif // NDEBUG
513 SmallVector<Value> vs;
514 vs.reserve(N: size);
515 for (unsigned i = 0; i < size; i++) {
516 Value v = builder.create<memref::LoadOp>(location: loc, args&: mem,
517 args: constantIndex(builder, loc, i));
518 if (i == offsetIdx && offsetVal)
519 v = builder.create<arith::AddIOp>(location: loc, args&: v, args&: offsetVal);
520 vs.push_back(Elt: v);
521 }
522 return vs;
523}
524
525void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
526 ValueRange vs, size_t offsetIdx, Value offsetVal) {
527#ifndef NDEBUG
528 const size_t vsize = vs.size();
529 const auto memTp = cast<MemRefType>(mem.getType());
530 assert(memTp.getRank() == 1);
531 const Size memSh = memTp.getDimSize(0);
532 assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize));
533 assert(offsetIdx == 0 || offsetIdx < vsize);
534#endif // NDEBUG
535 for (const auto &v : llvm::enumerate(First&: vs)) {
536 const Value w =
537 (offsetIdx == v.index() && offsetVal)
538 ? builder.create<arith::AddIOp>(location: loc, args&: v.value(), args&: offsetVal)
539 : v.value();
540 builder.create<memref::StoreOp>(location: loc, args: w, args&: mem,
541 args: constantIndex(builder, loc, i: v.index()));
542 }
543}
544
545TypedValue<BaseMemRefType>
546sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
547 auto tTp = llvm::cast<TensorType>(Val: tensor.getType());
548 auto mTp = MemRefType::get(shape: tTp.getShape(), elementType: tTp.getElementType());
549 return cast<TypedValue<BaseMemRefType>>(
550 Val: builder.create<bufferization::ToBufferOp>(location: loc, args&: mTp, args&: tensor).getResult());
551}
552
553Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
554 Value tensor, Dimension dim) {
555 auto enc = getSparseTensorEncoding(type: tensor.getType());
556 assert(enc && enc.isSlice());
557 std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim);
558 if (offset.has_value())
559 return constantIndex(builder, loc, i: *offset);
560 return builder.create<ToSliceOffsetOp>(location: loc, args&: tensor, args: APInt(64, dim));
561}
562
563Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
564 Value tensor, Dimension dim) {
565 auto enc = getSparseTensorEncoding(type: tensor.getType());
566 assert(enc && enc.isSlice());
567 std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim);
568 if (stride.has_value())
569 return constantIndex(builder, loc, i: *stride);
570 return builder.create<ToSliceStrideOp>(location: loc, args&: tensor, args: APInt(64, dim));
571}
572
573Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
574 SparseTensorType stt, Value tensor,
575 /*out*/ SmallVectorImpl<Value> &dimSizesValues,
576 /*out*/ Value &dimSizesBuffer) {
577 // Construct the dimension **shapes** buffer. The buffer contains the static
578 // size per dimension, or otherwise a zero for a dynamic size.
579 Dimension dimRank = stt.getDimRank();
580 dimSizesValues.clear();
581 dimSizesValues.reserve(N: dimRank);
582 for (const Size sz : stt.getDimShape()) {
583 const auto s = ShapedType::isDynamic(dValue: sz) ? 0 : sz;
584 dimSizesValues.push_back(Elt: constantIndex(builder, loc, i: s));
585 }
586 Value dimShapesBuffer = allocaBuffer(builder, loc, values: dimSizesValues);
587 // Create the `CheckedSparseTensorReader`. This reader performs a
588 // consistency check on the static sizes, but accepts any size
589 // of each dimension with a dynamic size.
590 Type opaqueTp = getOpaquePointerType(builder);
591 Type eltTp = stt.getElementType();
592 Value valTp = constantPrimaryTypeEncoding(builder, loc, elemTp: eltTp);
593 Value reader =
594 createFuncCall(builder, loc, name: "createCheckedSparseTensorReader", resultType: opaqueTp,
595 operands: {tensor, dimShapesBuffer, valTp}, emitCInterface: EmitCInterface::On)
596 .getResult(i: 0);
597 // For static shapes, the shape buffer can be used right away. For dynamic
598 // shapes, use the information from the reader to construct a buffer that
599 // supplies the actual size for each dynamic dimension.
600 dimSizesBuffer = dimShapesBuffer;
601 if (stt.hasDynamicDimShape()) {
602 Type indexTp = builder.getIndexType();
603 auto memTp = MemRefType::get(shape: {ShapedType::kDynamic}, elementType: indexTp);
604 dimSizesBuffer =
605 createFuncCall(builder, loc, name: "getSparseTensorReaderDimSizes", resultType: memTp,
606 operands: reader, emitCInterface: EmitCInterface::On)
607 .getResult(i: 0);
608 // Also convert the dim shapes values into dim sizes values, just in case
609 // subsequent clients need the values (DCE will remove unused).
610 for (Dimension d = 0; d < dimRank; d++) {
611 if (stt.isDynamicDim(d))
612 dimSizesValues[d] = builder.create<memref::LoadOp>(
613 location: loc, args&: dimSizesBuffer, args: constantIndex(builder, loc, i: d));
614 }
615 }
616 return reader;
617}
618
619Value sparse_tensor::genMapBuffers(
620 OpBuilder &builder, Location loc, SparseTensorType stt,
621 ArrayRef<Value> dimSizesValues, Value dimSizesBuffer,
622 /*out*/ SmallVectorImpl<Value> &lvlSizesValues,
623 /*out*/ Value &dim2lvlBuffer,
624 /*out*/ Value &lvl2dimBuffer) {
625 const Dimension dimRank = stt.getDimRank();
626 const Level lvlRank = stt.getLvlRank();
627 lvlSizesValues.clear();
628 lvlSizesValues.reserve(N: lvlRank);
629 // For an identity mapping, the dim2lvl and lvl2dim mappings are
630 // identical as are dimSizes and lvlSizes, so buffers are reused
631 // as much as possible.
632 if (stt.isIdentity()) {
633 assert(dimRank == lvlRank);
634 SmallVector<Value> iotaValues;
635 iotaValues.reserve(N: lvlRank);
636 for (Level l = 0; l < lvlRank; l++) {
637 iotaValues.push_back(Elt: constantIndex(builder, loc, i: l));
638 lvlSizesValues.push_back(Elt: dimSizesValues[l]);
639 }
640 dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, values: iotaValues);
641 return dimSizesBuffer; // now lvlSizesBuffer
642 }
643 // Otherwise, some code needs to be generated to set up the buffers.
644 // This code deals with permutations as well as non-permutations that
645 // arise from rank changing blocking.
646 const auto dimToLvl = stt.getDimToLvl();
647 const auto lvlToDim = stt.getLvlToDim();
648 SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
649 SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
650 // Generate dim2lvl.
651 assert(lvlRank == dimToLvl.getNumResults());
652 for (Level l = 0; l < lvlRank; l++) {
653 AffineExpr exp = dimToLvl.getResult(idx: l);
654 // We expect:
655 // (1) l = d
656 // (2) l = d / c
657 // (3) l = d % c
658 Dimension d = 0;
659 uint64_t cf = 0, cm = 0;
660 switch (exp.getKind()) {
661 case AffineExprKind::DimId: {
662 d = cast<AffineDimExpr>(Val&: exp).getPosition();
663 break;
664 }
665 case AffineExprKind::FloorDiv: {
666 auto floor = cast<AffineBinaryOpExpr>(Val&: exp);
667 d = cast<AffineDimExpr>(Val: floor.getLHS()).getPosition();
668 cf = cast<AffineConstantExpr>(Val: floor.getRHS()).getValue();
669 break;
670 }
671 case AffineExprKind::Mod: {
672 auto mod = cast<AffineBinaryOpExpr>(Val&: exp);
673 d = cast<AffineDimExpr>(Val: mod.getLHS()).getPosition();
674 cm = cast<AffineConstantExpr>(Val: mod.getRHS()).getValue();
675 break;
676 }
677 default:
678 llvm::report_fatal_error(reason: "unsupported dim2lvl in sparse tensor type");
679 }
680 dim2lvlValues[l] = constantIndex(builder, loc, i: encodeDim(i: d, cf, cm));
681 // Compute the level sizes.
682 // (1) l = d : size(d)
683 // (2) l = d / c : size(d) / c
684 // (3) l = d % c : c
685 Value lvlSz;
686 if (cm == 0) {
687 lvlSz = dimSizesValues[d];
688 if (cf != 0)
689 lvlSz = builder.create<arith::DivUIOp>(location: loc, args&: lvlSz,
690 args: constantIndex(builder, loc, i: cf));
691 } else {
692 lvlSz = constantIndex(builder, loc, i: cm);
693 }
694 lvlSizesValues.push_back(Elt: lvlSz);
695 }
696 // Generate lvl2dim.
697 assert(dimRank == lvlToDim.getNumResults());
698 for (Dimension d = 0; d < dimRank; d++) {
699 AffineExpr exp = lvlToDim.getResult(idx: d);
700 // We expect:
701 // (1) d = l
702 // (2) d = l' * c + l
703 Level l = 0, ll = 0;
704 uint64_t c = 0;
705 switch (exp.getKind()) {
706 case AffineExprKind::DimId: {
707 l = cast<AffineDimExpr>(Val&: exp).getPosition();
708 break;
709 }
710 case AffineExprKind::Add: {
711 // Always mul on lhs, symbol/constant on rhs.
712 auto add = cast<AffineBinaryOpExpr>(Val&: exp);
713 assert(add.getLHS().getKind() == AffineExprKind::Mul);
714 auto mul = cast<AffineBinaryOpExpr>(Val: add.getLHS());
715 ll = cast<AffineDimExpr>(Val: mul.getLHS()).getPosition();
716 c = cast<AffineConstantExpr>(Val: mul.getRHS()).getValue();
717 l = cast<AffineDimExpr>(Val: add.getRHS()).getPosition();
718 break;
719 }
720 default:
721 llvm::report_fatal_error(reason: "unsupported lvl2dim in sparse tensor type");
722 }
723 lvl2dimValues[d] = constantIndex(builder, loc, i: encodeLvl(i: l, c, ii: ll));
724 }
725 // Return buffers.
726 dim2lvlBuffer = allocaBuffer(builder, loc, values: dim2lvlValues);
727 lvl2dimBuffer = allocaBuffer(builder, loc, values: lvl2dimValues);
728 return allocaBuffer(builder, loc, values: lvlSizesValues); // lvlSizesBuffer
729}
730

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp