1 | //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "CodegenUtils.h" |
10 | #include "SparseTensorDescriptor.h" |
11 | |
12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
13 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
15 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
18 | #include "mlir/IR/Matchers.h" |
19 | #include "mlir/IR/Types.h" |
20 | #include "mlir/IR/Value.h" |
21 | #include <optional> |
22 | |
23 | using namespace mlir; |
24 | using namespace mlir::sparse_tensor; |
25 | |
26 | //===----------------------------------------------------------------------===// |
27 | // ExecutionEngine/SparseTensorUtils helper functions. |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { |
31 | switch (width) { |
32 | case 64: |
33 | return OverheadType::kU64; |
34 | case 32: |
35 | return OverheadType::kU32; |
36 | case 16: |
37 | return OverheadType::kU16; |
38 | case 8: |
39 | return OverheadType::kU8; |
40 | case 0: |
41 | return OverheadType::kIndex; |
42 | } |
43 | llvm_unreachable("Unsupported overhead bitwidth" ); |
44 | } |
45 | |
46 | OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { |
47 | if (tp.isIndex()) |
48 | return OverheadType::kIndex; |
49 | if (auto intTp = dyn_cast<IntegerType>(tp)) |
50 | return overheadTypeEncoding(intTp.getWidth()); |
51 | llvm_unreachable("Unknown overhead type" ); |
52 | } |
53 | |
54 | Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { |
55 | switch (ot) { |
56 | case OverheadType::kIndex: |
57 | return builder.getIndexType(); |
58 | case OverheadType::kU64: |
59 | return builder.getIntegerType(64); |
60 | case OverheadType::kU32: |
61 | return builder.getIntegerType(32); |
62 | case OverheadType::kU16: |
63 | return builder.getIntegerType(16); |
64 | case OverheadType::kU8: |
65 | return builder.getIntegerType(8); |
66 | } |
67 | llvm_unreachable("Unknown OverheadType" ); |
68 | } |
69 | |
70 | OverheadType |
71 | mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) { |
72 | return overheadTypeEncoding(enc.getPosWidth()); |
73 | } |
74 | |
75 | OverheadType |
76 | mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) { |
77 | return overheadTypeEncoding(enc.getCrdWidth()); |
78 | } |
79 | |
80 | // TODO: we ought to add some `static_assert` tests to ensure that the |
81 | // `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder, |
82 | // {pos,crd}OverheadTypeEncoding(enc))` |
83 | |
84 | // TODO: Adjust the naming convention for the constructors of |
85 | // `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro |
86 | // here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce |
87 | // the possibility of typo bugs or things getting out of sync. |
88 | StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { |
89 | switch (ot) { |
90 | case OverheadType::kIndex: |
91 | return "0" ; |
92 | #define CASE(ONAME, O) \ |
93 | case OverheadType::kU##ONAME: \ |
94 | return #ONAME; |
95 | MLIR_SPARSETENSOR_FOREVERY_FIXED_O(CASE) |
96 | #undef CASE |
97 | } |
98 | llvm_unreachable("Unknown OverheadType" ); |
99 | } |
100 | |
101 | StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { |
102 | return overheadTypeFunctionSuffix(ot: overheadTypeEncoding(tp)); |
103 | } |
104 | |
105 | PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { |
106 | if (elemTp.isF64()) |
107 | return PrimaryType::kF64; |
108 | if (elemTp.isF32()) |
109 | return PrimaryType::kF32; |
110 | if (elemTp.isF16()) |
111 | return PrimaryType::kF16; |
112 | if (elemTp.isBF16()) |
113 | return PrimaryType::kBF16; |
114 | if (elemTp.isInteger(width: 64)) |
115 | return PrimaryType::kI64; |
116 | if (elemTp.isInteger(width: 32)) |
117 | return PrimaryType::kI32; |
118 | if (elemTp.isInteger(width: 16)) |
119 | return PrimaryType::kI16; |
120 | if (elemTp.isInteger(width: 8)) |
121 | return PrimaryType::kI8; |
122 | if (auto complexTp = dyn_cast<ComplexType>(elemTp)) { |
123 | auto complexEltTp = complexTp.getElementType(); |
124 | if (complexEltTp.isF64()) |
125 | return PrimaryType::kC64; |
126 | if (complexEltTp.isF32()) |
127 | return PrimaryType::kC32; |
128 | } |
129 | llvm_unreachable("Unknown primary type" ); |
130 | } |
131 | |
132 | StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { |
133 | switch (pt) { |
134 | #define CASE(VNAME, V) \ |
135 | case PrimaryType::k##VNAME: \ |
136 | return #VNAME; |
137 | MLIR_SPARSETENSOR_FOREVERY_V(CASE) |
138 | #undef CASE |
139 | } |
140 | llvm_unreachable("Unknown PrimaryType" ); |
141 | } |
142 | |
143 | StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { |
144 | return primaryTypeFunctionSuffix(pt: primaryTypeEncoding(elemTp)); |
145 | } |
146 | |
147 | //===----------------------------------------------------------------------===// |
148 | // Misc code generators. |
149 | //===----------------------------------------------------------------------===// |
150 | |
151 | Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, |
152 | Type dstTp) { |
153 | const Type srcTp = value.getType(); |
154 | if (srcTp == dstTp) |
155 | return value; |
156 | |
157 | // int <=> index |
158 | if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp)) |
159 | return builder.create<arith::IndexCastOp>(loc, dstTp, value); |
160 | |
161 | const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp); |
162 | const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false; |
163 | return mlir::convertScalarToDtype(b&: builder, loc, operand: value, toType: dstTp, isUnsignedCast); |
164 | } |
165 | |
166 | Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc, |
167 | Value elem, Type dstTp) { |
168 | if (auto rtp = dyn_cast<RankedTensorType>(dstTp)) { |
169 | // Scalars can only be converted to 0-ranked tensors. |
170 | assert(rtp.getRank() == 0); |
171 | elem = sparse_tensor::genCast(builder, loc, value: elem, dstTp: rtp.getElementType()); |
172 | return builder.create<tensor::FromElementsOp>(loc, rtp, elem); |
173 | } |
174 | return sparse_tensor::genCast(builder, loc, value: elem, dstTp); |
175 | } |
176 | |
177 | Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem, |
178 | ValueRange s) { |
179 | Value load = builder.create<memref::LoadOp>(loc, mem, s); |
180 | if (!isa<IndexType>(Val: load.getType())) { |
181 | if (load.getType().getIntOrFloatBitWidth() < 64) |
182 | load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load); |
183 | load = |
184 | builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load); |
185 | } |
186 | return load; |
187 | } |
188 | |
189 | mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { |
190 | if (isa<FloatType>(Val: tp)) |
191 | return builder.getFloatAttr(tp, 1.0); |
192 | if (isa<IndexType>(Val: tp)) |
193 | return builder.getIndexAttr(1); |
194 | if (auto intTp = dyn_cast<IntegerType>(tp)) |
195 | return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); |
196 | if (isa<RankedTensorType, VectorType>(Val: tp)) { |
197 | auto shapedTp = cast<ShapedType>(tp); |
198 | if (auto one = getOneAttr(builder, shapedTp.getElementType())) |
199 | return DenseElementsAttr::get(shapedTp, one); |
200 | } |
201 | llvm_unreachable("Unsupported attribute type" ); |
202 | } |
203 | |
204 | Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, |
205 | Value v) { |
206 | Type tp = v.getType(); |
207 | Value zero = constantZero(builder, loc, tp); |
208 | if (isa<FloatType>(tp)) |
209 | return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, |
210 | zero); |
211 | if (tp.isIntOrIndex()) |
212 | return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, |
213 | zero); |
214 | if (dyn_cast<ComplexType>(tp)) |
215 | return builder.create<complex::NotEqualOp>(loc, v, zero); |
216 | llvm_unreachable("Non-numeric type" ); |
217 | } |
218 | |
219 | void mlir::sparse_tensor::genReshapeDstShape( |
220 | OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape, |
221 | ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape, |
222 | ArrayRef<ReassociationIndices> reassociation) { |
223 | // Collapse shape. |
224 | if (reassociation.size() < srcShape.size()) { |
225 | unsigned start = 0; |
226 | for (const auto &map : llvm::enumerate(First&: reassociation)) { |
227 | auto dstDim = constantIndex(builder, loc, i: 1); |
228 | for (unsigned i = start; i < start + map.value().size(); i++) { |
229 | dstDim = builder.create<arith::MulIOp>(loc, dstDim, srcShape[i]); |
230 | } |
231 | dstShape.push_back(Elt: dstDim); |
232 | start = start + map.value().size(); |
233 | } |
234 | assert(start == srcShape.size()); |
235 | return; |
236 | } |
237 | |
238 | // Expand shape. |
239 | assert(reassociation.size() == srcShape.size()); |
240 | unsigned start = 0; |
241 | // Expand the i-th dimension in srcShape. |
242 | for (unsigned i = 0, size = srcShape.size(); i < size; i++) { |
243 | const auto &map = reassociation[i]; |
244 | auto srcDim = srcShape[i]; |
245 | // Iterate through dimensions expanded from the i-th dimension. |
246 | for (unsigned j = start; j < start + map.size(); j++) { |
247 | // There can be only one dynamic sized dimension among dimensions |
248 | // expanded from the i-th dimension in srcShape. |
249 | // For example, if srcDim = 8, then the expanded shape could be <2x?x2>, |
250 | // but not <2x?x?>. |
251 | if (staticDstShape[j] == ShapedType::kDynamic) { |
252 | // The expanded dimension has dynamic size. We compute the dimension |
253 | // by dividing srcDim by the product of the static dimensions. |
254 | Size product = 1; |
255 | for (unsigned k = start; k < start + map.size(); k++) { |
256 | if (staticDstShape[k] != ShapedType::kDynamic) { |
257 | product *= staticDstShape[k]; |
258 | } |
259 | } |
260 | // Compute the dynamic dimension size. |
261 | Value productVal = constantIndex(builder, loc, i: product); |
262 | Value dynamicSize = |
263 | builder.create<arith::DivUIOp>(loc, srcDim, productVal); |
264 | dstShape.push_back(Elt: dynamicSize); |
265 | } else { |
266 | // The expanded dimension is statically known. |
267 | dstShape.push_back(Elt: constantIndex(builder, loc, i: staticDstShape[j])); |
268 | } |
269 | } |
270 | start = start + map.size(); |
271 | } |
272 | assert(start == staticDstShape.size()); |
273 | } |
274 | |
275 | void mlir::sparse_tensor::reshapeCvs( |
276 | OpBuilder &builder, Location loc, |
277 | ArrayRef<ReassociationIndices> reassociation, // NOLINT |
278 | ValueRange srcSizes, ValueRange srcCvs, // NOLINT |
279 | ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) { |
280 | const unsigned srcRank = srcSizes.size(); |
281 | const unsigned dstRank = dstSizes.size(); |
282 | assert(srcRank == srcCvs.size() && "Source rank mismatch" ); |
283 | const bool isCollapse = srcRank > dstRank; |
284 | const ValueRange sizes = isCollapse ? srcSizes : dstSizes; |
285 | // Iterate over reassociation map. |
286 | unsigned i = 0; |
287 | unsigned start = 0; |
288 | for (const auto &map : llvm::enumerate(First&: reassociation)) { |
289 | // Prepare strides information in dimension slice. |
290 | Value linear = constantIndex(builder, loc, i: 1); |
291 | for (unsigned j = start, end = start + map.value().size(); j < end; j++) { |
292 | linear = builder.create<arith::MulIOp>(loc, linear, sizes[j]); |
293 | } |
294 | // Start expansion. |
295 | Value val; |
296 | if (!isCollapse) |
297 | val = srcCvs[i]; |
298 | // Iterate over dimension slice. |
299 | for (unsigned j = start, end = start + map.value().size(); j < end; j++) { |
300 | linear = builder.create<arith::DivUIOp>(loc, linear, sizes[j]); |
301 | if (isCollapse) { |
302 | const Value mul = builder.create<arith::MulIOp>(loc, srcCvs[j], linear); |
303 | val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul; |
304 | } else { |
305 | const Value old = val; |
306 | val = builder.create<arith::DivUIOp>(loc, val, linear); |
307 | assert(dstCvs.size() == j); |
308 | dstCvs.push_back(Elt: val); |
309 | val = builder.create<arith::RemUIOp>(loc, old, linear); |
310 | } |
311 | } |
312 | // Finalize collapse. |
313 | if (isCollapse) { |
314 | assert(dstCvs.size() == i); |
315 | dstCvs.push_back(Elt: val); |
316 | } |
317 | start += map.value().size(); |
318 | i++; |
319 | } |
320 | assert(dstCvs.size() == dstRank); |
321 | } |
322 | |
323 | FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name, |
324 | TypeRange resultType, |
325 | ValueRange operands, |
326 | EmitCInterface emitCInterface) { |
327 | MLIRContext *context = module.getContext(); |
328 | auto result = SymbolRefAttr::get(context, name); |
329 | auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); |
330 | if (!func) { |
331 | OpBuilder moduleBuilder(module.getBodyRegion()); |
332 | func = moduleBuilder.create<func::FuncOp>( |
333 | module.getLoc(), name, |
334 | FunctionType::get(context, operands.getTypes(), resultType)); |
335 | func.setPrivate(); |
336 | if (static_cast<bool>(emitCInterface)) |
337 | func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
338 | UnitAttr::get(context)); |
339 | } |
340 | return result; |
341 | } |
342 | |
343 | func::CallOp mlir::sparse_tensor::createFuncCall( |
344 | OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, |
345 | ValueRange operands, EmitCInterface emitCInterface) { |
346 | auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>(); |
347 | FlatSymbolRefAttr fn = |
348 | getFunc(module, name, resultType, operands, emitCInterface); |
349 | return builder.create<func::CallOp>(loc, resultType, fn, operands); |
350 | } |
351 | |
352 | Type mlir::sparse_tensor::getOpaquePointerType(MLIRContext *ctx) { |
353 | return LLVM::LLVMPointerType::get(ctx); |
354 | } |
355 | |
356 | Type mlir::sparse_tensor::getOpaquePointerType(Builder &builder) { |
357 | return getOpaquePointerType(ctx: builder.getContext()); |
358 | } |
359 | |
360 | Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, |
361 | unsigned sz, Type tp, bool staticShape) { |
362 | if (staticShape) { |
363 | auto memTp = MemRefType::get({sz}, tp); |
364 | return builder.create<memref::AllocaOp>(loc, memTp); |
365 | } |
366 | return genAlloca(builder, loc, sz: constantIndex(builder, loc, i: sz), tp); |
367 | } |
368 | |
369 | Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz, |
370 | Type tp) { |
371 | auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); |
372 | return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); |
373 | } |
374 | |
375 | Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc, |
376 | Type tp) { |
377 | return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); |
378 | } |
379 | |
380 | Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc, |
381 | ValueRange values) { |
382 | const unsigned sz = values.size(); |
383 | assert(sz >= 1); |
384 | Value buffer = genAlloca(builder, loc, sz, tp: values[0].getType()); |
385 | for (unsigned i = 0; i < sz; i++) { |
386 | Value idx = constantIndex(builder, loc, i); |
387 | builder.create<memref::StoreOp>(loc, values[i], buffer, idx); |
388 | } |
389 | return buffer; |
390 | } |
391 | |
392 | Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc, |
393 | RankedTensorType tensorTp, |
394 | ValueRange sizes) { |
395 | Type elemTp = tensorTp.getElementType(); |
396 | auto shape = tensorTp.getShape(); |
397 | auto memTp = MemRefType::get(shape, elemTp); |
398 | SmallVector<Value> dynamicSizes; |
399 | for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { |
400 | if (shape[i] == ShapedType::kDynamic) |
401 | dynamicSizes.push_back(Elt: sizes[i]); |
402 | } |
403 | Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes); |
404 | Value zero = constantZero(builder, loc, tp: elemTp); |
405 | builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem}); |
406 | return mem; |
407 | } |
408 | |
409 | void mlir::sparse_tensor::deallocDenseTensor(OpBuilder &builder, Location loc, |
410 | Value buffer) { |
411 | builder.create<memref::DeallocOp>(loc, buffer); |
412 | } |
413 | |
414 | void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder, |
415 | SmallVectorImpl<Value> &sizes, |
416 | Location loc, Value src) { |
417 | const Dimension dimRank = getSparseTensorType(val: src).getDimRank(); |
418 | for (Dimension d = 0; d < dimRank; d++) |
419 | sizes.push_back(Elt: linalg::createOrFoldDimOp(b&: builder, loc, val: src, dim: d)); |
420 | } |
421 | |
422 | Operation *mlir::sparse_tensor::getTop(Operation *op) { |
423 | for (; isa<scf::ForOp>(op->getParentOp()) || |
424 | isa<scf::WhileOp>(op->getParentOp()) || |
425 | isa<scf::ParallelOp>(op->getParentOp()) || |
426 | isa<scf::IfOp>(op->getParentOp()); |
427 | op = op->getParentOp()) |
428 | ; |
429 | return op; |
430 | } |
431 | |
432 | void sparse_tensor::foreachInSparseConstant( |
433 | OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, |
434 | function_ref<void(ArrayRef<Value>, Value)> callback) { |
435 | if (!order) |
436 | order = builder.getMultiDimIdentityMap(rank: attr.getType().getRank()); |
437 | |
438 | auto stt = SparseTensorType(getRankedTensorType(attr)); |
439 | const Dimension dimRank = stt.getDimRank(); |
440 | const auto coordinates = attr.getIndices().getValues<IntegerAttr>(); |
441 | const auto values = attr.getValues().getValues<Attribute>(); |
442 | |
443 | // This is like the `Element<V>` class in the runtime library, but for |
444 | // MLIR attributes. In the future we may want to move this out into |
445 | // a proper class definition to help improve code legibility (e.g., |
446 | // `first` -> `coords`, `second` -> `value`) as well as being able |
447 | // to factor out analogues of `ElementLT<V>` for the sort below, etc. |
448 | using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>; |
449 | |
450 | // Construct the COO from the SparseElementsAttr. |
451 | SmallVector<ElementAttr> elems; |
452 | for (size_t i = 0, nse = values.size(); i < nse; i++) { |
453 | elems.emplace_back(); |
454 | elems.back().second = values[i]; |
455 | auto &coords = elems.back().first; |
456 | coords.reserve(dimRank); |
457 | for (Dimension d = 0; d < dimRank; d++) |
458 | coords.push_back(coordinates[i * dimRank + d]); |
459 | } |
460 | |
461 | // Sorts the sparse element attribute based on coordinates. |
462 | std::sort(elems.begin(), elems.end(), |
463 | [order](const ElementAttr &lhs, const ElementAttr &rhs) { |
464 | if (std::addressof(lhs) == std::addressof(rhs)) |
465 | return false; |
466 | |
467 | auto lhsCoords = llvm::map_to_vector( |
468 | lhs.first, [](IntegerAttr i) { return i.getInt(); }); |
469 | auto rhsCoords = llvm::map_to_vector( |
470 | rhs.first, [](IntegerAttr i) { return i.getInt(); }); |
471 | |
472 | SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords); |
473 | SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords); |
474 | // Sort the element based on the lvl coordinates. |
475 | for (Level l = 0; l < order.getNumResults(); l++) { |
476 | if (lhsLvlCrds[l] == rhsLvlCrds[l]) |
477 | continue; |
478 | return lhsLvlCrds[l] < rhsLvlCrds[l]; |
479 | } |
480 | llvm_unreachable("no equal coordinate in sparse element attr" ); |
481 | }); |
482 | |
483 | SmallVector<Value> cvs; |
484 | cvs.reserve(N: dimRank); |
485 | for (size_t i = 0, nse = values.size(); i < nse; i++) { |
486 | // Remap coordinates. |
487 | cvs.clear(); |
488 | for (Dimension d = 0; d < dimRank; d++) { |
489 | auto crd = elems[i].first[d].getInt(); |
490 | cvs.push_back(Elt: builder.create<arith::ConstantIndexOp>(loc, crd)); |
491 | } |
492 | // Remap value. |
493 | Value val; |
494 | if (isa<ComplexType>(attr.getElementType())) { |
495 | auto valAttr = cast<ArrayAttr>(elems[i].second); |
496 | val = builder.create<complex::ConstantOp>(loc, attr.getElementType(), |
497 | valAttr); |
498 | } else { |
499 | auto valAttr = cast<TypedAttr>(elems[i].second); |
500 | val = builder.create<arith::ConstantOp>(loc, valAttr); |
501 | } |
502 | assert(val); |
503 | callback(cvs, val); |
504 | } |
505 | } |
506 | |
507 | SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc, |
508 | size_t size, Value mem, |
509 | size_t offsetIdx, Value offsetVal) { |
510 | #ifndef NDEBUG |
511 | const auto memTp = cast<MemRefType>(mem.getType()); |
512 | assert(memTp.getRank() == 1); |
513 | const Size memSh = memTp.getDimSize(0); |
514 | assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size)); |
515 | assert(offsetIdx == 0 || offsetIdx < size); |
516 | #endif // NDEBUG |
517 | SmallVector<Value> vs; |
518 | vs.reserve(N: size); |
519 | for (unsigned i = 0; i < size; i++) { |
520 | Value v = builder.create<memref::LoadOp>(loc, mem, |
521 | constantIndex(builder, loc, i)); |
522 | if (i == offsetIdx && offsetVal) |
523 | v = builder.create<arith::AddIOp>(loc, v, offsetVal); |
524 | vs.push_back(Elt: v); |
525 | } |
526 | return vs; |
527 | } |
528 | |
529 | void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem, |
530 | ValueRange vs, size_t offsetIdx, Value offsetVal) { |
531 | #ifndef NDEBUG |
532 | const size_t vsize = vs.size(); |
533 | const auto memTp = cast<MemRefType>(mem.getType()); |
534 | assert(memTp.getRank() == 1); |
535 | const Size memSh = memTp.getDimSize(0); |
536 | assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize)); |
537 | assert(offsetIdx == 0 || offsetIdx < vsize); |
538 | #endif // NDEBUG |
539 | for (const auto &v : llvm::enumerate(First&: vs)) { |
540 | const Value w = |
541 | (offsetIdx == v.index() && offsetVal) |
542 | ? builder.create<arith::AddIOp>(loc, v.value(), offsetVal) |
543 | : v.value(); |
544 | builder.create<memref::StoreOp>(loc, w, mem, |
545 | constantIndex(builder, loc, v.index())); |
546 | } |
547 | } |
548 | |
549 | TypedValue<BaseMemRefType> |
550 | sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { |
551 | auto tTp = llvm::cast<TensorType>(Val: tensor.getType()); |
552 | auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType()); |
553 | return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor) |
554 | .getResult(); |
555 | } |
556 | |
557 | Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc, |
558 | Value tensor) { |
559 | return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc); |
560 | } |
561 | |
562 | Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, |
563 | Value tensor, Dimension dim) { |
564 | auto enc = getSparseTensorEncoding(tensor.getType()); |
565 | assert(enc && enc.isSlice()); |
566 | std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim); |
567 | if (offset.has_value()) |
568 | return constantIndex(builder, loc, i: *offset); |
569 | return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim)); |
570 | } |
571 | |
572 | Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, |
573 | Value tensor, Dimension dim) { |
574 | auto enc = getSparseTensorEncoding(tensor.getType()); |
575 | assert(enc && enc.isSlice()); |
576 | std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim); |
577 | if (stride.has_value()) |
578 | return constantIndex(builder, loc, i: *stride); |
579 | return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim)); |
580 | } |
581 | |
582 | Value sparse_tensor::genReader(OpBuilder &builder, Location loc, |
583 | SparseTensorType stt, Value tensor, |
584 | /*out*/ SmallVectorImpl<Value> &dimSizesValues, |
585 | /*out*/ Value &dimSizesBuffer) { |
586 | // Construct the dimension **shapes** buffer. The buffer contains the static |
587 | // size per dimension, or otherwise a zero for a dynamic size. |
588 | Dimension dimRank = stt.getDimRank(); |
589 | dimSizesValues.clear(); |
590 | dimSizesValues.reserve(N: dimRank); |
591 | for (const Size sz : stt.getDimShape()) { |
592 | const auto s = ShapedType::isDynamic(sz) ? 0 : sz; |
593 | dimSizesValues.push_back(Elt: constantIndex(builder, loc, s)); |
594 | } |
595 | Value dimShapesBuffer = allocaBuffer(builder, loc, values: dimSizesValues); |
596 | // Create the `CheckedSparseTensorReader`. This reader performs a |
597 | // consistency check on the static sizes, but accepts any size |
598 | // of each dimension with a dynamic size. |
599 | Type opaqueTp = getOpaquePointerType(builder); |
600 | Type eltTp = stt.getElementType(); |
601 | Value valTp = constantPrimaryTypeEncoding(builder, loc, elemTp: eltTp); |
602 | Value reader = |
603 | createFuncCall(builder, loc, "createCheckedSparseTensorReader" , opaqueTp, |
604 | {tensor, dimShapesBuffer, valTp}, EmitCInterface::On) |
605 | .getResult(0); |
606 | // For static shapes, the shape buffer can be used right away. For dynamic |
607 | // shapes, use the information from the reader to construct a buffer that |
608 | // supplies the actual size for each dynamic dimension. |
609 | dimSizesBuffer = dimShapesBuffer; |
610 | if (stt.hasDynamicDimShape()) { |
611 | Type indexTp = builder.getIndexType(); |
612 | auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp); |
613 | dimSizesBuffer = |
614 | createFuncCall(builder, loc, "getSparseTensorReaderDimSizes" , memTp, |
615 | reader, EmitCInterface::On) |
616 | .getResult(0); |
617 | // Also convert the dim shapes values into dim sizes values, just in case |
618 | // subsequent clients need the values (DCE will remove unused). |
619 | for (Dimension d = 0; d < dimRank; d++) { |
620 | if (stt.isDynamicDim(d)) |
621 | dimSizesValues[d] = builder.create<memref::LoadOp>( |
622 | loc, dimSizesBuffer, constantIndex(builder, loc, d)); |
623 | } |
624 | } |
625 | return reader; |
626 | } |
627 | |
628 | Value sparse_tensor::genMapBuffers( |
629 | OpBuilder &builder, Location loc, SparseTensorType stt, |
630 | ArrayRef<Value> dimSizesValues, Value dimSizesBuffer, |
631 | /*out*/ SmallVectorImpl<Value> &lvlSizesValues, |
632 | /*out*/ Value &dim2lvlBuffer, |
633 | /*out*/ Value &lvl2dimBuffer) { |
634 | const Dimension dimRank = stt.getDimRank(); |
635 | const Level lvlRank = stt.getLvlRank(); |
636 | lvlSizesValues.clear(); |
637 | lvlSizesValues.reserve(N: lvlRank); |
638 | // For an identity mapping, the dim2lvl and lvl2dim mappings are |
639 | // identical as are dimSizes and lvlSizes, so buffers are reused |
640 | // as much as possible. |
641 | if (stt.isIdentity()) { |
642 | assert(dimRank == lvlRank); |
643 | SmallVector<Value> iotaValues; |
644 | iotaValues.reserve(N: lvlRank); |
645 | for (Level l = 0; l < lvlRank; l++) { |
646 | iotaValues.push_back(Elt: constantIndex(builder, loc, i: l)); |
647 | lvlSizesValues.push_back(Elt: dimSizesValues[l]); |
648 | } |
649 | dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, values: iotaValues); |
650 | return dimSizesBuffer; // now lvlSizesBuffer |
651 | } |
652 | // Otherwise, some code needs to be generated to set up the buffers. |
653 | // This code deals with permutations as well as non-permutations that |
654 | // arise from rank changing blocking. |
655 | const auto dimToLvl = stt.getDimToLvl(); |
656 | const auto lvlToDim = stt.getLvlToDim(); |
657 | SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars |
658 | SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars |
659 | // Generate dim2lvl. |
660 | assert(lvlRank == dimToLvl.getNumResults()); |
661 | for (Level l = 0; l < lvlRank; l++) { |
662 | AffineExpr exp = dimToLvl.getResult(idx: l); |
663 | // We expect: |
664 | // (1) l = d |
665 | // (2) l = d / c |
666 | // (3) l = d % c |
667 | Dimension d = 0; |
668 | uint64_t cf = 0, cm = 0; |
669 | switch (exp.getKind()) { |
670 | case AffineExprKind::DimId: { |
671 | d = cast<AffineDimExpr>(Val&: exp).getPosition(); |
672 | break; |
673 | } |
674 | case AffineExprKind::FloorDiv: { |
675 | auto floor = cast<AffineBinaryOpExpr>(Val&: exp); |
676 | d = cast<AffineDimExpr>(Val: floor.getLHS()).getPosition(); |
677 | cf = cast<AffineConstantExpr>(Val: floor.getRHS()).getValue(); |
678 | break; |
679 | } |
680 | case AffineExprKind::Mod: { |
681 | auto mod = cast<AffineBinaryOpExpr>(Val&: exp); |
682 | d = cast<AffineDimExpr>(Val: mod.getLHS()).getPosition(); |
683 | cm = cast<AffineConstantExpr>(Val: mod.getRHS()).getValue(); |
684 | break; |
685 | } |
686 | default: |
687 | llvm::report_fatal_error(reason: "unsupported dim2lvl in sparse tensor type" ); |
688 | } |
689 | dim2lvlValues[l] = constantIndex(builder, loc, i: encodeDim(i: d, cf, cm)); |
690 | // Compute the level sizes. |
691 | // (1) l = d : size(d) |
692 | // (2) l = d / c : size(d) / c |
693 | // (3) l = d % c : c |
694 | Value lvlSz; |
695 | if (cm == 0) { |
696 | lvlSz = dimSizesValues[d]; |
697 | if (cf != 0) |
698 | lvlSz = builder.create<arith::DivUIOp>(loc, lvlSz, |
699 | constantIndex(builder, loc, cf)); |
700 | } else { |
701 | lvlSz = constantIndex(builder, loc, i: cm); |
702 | } |
703 | lvlSizesValues.push_back(Elt: lvlSz); |
704 | } |
705 | // Generate lvl2dim. |
706 | assert(dimRank == lvlToDim.getNumResults()); |
707 | for (Dimension d = 0; d < dimRank; d++) { |
708 | AffineExpr exp = lvlToDim.getResult(idx: d); |
709 | // We expect: |
710 | // (1) d = l |
711 | // (2) d = l' * c + l |
712 | Level l = 0, ll = 0; |
713 | uint64_t c = 0; |
714 | switch (exp.getKind()) { |
715 | case AffineExprKind::DimId: { |
716 | l = cast<AffineDimExpr>(Val&: exp).getPosition(); |
717 | break; |
718 | } |
719 | case AffineExprKind::Add: { |
720 | // Always mul on lhs, symbol/constant on rhs. |
721 | auto add = cast<AffineBinaryOpExpr>(Val&: exp); |
722 | assert(add.getLHS().getKind() == AffineExprKind::Mul); |
723 | auto mul = cast<AffineBinaryOpExpr>(Val: add.getLHS()); |
724 | ll = cast<AffineDimExpr>(Val: mul.getLHS()).getPosition(); |
725 | c = cast<AffineConstantExpr>(Val: mul.getRHS()).getValue(); |
726 | l = cast<AffineDimExpr>(Val: add.getRHS()).getPosition(); |
727 | break; |
728 | } |
729 | default: |
730 | llvm::report_fatal_error(reason: "unsupported lvl2dim in sparse tensor type" ); |
731 | } |
732 | lvl2dimValues[d] = constantIndex(builder, loc, i: encodeLvl(i: l, c, ii: ll)); |
733 | } |
734 | // Return buffers. |
735 | dim2lvlBuffer = allocaBuffer(builder, loc, values: dim2lvlValues); |
736 | lvl2dimBuffer = allocaBuffer(builder, loc, values: lvl2dimValues); |
737 | return allocaBuffer(builder, loc, values: lvlSizesValues); // lvlSizesBuffer |
738 | } |
739 | |