| 1 | //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "CodegenUtils.h" |
| 10 | #include "SparseTensorDescriptor.h" |
| 11 | |
| 12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 13 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 15 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 18 | #include "mlir/IR/Matchers.h" |
| 19 | #include "mlir/IR/Types.h" |
| 20 | #include "mlir/IR/Value.h" |
| 21 | #include <optional> |
| 22 | |
| 23 | using namespace mlir; |
| 24 | using namespace mlir::sparse_tensor; |
| 25 | |
| 26 | //===----------------------------------------------------------------------===// |
| 27 | // ExecutionEngine/SparseTensorUtils helper functions. |
| 28 | //===----------------------------------------------------------------------===// |
| 29 | |
| 30 | OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { |
| 31 | switch (width) { |
| 32 | case 64: |
| 33 | return OverheadType::kU64; |
| 34 | case 32: |
| 35 | return OverheadType::kU32; |
| 36 | case 16: |
| 37 | return OverheadType::kU16; |
| 38 | case 8: |
| 39 | return OverheadType::kU8; |
| 40 | case 0: |
| 41 | return OverheadType::kIndex; |
| 42 | } |
| 43 | llvm_unreachable("Unsupported overhead bitwidth" ); |
| 44 | } |
| 45 | |
| 46 | OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { |
| 47 | if (tp.isIndex()) |
| 48 | return OverheadType::kIndex; |
| 49 | if (auto intTp = dyn_cast<IntegerType>(tp)) |
| 50 | return overheadTypeEncoding(intTp.getWidth()); |
| 51 | llvm_unreachable("Unknown overhead type" ); |
| 52 | } |
| 53 | |
| 54 | Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { |
| 55 | switch (ot) { |
| 56 | case OverheadType::kIndex: |
| 57 | return builder.getIndexType(); |
| 58 | case OverheadType::kU64: |
| 59 | return builder.getIntegerType(64); |
| 60 | case OverheadType::kU32: |
| 61 | return builder.getIntegerType(32); |
| 62 | case OverheadType::kU16: |
| 63 | return builder.getIntegerType(16); |
| 64 | case OverheadType::kU8: |
| 65 | return builder.getIntegerType(8); |
| 66 | } |
| 67 | llvm_unreachable("Unknown OverheadType" ); |
| 68 | } |
| 69 | |
| 70 | OverheadType |
| 71 | mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) { |
| 72 | return overheadTypeEncoding(enc.getPosWidth()); |
| 73 | } |
| 74 | |
| 75 | OverheadType |
| 76 | mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) { |
| 77 | return overheadTypeEncoding(enc.getCrdWidth()); |
| 78 | } |
| 79 | |
| 80 | // TODO: we ought to add some `static_assert` tests to ensure that the |
| 81 | // `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder, |
| 82 | // {pos,crd}OverheadTypeEncoding(enc))` |
| 83 | |
| 84 | // TODO: Adjust the naming convention for the constructors of |
| 85 | // `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro |
| 86 | // here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce |
| 87 | // the possibility of typo bugs or things getting out of sync. |
| 88 | StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { |
| 89 | switch (ot) { |
| 90 | case OverheadType::kIndex: |
| 91 | return "0" ; |
| 92 | #define CASE(ONAME, O) \ |
| 93 | case OverheadType::kU##ONAME: \ |
| 94 | return #ONAME; |
| 95 | MLIR_SPARSETENSOR_FOREVERY_FIXED_O(CASE) |
| 96 | #undef CASE |
| 97 | } |
| 98 | llvm_unreachable("Unknown OverheadType" ); |
| 99 | } |
| 100 | |
| 101 | StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { |
| 102 | return overheadTypeFunctionSuffix(ot: overheadTypeEncoding(tp)); |
| 103 | } |
| 104 | |
| 105 | PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { |
| 106 | if (elemTp.isF64()) |
| 107 | return PrimaryType::kF64; |
| 108 | if (elemTp.isF32()) |
| 109 | return PrimaryType::kF32; |
| 110 | if (elemTp.isF16()) |
| 111 | return PrimaryType::kF16; |
| 112 | if (elemTp.isBF16()) |
| 113 | return PrimaryType::kBF16; |
| 114 | if (elemTp.isInteger(width: 64)) |
| 115 | return PrimaryType::kI64; |
| 116 | if (elemTp.isInteger(width: 32)) |
| 117 | return PrimaryType::kI32; |
| 118 | if (elemTp.isInteger(width: 16)) |
| 119 | return PrimaryType::kI16; |
| 120 | if (elemTp.isInteger(width: 8)) |
| 121 | return PrimaryType::kI8; |
| 122 | if (auto complexTp = dyn_cast<ComplexType>(elemTp)) { |
| 123 | auto complexEltTp = complexTp.getElementType(); |
| 124 | if (complexEltTp.isF64()) |
| 125 | return PrimaryType::kC64; |
| 126 | if (complexEltTp.isF32()) |
| 127 | return PrimaryType::kC32; |
| 128 | } |
| 129 | llvm_unreachable("Unknown primary type" ); |
| 130 | } |
| 131 | |
| 132 | StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { |
| 133 | switch (pt) { |
| 134 | #define CASE(VNAME, V) \ |
| 135 | case PrimaryType::k##VNAME: \ |
| 136 | return #VNAME; |
| 137 | MLIR_SPARSETENSOR_FOREVERY_V(CASE) |
| 138 | #undef CASE |
| 139 | } |
| 140 | llvm_unreachable("Unknown PrimaryType" ); |
| 141 | } |
| 142 | |
| 143 | StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { |
| 144 | return primaryTypeFunctionSuffix(pt: primaryTypeEncoding(elemTp)); |
| 145 | } |
| 146 | |
| 147 | //===----------------------------------------------------------------------===// |
| 148 | // Misc code generators. |
| 149 | //===----------------------------------------------------------------------===// |
| 150 | |
| 151 | Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, |
| 152 | Type dstTp) { |
| 153 | const Type srcTp = value.getType(); |
| 154 | if (srcTp == dstTp) |
| 155 | return value; |
| 156 | |
| 157 | // int <=> index |
| 158 | if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp)) |
| 159 | return builder.create<arith::IndexCastOp>(loc, dstTp, value); |
| 160 | |
| 161 | const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp); |
| 162 | const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false; |
| 163 | return mlir::convertScalarToDtype(b&: builder, loc, operand: value, toType: dstTp, isUnsignedCast); |
| 164 | } |
| 165 | |
| 166 | Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc, |
| 167 | Value elem, Type dstTp) { |
| 168 | if (auto rtp = dyn_cast<RankedTensorType>(dstTp)) { |
| 169 | // Scalars can only be converted to 0-ranked tensors. |
| 170 | assert(rtp.getRank() == 0); |
| 171 | elem = sparse_tensor::genCast(builder, loc, value: elem, dstTp: rtp.getElementType()); |
| 172 | return builder.create<tensor::FromElementsOp>(loc, rtp, elem); |
| 173 | } |
| 174 | return sparse_tensor::genCast(builder, loc, value: elem, dstTp); |
| 175 | } |
| 176 | |
| 177 | Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem, |
| 178 | ValueRange s) { |
| 179 | Value load = builder.create<memref::LoadOp>(loc, mem, s); |
| 180 | if (!isa<IndexType>(Val: load.getType())) { |
| 181 | if (load.getType().getIntOrFloatBitWidth() < 64) |
| 182 | load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load); |
| 183 | load = |
| 184 | builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load); |
| 185 | } |
| 186 | return load; |
| 187 | } |
| 188 | |
| 189 | mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { |
| 190 | if (isa<FloatType>(Val: tp)) |
| 191 | return builder.getFloatAttr(tp, 1.0); |
| 192 | if (isa<IndexType>(Val: tp)) |
| 193 | return builder.getIndexAttr(1); |
| 194 | if (auto intTp = dyn_cast<IntegerType>(tp)) |
| 195 | return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); |
| 196 | if (isa<RankedTensorType, VectorType>(Val: tp)) { |
| 197 | auto shapedTp = cast<ShapedType>(tp); |
| 198 | if (auto one = getOneAttr(builder, shapedTp.getElementType())) |
| 199 | return DenseElementsAttr::get(shapedTp, one); |
| 200 | } |
| 201 | llvm_unreachable("Unsupported attribute type" ); |
| 202 | } |
| 203 | |
| 204 | Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, |
| 205 | Value v) { |
| 206 | Type tp = v.getType(); |
| 207 | Value zero = constantZero(builder, loc, tp); |
| 208 | if (isa<FloatType>(tp)) |
| 209 | return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, |
| 210 | zero); |
| 211 | if (tp.isIntOrIndex()) |
| 212 | return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, |
| 213 | zero); |
| 214 | if (isa<ComplexType>(tp)) |
| 215 | return builder.create<complex::NotEqualOp>(loc, v, zero); |
| 216 | llvm_unreachable("Non-numeric type" ); |
| 217 | } |
| 218 | |
| 219 | void mlir::sparse_tensor::genReshapeDstShape( |
| 220 | OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape, |
| 221 | ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape, |
| 222 | ArrayRef<ReassociationIndices> reassociation) { |
| 223 | // Collapse shape. |
| 224 | if (reassociation.size() < srcShape.size()) { |
| 225 | unsigned start = 0; |
| 226 | for (const auto &map : llvm::enumerate(First&: reassociation)) { |
| 227 | auto dstDim = constantIndex(builder, loc, i: 1); |
| 228 | for (unsigned i = start; i < start + map.value().size(); i++) { |
| 229 | dstDim = builder.create<arith::MulIOp>(loc, dstDim, srcShape[i]); |
| 230 | } |
| 231 | dstShape.push_back(Elt: dstDim); |
| 232 | start = start + map.value().size(); |
| 233 | } |
| 234 | assert(start == srcShape.size()); |
| 235 | return; |
| 236 | } |
| 237 | |
| 238 | // Expand shape. |
| 239 | assert(reassociation.size() == srcShape.size()); |
| 240 | unsigned start = 0; |
| 241 | // Expand the i-th dimension in srcShape. |
| 242 | for (unsigned i = 0, size = srcShape.size(); i < size; i++) { |
| 243 | const auto &map = reassociation[i]; |
| 244 | auto srcDim = srcShape[i]; |
| 245 | // Iterate through dimensions expanded from the i-th dimension. |
| 246 | for (unsigned j = start; j < start + map.size(); j++) { |
| 247 | // There can be only one dynamic sized dimension among dimensions |
| 248 | // expanded from the i-th dimension in srcShape. |
| 249 | // For example, if srcDim = 8, then the expanded shape could be <2x?x2>, |
| 250 | // but not <2x?x?>. |
| 251 | if (staticDstShape[j] == ShapedType::kDynamic) { |
| 252 | // The expanded dimension has dynamic size. We compute the dimension |
| 253 | // by dividing srcDim by the product of the static dimensions. |
| 254 | Size product = 1; |
| 255 | for (unsigned k = start; k < start + map.size(); k++) { |
| 256 | if (staticDstShape[k] != ShapedType::kDynamic) { |
| 257 | product *= staticDstShape[k]; |
| 258 | } |
| 259 | } |
| 260 | // Compute the dynamic dimension size. |
| 261 | Value productVal = constantIndex(builder, loc, i: product); |
| 262 | Value dynamicSize = |
| 263 | builder.create<arith::DivUIOp>(loc, srcDim, productVal); |
| 264 | dstShape.push_back(Elt: dynamicSize); |
| 265 | } else { |
| 266 | // The expanded dimension is statically known. |
| 267 | dstShape.push_back(Elt: constantIndex(builder, loc, i: staticDstShape[j])); |
| 268 | } |
| 269 | } |
| 270 | start = start + map.size(); |
| 271 | } |
| 272 | assert(start == staticDstShape.size()); |
| 273 | } |
| 274 | |
| 275 | void mlir::sparse_tensor::reshapeCvs( |
| 276 | OpBuilder &builder, Location loc, |
| 277 | ArrayRef<ReassociationIndices> reassociation, // NOLINT |
| 278 | ValueRange srcSizes, ValueRange srcCvs, // NOLINT |
| 279 | ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) { |
| 280 | const unsigned srcRank = srcSizes.size(); |
| 281 | const unsigned dstRank = dstSizes.size(); |
| 282 | assert(srcRank == srcCvs.size() && "Source rank mismatch" ); |
| 283 | const bool isCollapse = srcRank > dstRank; |
| 284 | const ValueRange sizes = isCollapse ? srcSizes : dstSizes; |
| 285 | // Iterate over reassociation map. |
| 286 | unsigned i = 0; |
| 287 | unsigned start = 0; |
| 288 | for (const auto &map : llvm::enumerate(First&: reassociation)) { |
| 289 | // Prepare strides information in dimension slice. |
| 290 | Value linear = constantIndex(builder, loc, i: 1); |
| 291 | for (unsigned j = start, end = start + map.value().size(); j < end; j++) { |
| 292 | linear = builder.create<arith::MulIOp>(loc, linear, sizes[j]); |
| 293 | } |
| 294 | // Start expansion. |
| 295 | Value val; |
| 296 | if (!isCollapse) |
| 297 | val = srcCvs[i]; |
| 298 | // Iterate over dimension slice. |
| 299 | for (unsigned j = start, end = start + map.value().size(); j < end; j++) { |
| 300 | linear = builder.create<arith::DivUIOp>(loc, linear, sizes[j]); |
| 301 | if (isCollapse) { |
| 302 | const Value mul = builder.create<arith::MulIOp>(loc, srcCvs[j], linear); |
| 303 | val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul; |
| 304 | } else { |
| 305 | const Value old = val; |
| 306 | val = builder.create<arith::DivUIOp>(loc, val, linear); |
| 307 | assert(dstCvs.size() == j); |
| 308 | dstCvs.push_back(Elt: val); |
| 309 | val = builder.create<arith::RemUIOp>(loc, old, linear); |
| 310 | } |
| 311 | } |
| 312 | // Finalize collapse. |
| 313 | if (isCollapse) { |
| 314 | assert(dstCvs.size() == i); |
| 315 | dstCvs.push_back(Elt: val); |
| 316 | } |
| 317 | start += map.value().size(); |
| 318 | i++; |
| 319 | } |
| 320 | assert(dstCvs.size() == dstRank); |
| 321 | } |
| 322 | |
| 323 | FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name, |
| 324 | TypeRange resultType, |
| 325 | ValueRange operands, |
| 326 | EmitCInterface emitCInterface) { |
| 327 | MLIRContext *context = module.getContext(); |
| 328 | auto result = SymbolRefAttr::get(context, name); |
| 329 | auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); |
| 330 | if (!func) { |
| 331 | OpBuilder moduleBuilder(module.getBodyRegion()); |
| 332 | func = moduleBuilder.create<func::FuncOp>( |
| 333 | module.getLoc(), name, |
| 334 | FunctionType::get(context, operands.getTypes(), resultType)); |
| 335 | func.setPrivate(); |
| 336 | if (static_cast<bool>(emitCInterface)) |
| 337 | func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
| 338 | UnitAttr::get(context)); |
| 339 | } |
| 340 | return result; |
| 341 | } |
| 342 | |
| 343 | func::CallOp mlir::sparse_tensor::createFuncCall( |
| 344 | OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, |
| 345 | ValueRange operands, EmitCInterface emitCInterface) { |
| 346 | auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>(); |
| 347 | FlatSymbolRefAttr fn = |
| 348 | getFunc(module, name, resultType, operands, emitCInterface); |
| 349 | return builder.create<func::CallOp>(loc, resultType, fn, operands); |
| 350 | } |
| 351 | |
| 352 | Type mlir::sparse_tensor::getOpaquePointerType(MLIRContext *ctx) { |
| 353 | return LLVM::LLVMPointerType::get(ctx); |
| 354 | } |
| 355 | |
| 356 | Type mlir::sparse_tensor::getOpaquePointerType(Builder &builder) { |
| 357 | return getOpaquePointerType(ctx: builder.getContext()); |
| 358 | } |
| 359 | |
| 360 | Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, |
| 361 | unsigned sz, Type tp, bool staticShape) { |
| 362 | if (staticShape) { |
| 363 | auto memTp = MemRefType::get({sz}, tp); |
| 364 | return builder.create<memref::AllocaOp>(loc, memTp); |
| 365 | } |
| 366 | return genAlloca(builder, loc, sz: constantIndex(builder, loc, i: sz), tp); |
| 367 | } |
| 368 | |
| 369 | Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz, |
| 370 | Type tp) { |
| 371 | auto memTp = MemRefType::get({ShapedType::kDynamic}, tp); |
| 372 | return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz}); |
| 373 | } |
| 374 | |
| 375 | Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc, |
| 376 | Type tp) { |
| 377 | return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp)); |
| 378 | } |
| 379 | |
| 380 | Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc, |
| 381 | ValueRange values) { |
| 382 | const unsigned sz = values.size(); |
| 383 | assert(sz >= 1); |
| 384 | Value buffer = genAlloca(builder, loc, sz, tp: values[0].getType()); |
| 385 | for (unsigned i = 0; i < sz; i++) { |
| 386 | Value idx = constantIndex(builder, loc, i); |
| 387 | builder.create<memref::StoreOp>(loc, values[i], buffer, idx); |
| 388 | } |
| 389 | return buffer; |
| 390 | } |
| 391 | |
| 392 | Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc, |
| 393 | RankedTensorType tensorTp, |
| 394 | ValueRange sizes) { |
| 395 | Type elemTp = tensorTp.getElementType(); |
| 396 | auto shape = tensorTp.getShape(); |
| 397 | auto memTp = MemRefType::get(shape, elemTp); |
| 398 | SmallVector<Value> dynamicSizes; |
| 399 | for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { |
| 400 | if (shape[i] == ShapedType::kDynamic) |
| 401 | dynamicSizes.push_back(Elt: sizes[i]); |
| 402 | } |
| 403 | Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes); |
| 404 | Value zero = constantZero(builder, loc, tp: elemTp); |
| 405 | builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem}); |
| 406 | return mem; |
| 407 | } |
| 408 | |
| 409 | void mlir::sparse_tensor::deallocDenseTensor(OpBuilder &builder, Location loc, |
| 410 | Value buffer) { |
| 411 | builder.create<memref::DeallocOp>(loc, buffer); |
| 412 | } |
| 413 | |
| 414 | void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder, |
| 415 | SmallVectorImpl<Value> &sizes, |
| 416 | Location loc, Value src) { |
| 417 | const Dimension dimRank = getSparseTensorType(val: src).getDimRank(); |
| 418 | for (Dimension d = 0; d < dimRank; d++) |
| 419 | sizes.push_back(Elt: linalg::createOrFoldDimOp(b&: builder, loc, val: src, dim: d)); |
| 420 | } |
| 421 | |
| 422 | Operation *mlir::sparse_tensor::getTop(Operation *op) { |
| 423 | for (; isa<scf::ForOp>(op->getParentOp()) || |
| 424 | isa<scf::WhileOp>(op->getParentOp()) || |
| 425 | isa<scf::ParallelOp>(op->getParentOp()) || |
| 426 | isa<scf::IfOp>(op->getParentOp()); |
| 427 | op = op->getParentOp()) |
| 428 | ; |
| 429 | return op; |
| 430 | } |
| 431 | |
| 432 | void sparse_tensor::foreachInSparseConstant( |
| 433 | OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, |
| 434 | function_ref<void(ArrayRef<Value>, Value)> callback) { |
| 435 | if (!order) |
| 436 | order = builder.getMultiDimIdentityMap(rank: attr.getType().getRank()); |
| 437 | |
| 438 | auto stt = SparseTensorType(getRankedTensorType(attr)); |
| 439 | const Dimension dimRank = stt.getDimRank(); |
| 440 | const auto coordinates = attr.getIndices().getValues<IntegerAttr>(); |
| 441 | const auto values = attr.getValues().getValues<Attribute>(); |
| 442 | |
| 443 | // This is like the `Element<V>` class in the runtime library, but for |
| 444 | // MLIR attributes. In the future we may want to move this out into |
| 445 | // a proper class definition to help improve code legibility (e.g., |
| 446 | // `first` -> `coords`, `second` -> `value`) as well as being able |
| 447 | // to factor out analogues of `ElementLT<V>` for the sort below, etc. |
| 448 | using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>; |
| 449 | |
| 450 | // Construct the COO from the SparseElementsAttr. |
| 451 | SmallVector<ElementAttr> elems; |
| 452 | for (size_t i = 0, nse = values.size(); i < nse; i++) { |
| 453 | elems.emplace_back(); |
| 454 | elems.back().second = values[i]; |
| 455 | auto &coords = elems.back().first; |
| 456 | coords.reserve(dimRank); |
| 457 | for (Dimension d = 0; d < dimRank; d++) |
| 458 | coords.push_back(coordinates[i * dimRank + d]); |
| 459 | } |
| 460 | |
| 461 | // Sorts the sparse element attribute based on coordinates. |
| 462 | llvm::sort(elems, [order](const ElementAttr &lhs, const ElementAttr &rhs) { |
| 463 | if (std::addressof(lhs) == std::addressof(rhs)) |
| 464 | return false; |
| 465 | |
| 466 | auto lhsCoords = llvm::map_to_vector( |
| 467 | lhs.first, [](IntegerAttr i) { return i.getInt(); }); |
| 468 | auto rhsCoords = llvm::map_to_vector( |
| 469 | rhs.first, [](IntegerAttr i) { return i.getInt(); }); |
| 470 | |
| 471 | SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords); |
| 472 | SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords); |
| 473 | // Sort the element based on the lvl coordinates. |
| 474 | for (Level l = 0; l < order.getNumResults(); l++) { |
| 475 | if (lhsLvlCrds[l] == rhsLvlCrds[l]) |
| 476 | continue; |
| 477 | return lhsLvlCrds[l] < rhsLvlCrds[l]; |
| 478 | } |
| 479 | llvm_unreachable("no equal coordinate in sparse element attr" ); |
| 480 | }); |
| 481 | |
| 482 | SmallVector<Value> cvs; |
| 483 | cvs.reserve(N: dimRank); |
| 484 | for (size_t i = 0, nse = values.size(); i < nse; i++) { |
| 485 | // Remap coordinates. |
| 486 | cvs.clear(); |
| 487 | for (Dimension d = 0; d < dimRank; d++) { |
| 488 | auto crd = elems[i].first[d].getInt(); |
| 489 | cvs.push_back(Elt: builder.create<arith::ConstantIndexOp>(loc, crd)); |
| 490 | } |
| 491 | // Remap value. |
| 492 | Value val; |
| 493 | if (isa<ComplexType>(attr.getElementType())) { |
| 494 | auto valAttr = cast<ArrayAttr>(elems[i].second); |
| 495 | val = builder.create<complex::ConstantOp>(loc, attr.getElementType(), |
| 496 | valAttr); |
| 497 | } else { |
| 498 | auto valAttr = cast<TypedAttr>(elems[i].second); |
| 499 | val = builder.create<arith::ConstantOp>(loc, valAttr); |
| 500 | } |
| 501 | assert(val); |
| 502 | callback(cvs, val); |
| 503 | } |
| 504 | } |
| 505 | |
| 506 | SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc, |
| 507 | size_t size, Value mem, |
| 508 | size_t offsetIdx, Value offsetVal) { |
| 509 | #ifndef NDEBUG |
| 510 | const auto memTp = cast<MemRefType>(mem.getType()); |
| 511 | assert(memTp.getRank() == 1); |
| 512 | const Size memSh = memTp.getDimSize(0); |
| 513 | assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size)); |
| 514 | assert(offsetIdx == 0 || offsetIdx < size); |
| 515 | #endif // NDEBUG |
| 516 | SmallVector<Value> vs; |
| 517 | vs.reserve(N: size); |
| 518 | for (unsigned i = 0; i < size; i++) { |
| 519 | Value v = builder.create<memref::LoadOp>(loc, mem, |
| 520 | constantIndex(builder, loc, i)); |
| 521 | if (i == offsetIdx && offsetVal) |
| 522 | v = builder.create<arith::AddIOp>(loc, v, offsetVal); |
| 523 | vs.push_back(Elt: v); |
| 524 | } |
| 525 | return vs; |
| 526 | } |
| 527 | |
| 528 | void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem, |
| 529 | ValueRange vs, size_t offsetIdx, Value offsetVal) { |
| 530 | #ifndef NDEBUG |
| 531 | const size_t vsize = vs.size(); |
| 532 | const auto memTp = cast<MemRefType>(mem.getType()); |
| 533 | assert(memTp.getRank() == 1); |
| 534 | const Size memSh = memTp.getDimSize(0); |
| 535 | assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize)); |
| 536 | assert(offsetIdx == 0 || offsetIdx < vsize); |
| 537 | #endif // NDEBUG |
| 538 | for (const auto &v : llvm::enumerate(First&: vs)) { |
| 539 | const Value w = |
| 540 | (offsetIdx == v.index() && offsetVal) |
| 541 | ? builder.create<arith::AddIOp>(loc, v.value(), offsetVal) |
| 542 | : v.value(); |
| 543 | builder.create<memref::StoreOp>(loc, w, mem, |
| 544 | constantIndex(builder, loc, v.index())); |
| 545 | } |
| 546 | } |
| 547 | |
| 548 | TypedValue<BaseMemRefType> |
| 549 | sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { |
| 550 | auto tTp = llvm::cast<TensorType>(Val: tensor.getType()); |
| 551 | auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType()); |
| 552 | return builder.create<bufferization::ToBufferOp>(loc, mTp, tensor) |
| 553 | .getResult(); |
| 554 | } |
| 555 | |
| 556 | Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, |
| 557 | Value tensor, Dimension dim) { |
| 558 | auto enc = getSparseTensorEncoding(tensor.getType()); |
| 559 | assert(enc && enc.isSlice()); |
| 560 | std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim); |
| 561 | if (offset.has_value()) |
| 562 | return constantIndex(builder, loc, i: *offset); |
| 563 | return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim)); |
| 564 | } |
| 565 | |
| 566 | Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, |
| 567 | Value tensor, Dimension dim) { |
| 568 | auto enc = getSparseTensorEncoding(tensor.getType()); |
| 569 | assert(enc && enc.isSlice()); |
| 570 | std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim); |
| 571 | if (stride.has_value()) |
| 572 | return constantIndex(builder, loc, i: *stride); |
| 573 | return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim)); |
| 574 | } |
| 575 | |
| 576 | Value sparse_tensor::genReader(OpBuilder &builder, Location loc, |
| 577 | SparseTensorType stt, Value tensor, |
| 578 | /*out*/ SmallVectorImpl<Value> &dimSizesValues, |
| 579 | /*out*/ Value &dimSizesBuffer) { |
| 580 | // Construct the dimension **shapes** buffer. The buffer contains the static |
| 581 | // size per dimension, or otherwise a zero for a dynamic size. |
| 582 | Dimension dimRank = stt.getDimRank(); |
| 583 | dimSizesValues.clear(); |
| 584 | dimSizesValues.reserve(N: dimRank); |
| 585 | for (const Size sz : stt.getDimShape()) { |
| 586 | const auto s = ShapedType::isDynamic(sz) ? 0 : sz; |
| 587 | dimSizesValues.push_back(Elt: constantIndex(builder, loc, s)); |
| 588 | } |
| 589 | Value dimShapesBuffer = allocaBuffer(builder, loc, values: dimSizesValues); |
| 590 | // Create the `CheckedSparseTensorReader`. This reader performs a |
| 591 | // consistency check on the static sizes, but accepts any size |
| 592 | // of each dimension with a dynamic size. |
| 593 | Type opaqueTp = getOpaquePointerType(builder); |
| 594 | Type eltTp = stt.getElementType(); |
| 595 | Value valTp = constantPrimaryTypeEncoding(builder, loc, elemTp: eltTp); |
| 596 | Value reader = |
| 597 | createFuncCall(builder, loc, "createCheckedSparseTensorReader" , opaqueTp, |
| 598 | {tensor, dimShapesBuffer, valTp}, EmitCInterface::On) |
| 599 | .getResult(0); |
| 600 | // For static shapes, the shape buffer can be used right away. For dynamic |
| 601 | // shapes, use the information from the reader to construct a buffer that |
| 602 | // supplies the actual size for each dynamic dimension. |
| 603 | dimSizesBuffer = dimShapesBuffer; |
| 604 | if (stt.hasDynamicDimShape()) { |
| 605 | Type indexTp = builder.getIndexType(); |
| 606 | auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp); |
| 607 | dimSizesBuffer = |
| 608 | createFuncCall(builder, loc, "getSparseTensorReaderDimSizes" , memTp, |
| 609 | reader, EmitCInterface::On) |
| 610 | .getResult(0); |
| 611 | // Also convert the dim shapes values into dim sizes values, just in case |
| 612 | // subsequent clients need the values (DCE will remove unused). |
| 613 | for (Dimension d = 0; d < dimRank; d++) { |
| 614 | if (stt.isDynamicDim(d)) |
| 615 | dimSizesValues[d] = builder.create<memref::LoadOp>( |
| 616 | loc, dimSizesBuffer, constantIndex(builder, loc, d)); |
| 617 | } |
| 618 | } |
| 619 | return reader; |
| 620 | } |
| 621 | |
| 622 | Value sparse_tensor::genMapBuffers( |
| 623 | OpBuilder &builder, Location loc, SparseTensorType stt, |
| 624 | ArrayRef<Value> dimSizesValues, Value dimSizesBuffer, |
| 625 | /*out*/ SmallVectorImpl<Value> &lvlSizesValues, |
| 626 | /*out*/ Value &dim2lvlBuffer, |
| 627 | /*out*/ Value &lvl2dimBuffer) { |
| 628 | const Dimension dimRank = stt.getDimRank(); |
| 629 | const Level lvlRank = stt.getLvlRank(); |
| 630 | lvlSizesValues.clear(); |
| 631 | lvlSizesValues.reserve(N: lvlRank); |
| 632 | // For an identity mapping, the dim2lvl and lvl2dim mappings are |
| 633 | // identical as are dimSizes and lvlSizes, so buffers are reused |
| 634 | // as much as possible. |
| 635 | if (stt.isIdentity()) { |
| 636 | assert(dimRank == lvlRank); |
| 637 | SmallVector<Value> iotaValues; |
| 638 | iotaValues.reserve(N: lvlRank); |
| 639 | for (Level l = 0; l < lvlRank; l++) { |
| 640 | iotaValues.push_back(Elt: constantIndex(builder, loc, i: l)); |
| 641 | lvlSizesValues.push_back(Elt: dimSizesValues[l]); |
| 642 | } |
| 643 | dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, values: iotaValues); |
| 644 | return dimSizesBuffer; // now lvlSizesBuffer |
| 645 | } |
| 646 | // Otherwise, some code needs to be generated to set up the buffers. |
| 647 | // This code deals with permutations as well as non-permutations that |
| 648 | // arise from rank changing blocking. |
| 649 | const auto dimToLvl = stt.getDimToLvl(); |
| 650 | const auto lvlToDim = stt.getLvlToDim(); |
| 651 | SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars |
| 652 | SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars |
| 653 | // Generate dim2lvl. |
| 654 | assert(lvlRank == dimToLvl.getNumResults()); |
| 655 | for (Level l = 0; l < lvlRank; l++) { |
| 656 | AffineExpr exp = dimToLvl.getResult(idx: l); |
| 657 | // We expect: |
| 658 | // (1) l = d |
| 659 | // (2) l = d / c |
| 660 | // (3) l = d % c |
| 661 | Dimension d = 0; |
| 662 | uint64_t cf = 0, cm = 0; |
| 663 | switch (exp.getKind()) { |
| 664 | case AffineExprKind::DimId: { |
| 665 | d = cast<AffineDimExpr>(Val&: exp).getPosition(); |
| 666 | break; |
| 667 | } |
| 668 | case AffineExprKind::FloorDiv: { |
| 669 | auto floor = cast<AffineBinaryOpExpr>(Val&: exp); |
| 670 | d = cast<AffineDimExpr>(Val: floor.getLHS()).getPosition(); |
| 671 | cf = cast<AffineConstantExpr>(Val: floor.getRHS()).getValue(); |
| 672 | break; |
| 673 | } |
| 674 | case AffineExprKind::Mod: { |
| 675 | auto mod = cast<AffineBinaryOpExpr>(Val&: exp); |
| 676 | d = cast<AffineDimExpr>(Val: mod.getLHS()).getPosition(); |
| 677 | cm = cast<AffineConstantExpr>(Val: mod.getRHS()).getValue(); |
| 678 | break; |
| 679 | } |
| 680 | default: |
| 681 | llvm::report_fatal_error(reason: "unsupported dim2lvl in sparse tensor type" ); |
| 682 | } |
| 683 | dim2lvlValues[l] = constantIndex(builder, loc, i: encodeDim(i: d, cf, cm)); |
| 684 | // Compute the level sizes. |
| 685 | // (1) l = d : size(d) |
| 686 | // (2) l = d / c : size(d) / c |
| 687 | // (3) l = d % c : c |
| 688 | Value lvlSz; |
| 689 | if (cm == 0) { |
| 690 | lvlSz = dimSizesValues[d]; |
| 691 | if (cf != 0) |
| 692 | lvlSz = builder.create<arith::DivUIOp>(loc, lvlSz, |
| 693 | constantIndex(builder, loc, cf)); |
| 694 | } else { |
| 695 | lvlSz = constantIndex(builder, loc, i: cm); |
| 696 | } |
| 697 | lvlSizesValues.push_back(Elt: lvlSz); |
| 698 | } |
| 699 | // Generate lvl2dim. |
| 700 | assert(dimRank == lvlToDim.getNumResults()); |
| 701 | for (Dimension d = 0; d < dimRank; d++) { |
| 702 | AffineExpr exp = lvlToDim.getResult(idx: d); |
| 703 | // We expect: |
| 704 | // (1) d = l |
| 705 | // (2) d = l' * c + l |
| 706 | Level l = 0, ll = 0; |
| 707 | uint64_t c = 0; |
| 708 | switch (exp.getKind()) { |
| 709 | case AffineExprKind::DimId: { |
| 710 | l = cast<AffineDimExpr>(Val&: exp).getPosition(); |
| 711 | break; |
| 712 | } |
| 713 | case AffineExprKind::Add: { |
| 714 | // Always mul on lhs, symbol/constant on rhs. |
| 715 | auto add = cast<AffineBinaryOpExpr>(Val&: exp); |
| 716 | assert(add.getLHS().getKind() == AffineExprKind::Mul); |
| 717 | auto mul = cast<AffineBinaryOpExpr>(Val: add.getLHS()); |
| 718 | ll = cast<AffineDimExpr>(Val: mul.getLHS()).getPosition(); |
| 719 | c = cast<AffineConstantExpr>(Val: mul.getRHS()).getValue(); |
| 720 | l = cast<AffineDimExpr>(Val: add.getRHS()).getPosition(); |
| 721 | break; |
| 722 | } |
| 723 | default: |
| 724 | llvm::report_fatal_error(reason: "unsupported lvl2dim in sparse tensor type" ); |
| 725 | } |
| 726 | lvl2dimValues[d] = constantIndex(builder, loc, i: encodeLvl(i: l, c, ii: ll)); |
| 727 | } |
| 728 | // Return buffers. |
| 729 | dim2lvlBuffer = allocaBuffer(builder, loc, values: dim2lvlValues); |
| 730 | lvl2dimBuffer = allocaBuffer(builder, loc, values: lvl2dimValues); |
| 731 | return allocaBuffer(builder, loc, values: lvlSizesValues); // lvlSizesBuffer |
| 732 | } |
| 733 | |