| 1 | //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "CodegenUtils.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 12 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 13 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 16 | #include "mlir/IR/Types.h" |
| 17 | #include "mlir/IR/Value.h" |
| 18 | #include <optional> |
| 19 | |
| 20 | using namespace mlir; |
| 21 | using namespace mlir::sparse_tensor; |
| 22 | |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | // ExecutionEngine/SparseTensorUtils helper functions. |
| 25 | //===----------------------------------------------------------------------===// |
| 26 | |
| 27 | OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { |
| 28 | switch (width) { |
| 29 | case 64: |
| 30 | return OverheadType::kU64; |
| 31 | case 32: |
| 32 | return OverheadType::kU32; |
| 33 | case 16: |
| 34 | return OverheadType::kU16; |
| 35 | case 8: |
| 36 | return OverheadType::kU8; |
| 37 | case 0: |
| 38 | return OverheadType::kIndex; |
| 39 | } |
| 40 | llvm_unreachable("Unsupported overhead bitwidth" ); |
| 41 | } |
| 42 | |
| 43 | OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { |
| 44 | if (tp.isIndex()) |
| 45 | return OverheadType::kIndex; |
| 46 | if (auto intTp = dyn_cast<IntegerType>(Val&: tp)) |
| 47 | return overheadTypeEncoding(width: intTp.getWidth()); |
| 48 | llvm_unreachable("Unknown overhead type" ); |
| 49 | } |
| 50 | |
| 51 | Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { |
| 52 | switch (ot) { |
| 53 | case OverheadType::kIndex: |
| 54 | return builder.getIndexType(); |
| 55 | case OverheadType::kU64: |
| 56 | return builder.getIntegerType(width: 64); |
| 57 | case OverheadType::kU32: |
| 58 | return builder.getIntegerType(width: 32); |
| 59 | case OverheadType::kU16: |
| 60 | return builder.getIntegerType(width: 16); |
| 61 | case OverheadType::kU8: |
| 62 | return builder.getIntegerType(width: 8); |
| 63 | } |
| 64 | llvm_unreachable("Unknown OverheadType" ); |
| 65 | } |
| 66 | |
| 67 | OverheadType |
| 68 | mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) { |
| 69 | return overheadTypeEncoding(width: enc.getPosWidth()); |
| 70 | } |
| 71 | |
| 72 | OverheadType |
| 73 | mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) { |
| 74 | return overheadTypeEncoding(width: enc.getCrdWidth()); |
| 75 | } |
| 76 | |
| 77 | // TODO: we ought to add some `static_assert` tests to ensure that the |
| 78 | // `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder, |
| 79 | // {pos,crd}OverheadTypeEncoding(enc))` |
| 80 | |
| 81 | // TODO: Adjust the naming convention for the constructors of |
| 82 | // `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro |
| 83 | // here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce |
| 84 | // the possibility of typo bugs or things getting out of sync. |
| 85 | StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { |
| 86 | switch (ot) { |
| 87 | case OverheadType::kIndex: |
| 88 | return "0" ; |
| 89 | #define CASE(ONAME, O) \ |
| 90 | case OverheadType::kU##ONAME: \ |
| 91 | return #ONAME; |
| 92 | MLIR_SPARSETENSOR_FOREVERY_FIXED_O(CASE) |
| 93 | #undef CASE |
| 94 | } |
| 95 | llvm_unreachable("Unknown OverheadType" ); |
| 96 | } |
| 97 | |
| 98 | StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { |
| 99 | return overheadTypeFunctionSuffix(ot: overheadTypeEncoding(tp)); |
| 100 | } |
| 101 | |
| 102 | PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { |
| 103 | if (elemTp.isF64()) |
| 104 | return PrimaryType::kF64; |
| 105 | if (elemTp.isF32()) |
| 106 | return PrimaryType::kF32; |
| 107 | if (elemTp.isF16()) |
| 108 | return PrimaryType::kF16; |
| 109 | if (elemTp.isBF16()) |
| 110 | return PrimaryType::kBF16; |
| 111 | if (elemTp.isInteger(width: 64)) |
| 112 | return PrimaryType::kI64; |
| 113 | if (elemTp.isInteger(width: 32)) |
| 114 | return PrimaryType::kI32; |
| 115 | if (elemTp.isInteger(width: 16)) |
| 116 | return PrimaryType::kI16; |
| 117 | if (elemTp.isInteger(width: 8)) |
| 118 | return PrimaryType::kI8; |
| 119 | if (auto complexTp = dyn_cast<ComplexType>(Val&: elemTp)) { |
| 120 | auto complexEltTp = complexTp.getElementType(); |
| 121 | if (complexEltTp.isF64()) |
| 122 | return PrimaryType::kC64; |
| 123 | if (complexEltTp.isF32()) |
| 124 | return PrimaryType::kC32; |
| 125 | } |
| 126 | llvm_unreachable("Unknown primary type" ); |
| 127 | } |
| 128 | |
| 129 | StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { |
| 130 | switch (pt) { |
| 131 | #define CASE(VNAME, V) \ |
| 132 | case PrimaryType::k##VNAME: \ |
| 133 | return #VNAME; |
| 134 | MLIR_SPARSETENSOR_FOREVERY_V(CASE) |
| 135 | #undef CASE |
| 136 | } |
| 137 | llvm_unreachable("Unknown PrimaryType" ); |
| 138 | } |
| 139 | |
| 140 | StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { |
| 141 | return primaryTypeFunctionSuffix(pt: primaryTypeEncoding(elemTp)); |
| 142 | } |
| 143 | |
| 144 | //===----------------------------------------------------------------------===// |
| 145 | // Misc code generators. |
| 146 | //===----------------------------------------------------------------------===// |
| 147 | |
| 148 | Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, |
| 149 | Type dstTp) { |
| 150 | const Type srcTp = value.getType(); |
| 151 | if (srcTp == dstTp) |
| 152 | return value; |
| 153 | |
| 154 | // int <=> index |
| 155 | if (isa<IndexType>(Val: srcTp) || isa<IndexType>(Val: dstTp)) |
| 156 | return builder.create<arith::IndexCastOp>(location: loc, args&: dstTp, args&: value); |
| 157 | |
| 158 | const auto srcIntTp = dyn_cast_or_null<IntegerType>(Val: srcTp); |
| 159 | const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false; |
| 160 | return mlir::convertScalarToDtype(b&: builder, loc, operand: value, toType: dstTp, isUnsignedCast); |
| 161 | } |
| 162 | |
| 163 | Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc, |
| 164 | Value elem, Type dstTp) { |
| 165 | if (auto rtp = dyn_cast<RankedTensorType>(Val&: dstTp)) { |
| 166 | // Scalars can only be converted to 0-ranked tensors. |
| 167 | assert(rtp.getRank() == 0); |
| 168 | elem = sparse_tensor::genCast(builder, loc, value: elem, dstTp: rtp.getElementType()); |
| 169 | return builder.create<tensor::FromElementsOp>(location: loc, args&: rtp, args&: elem); |
| 170 | } |
| 171 | return sparse_tensor::genCast(builder, loc, value: elem, dstTp); |
| 172 | } |
| 173 | |
| 174 | Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem, |
| 175 | ValueRange s) { |
| 176 | Value load = builder.create<memref::LoadOp>(location: loc, args&: mem, args&: s); |
| 177 | if (!isa<IndexType>(Val: load.getType())) { |
| 178 | if (load.getType().getIntOrFloatBitWidth() < 64) |
| 179 | load = builder.create<arith::ExtUIOp>(location: loc, args: builder.getI64Type(), args&: load); |
| 180 | load = |
| 181 | builder.create<arith::IndexCastOp>(location: loc, args: builder.getIndexType(), args&: load); |
| 182 | } |
| 183 | return load; |
| 184 | } |
| 185 | |
| 186 | mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { |
| 187 | if (isa<FloatType>(Val: tp)) |
| 188 | return builder.getFloatAttr(type: tp, value: 1.0); |
| 189 | if (isa<IndexType>(Val: tp)) |
| 190 | return builder.getIndexAttr(value: 1); |
| 191 | if (auto intTp = dyn_cast<IntegerType>(Val&: tp)) |
| 192 | return builder.getIntegerAttr(type: tp, value: APInt(intTp.getWidth(), 1)); |
| 193 | if (isa<RankedTensorType, VectorType>(Val: tp)) { |
| 194 | auto shapedTp = cast<ShapedType>(Val&: tp); |
| 195 | if (auto one = getOneAttr(builder, tp: shapedTp.getElementType())) |
| 196 | return DenseElementsAttr::get(type: shapedTp, values: one); |
| 197 | } |
| 198 | llvm_unreachable("Unsupported attribute type" ); |
| 199 | } |
| 200 | |
| 201 | Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, |
| 202 | Value v) { |
| 203 | Type tp = v.getType(); |
| 204 | Value zero = constantZero(builder, loc, tp); |
| 205 | if (isa<FloatType>(Val: tp)) |
| 206 | return builder.create<arith::CmpFOp>(location: loc, args: arith::CmpFPredicate::UNE, args&: v, |
| 207 | args&: zero); |
| 208 | if (tp.isIntOrIndex()) |
| 209 | return builder.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::ne, args&: v, |
| 210 | args&: zero); |
| 211 | if (isa<ComplexType>(Val: tp)) |
| 212 | return builder.create<complex::NotEqualOp>(location: loc, args&: v, args&: zero); |
| 213 | llvm_unreachable("Non-numeric type" ); |
| 214 | } |
| 215 | |
| 216 | void mlir::sparse_tensor::genReshapeDstShape( |
| 217 | OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape, |
| 218 | ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape, |
| 219 | ArrayRef<ReassociationIndices> reassociation) { |
| 220 | // Collapse shape. |
| 221 | if (reassociation.size() < srcShape.size()) { |
| 222 | unsigned start = 0; |
| 223 | for (const auto &map : llvm::enumerate(First&: reassociation)) { |
| 224 | auto dstDim = constantIndex(builder, loc, i: 1); |
| 225 | for (unsigned i = start; i < start + map.value().size(); i++) { |
| 226 | dstDim = builder.create<arith::MulIOp>(location: loc, args&: dstDim, args: srcShape[i]); |
| 227 | } |
| 228 | dstShape.push_back(Elt: dstDim); |
| 229 | start = start + map.value().size(); |
| 230 | } |
| 231 | assert(start == srcShape.size()); |
| 232 | return; |
| 233 | } |
| 234 | |
| 235 | // Expand shape. |
| 236 | assert(reassociation.size() == srcShape.size()); |
| 237 | unsigned start = 0; |
| 238 | // Expand the i-th dimension in srcShape. |
| 239 | for (unsigned i = 0, size = srcShape.size(); i < size; i++) { |
| 240 | const auto &map = reassociation[i]; |
| 241 | auto srcDim = srcShape[i]; |
| 242 | // Iterate through dimensions expanded from the i-th dimension. |
| 243 | for (unsigned j = start; j < start + map.size(); j++) { |
| 244 | // There can be only one dynamic sized dimension among dimensions |
| 245 | // expanded from the i-th dimension in srcShape. |
| 246 | // For example, if srcDim = 8, then the expanded shape could be <2x?x2>, |
| 247 | // but not <2x?x?>. |
| 248 | if (staticDstShape[j] == ShapedType::kDynamic) { |
| 249 | // The expanded dimension has dynamic size. We compute the dimension |
| 250 | // by dividing srcDim by the product of the static dimensions. |
| 251 | Size product = 1; |
| 252 | for (unsigned k = start; k < start + map.size(); k++) { |
| 253 | if (staticDstShape[k] != ShapedType::kDynamic) { |
| 254 | product *= staticDstShape[k]; |
| 255 | } |
| 256 | } |
| 257 | // Compute the dynamic dimension size. |
| 258 | Value productVal = constantIndex(builder, loc, i: product); |
| 259 | Value dynamicSize = |
| 260 | builder.create<arith::DivUIOp>(location: loc, args&: srcDim, args&: productVal); |
| 261 | dstShape.push_back(Elt: dynamicSize); |
| 262 | } else { |
| 263 | // The expanded dimension is statically known. |
| 264 | dstShape.push_back(Elt: constantIndex(builder, loc, i: staticDstShape[j])); |
| 265 | } |
| 266 | } |
| 267 | start = start + map.size(); |
| 268 | } |
| 269 | assert(start == staticDstShape.size()); |
| 270 | } |
| 271 | |
| 272 | void mlir::sparse_tensor::reshapeCvs( |
| 273 | OpBuilder &builder, Location loc, |
| 274 | ArrayRef<ReassociationIndices> reassociation, // NOLINT |
| 275 | ValueRange srcSizes, ValueRange srcCvs, // NOLINT |
| 276 | ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) { |
| 277 | const unsigned srcRank = srcSizes.size(); |
| 278 | const unsigned dstRank = dstSizes.size(); |
| 279 | assert(srcRank == srcCvs.size() && "Source rank mismatch" ); |
| 280 | const bool isCollapse = srcRank > dstRank; |
| 281 | const ValueRange sizes = isCollapse ? srcSizes : dstSizes; |
| 282 | // Iterate over reassociation map. |
| 283 | unsigned i = 0; |
| 284 | unsigned start = 0; |
| 285 | for (const auto &map : llvm::enumerate(First&: reassociation)) { |
| 286 | // Prepare strides information in dimension slice. |
| 287 | Value linear = constantIndex(builder, loc, i: 1); |
| 288 | for (unsigned j = start, end = start + map.value().size(); j < end; j++) { |
| 289 | linear = builder.create<arith::MulIOp>(location: loc, args&: linear, args: sizes[j]); |
| 290 | } |
| 291 | // Start expansion. |
| 292 | Value val; |
| 293 | if (!isCollapse) |
| 294 | val = srcCvs[i]; |
| 295 | // Iterate over dimension slice. |
| 296 | for (unsigned j = start, end = start + map.value().size(); j < end; j++) { |
| 297 | linear = builder.create<arith::DivUIOp>(location: loc, args&: linear, args: sizes[j]); |
| 298 | if (isCollapse) { |
| 299 | const Value mul = builder.create<arith::MulIOp>(location: loc, args: srcCvs[j], args&: linear); |
| 300 | val = val ? builder.create<arith::AddIOp>(location: loc, args&: val, args: mul) : mul; |
| 301 | } else { |
| 302 | const Value old = val; |
| 303 | val = builder.create<arith::DivUIOp>(location: loc, args&: val, args&: linear); |
| 304 | assert(dstCvs.size() == j); |
| 305 | dstCvs.push_back(Elt: val); |
| 306 | val = builder.create<arith::RemUIOp>(location: loc, args: old, args&: linear); |
| 307 | } |
| 308 | } |
| 309 | // Finalize collapse. |
| 310 | if (isCollapse) { |
| 311 | assert(dstCvs.size() == i); |
| 312 | dstCvs.push_back(Elt: val); |
| 313 | } |
| 314 | start += map.value().size(); |
| 315 | i++; |
| 316 | } |
| 317 | assert(dstCvs.size() == dstRank); |
| 318 | } |
| 319 | |
| 320 | FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name, |
| 321 | TypeRange resultType, |
| 322 | ValueRange operands, |
| 323 | EmitCInterface emitCInterface) { |
| 324 | MLIRContext *context = module.getContext(); |
| 325 | auto result = SymbolRefAttr::get(ctx: context, value: name); |
| 326 | auto func = module.lookupSymbol<func::FuncOp>(name: result.getAttr()); |
| 327 | if (!func) { |
| 328 | OpBuilder moduleBuilder(module.getBodyRegion()); |
| 329 | func = moduleBuilder.create<func::FuncOp>( |
| 330 | location: module.getLoc(), args&: name, |
| 331 | args: FunctionType::get(context, inputs: operands.getTypes(), results: resultType)); |
| 332 | func.setPrivate(); |
| 333 | if (static_cast<bool>(emitCInterface)) |
| 334 | func->setAttr(name: LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
| 335 | value: UnitAttr::get(context)); |
| 336 | } |
| 337 | return result; |
| 338 | } |
| 339 | |
| 340 | func::CallOp mlir::sparse_tensor::createFuncCall( |
| 341 | OpBuilder &builder, Location loc, StringRef name, TypeRange resultType, |
| 342 | ValueRange operands, EmitCInterface emitCInterface) { |
| 343 | auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>(); |
| 344 | FlatSymbolRefAttr fn = |
| 345 | getFunc(module, name, resultType, operands, emitCInterface); |
| 346 | return builder.create<func::CallOp>(location: loc, args&: resultType, args&: fn, args&: operands); |
| 347 | } |
| 348 | |
| 349 | Type mlir::sparse_tensor::getOpaquePointerType(MLIRContext *ctx) { |
| 350 | return LLVM::LLVMPointerType::get(context: ctx); |
| 351 | } |
| 352 | |
| 353 | Type mlir::sparse_tensor::getOpaquePointerType(Builder &builder) { |
| 354 | return getOpaquePointerType(ctx: builder.getContext()); |
| 355 | } |
| 356 | |
| 357 | Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, |
| 358 | unsigned sz, Type tp, bool staticShape) { |
| 359 | if (staticShape) { |
| 360 | auto memTp = MemRefType::get(shape: {sz}, elementType: tp); |
| 361 | return builder.create<memref::AllocaOp>(location: loc, args&: memTp); |
| 362 | } |
| 363 | return genAlloca(builder, loc, sz: constantIndex(builder, loc, i: sz), tp); |
| 364 | } |
| 365 | |
| 366 | Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz, |
| 367 | Type tp) { |
| 368 | auto memTp = MemRefType::get(shape: {ShapedType::kDynamic}, elementType: tp); |
| 369 | return builder.create<memref::AllocaOp>(location: loc, args&: memTp, args: ValueRange{sz}); |
| 370 | } |
| 371 | |
| 372 | Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc, |
| 373 | Type tp) { |
| 374 | return builder.create<memref::AllocaOp>(location: loc, args: MemRefType::get(shape: {}, elementType: tp)); |
| 375 | } |
| 376 | |
| 377 | Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc, |
| 378 | ValueRange values) { |
| 379 | const unsigned sz = values.size(); |
| 380 | assert(sz >= 1); |
| 381 | Value buffer = genAlloca(builder, loc, sz, tp: values[0].getType()); |
| 382 | for (unsigned i = 0; i < sz; i++) { |
| 383 | Value idx = constantIndex(builder, loc, i); |
| 384 | builder.create<memref::StoreOp>(location: loc, args: values[i], args&: buffer, args&: idx); |
| 385 | } |
| 386 | return buffer; |
| 387 | } |
| 388 | |
| 389 | Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc, |
| 390 | RankedTensorType tensorTp, |
| 391 | ValueRange sizes) { |
| 392 | Type elemTp = tensorTp.getElementType(); |
| 393 | auto shape = tensorTp.getShape(); |
| 394 | auto memTp = MemRefType::get(shape, elementType: elemTp); |
| 395 | SmallVector<Value> dynamicSizes; |
| 396 | for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) { |
| 397 | if (shape[i] == ShapedType::kDynamic) |
| 398 | dynamicSizes.push_back(Elt: sizes[i]); |
| 399 | } |
| 400 | Value mem = builder.create<memref::AllocOp>(location: loc, args&: memTp, args&: dynamicSizes); |
| 401 | Value zero = constantZero(builder, loc, tp: elemTp); |
| 402 | builder.create<linalg::FillOp>(location: loc, args: ValueRange{zero}, args: ValueRange{mem}); |
| 403 | return mem; |
| 404 | } |
| 405 | |
| 406 | void mlir::sparse_tensor::deallocDenseTensor(OpBuilder &builder, Location loc, |
| 407 | Value buffer) { |
| 408 | builder.create<memref::DeallocOp>(location: loc, args&: buffer); |
| 409 | } |
| 410 | |
| 411 | void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder, |
| 412 | SmallVectorImpl<Value> &sizes, |
| 413 | Location loc, Value src) { |
| 414 | const Dimension dimRank = getSparseTensorType(val: src).getDimRank(); |
| 415 | for (Dimension d = 0; d < dimRank; d++) |
| 416 | sizes.push_back(Elt: linalg::createOrFoldDimOp(b&: builder, loc, val: src, dim: d)); |
| 417 | } |
| 418 | |
| 419 | Operation *mlir::sparse_tensor::getTop(Operation *op) { |
| 420 | for (; isa<scf::ForOp>(Val: op->getParentOp()) || |
| 421 | isa<scf::WhileOp>(Val: op->getParentOp()) || |
| 422 | isa<scf::ParallelOp>(Val: op->getParentOp()) || |
| 423 | isa<scf::IfOp>(Val: op->getParentOp()); |
| 424 | op = op->getParentOp()) |
| 425 | ; |
| 426 | return op; |
| 427 | } |
| 428 | |
| 429 | void sparse_tensor::foreachInSparseConstant( |
| 430 | OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, |
| 431 | function_ref<void(ArrayRef<Value>, Value)> callback) { |
| 432 | if (!order) |
| 433 | order = builder.getMultiDimIdentityMap(rank: attr.getType().getRank()); |
| 434 | |
| 435 | auto stt = SparseTensorType(getRankedTensorType(t&: attr)); |
| 436 | const Dimension dimRank = stt.getDimRank(); |
| 437 | const auto coordinates = attr.getIndices().getValues<IntegerAttr>(); |
| 438 | const auto values = attr.getValues().getValues<Attribute>(); |
| 439 | |
| 440 | // This is like the `Element<V>` class in the runtime library, but for |
| 441 | // MLIR attributes. In the future we may want to move this out into |
| 442 | // a proper class definition to help improve code legibility (e.g., |
| 443 | // `first` -> `coords`, `second` -> `value`) as well as being able |
| 444 | // to factor out analogues of `ElementLT<V>` for the sort below, etc. |
| 445 | using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>; |
| 446 | |
| 447 | // Construct the COO from the SparseElementsAttr. |
| 448 | SmallVector<ElementAttr> elems; |
| 449 | for (size_t i = 0, nse = values.size(); i < nse; i++) { |
| 450 | elems.emplace_back(); |
| 451 | elems.back().second = values[i]; |
| 452 | auto &coords = elems.back().first; |
| 453 | coords.reserve(N: dimRank); |
| 454 | for (Dimension d = 0; d < dimRank; d++) |
| 455 | coords.push_back(Elt: coordinates[i * dimRank + d]); |
| 456 | } |
| 457 | |
| 458 | // Sorts the sparse element attribute based on coordinates. |
| 459 | llvm::sort(C&: elems, Comp: [order](const ElementAttr &lhs, const ElementAttr &rhs) { |
| 460 | if (std::addressof(r: lhs) == std::addressof(r: rhs)) |
| 461 | return false; |
| 462 | |
| 463 | auto lhsCoords = llvm::map_to_vector( |
| 464 | C: lhs.first, F: [](IntegerAttr i) { return i.getInt(); }); |
| 465 | auto rhsCoords = llvm::map_to_vector( |
| 466 | C: rhs.first, F: [](IntegerAttr i) { return i.getInt(); }); |
| 467 | |
| 468 | SmallVector<int64_t, 4> lhsLvlCrds = order.compose(values: lhsCoords); |
| 469 | SmallVector<int64_t, 4> rhsLvlCrds = order.compose(values: rhsCoords); |
| 470 | // Sort the element based on the lvl coordinates. |
| 471 | for (Level l = 0; l < order.getNumResults(); l++) { |
| 472 | if (lhsLvlCrds[l] == rhsLvlCrds[l]) |
| 473 | continue; |
| 474 | return lhsLvlCrds[l] < rhsLvlCrds[l]; |
| 475 | } |
| 476 | llvm_unreachable("no equal coordinate in sparse element attr" ); |
| 477 | }); |
| 478 | |
| 479 | SmallVector<Value> cvs; |
| 480 | cvs.reserve(N: dimRank); |
| 481 | for (size_t i = 0, nse = values.size(); i < nse; i++) { |
| 482 | // Remap coordinates. |
| 483 | cvs.clear(); |
| 484 | for (Dimension d = 0; d < dimRank; d++) { |
| 485 | auto crd = elems[i].first[d].getInt(); |
| 486 | cvs.push_back(Elt: builder.create<arith::ConstantIndexOp>(location: loc, args&: crd)); |
| 487 | } |
| 488 | // Remap value. |
| 489 | Value val; |
| 490 | if (isa<ComplexType>(Val: attr.getElementType())) { |
| 491 | auto valAttr = cast<ArrayAttr>(Val&: elems[i].second); |
| 492 | val = builder.create<complex::ConstantOp>(location: loc, args: attr.getElementType(), |
| 493 | args&: valAttr); |
| 494 | } else { |
| 495 | auto valAttr = cast<TypedAttr>(Val&: elems[i].second); |
| 496 | val = builder.create<arith::ConstantOp>(location: loc, args&: valAttr); |
| 497 | } |
| 498 | assert(val); |
| 499 | callback(cvs, val); |
| 500 | } |
| 501 | } |
| 502 | |
| 503 | SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc, |
| 504 | size_t size, Value mem, |
| 505 | size_t offsetIdx, Value offsetVal) { |
| 506 | #ifndef NDEBUG |
| 507 | const auto memTp = cast<MemRefType>(mem.getType()); |
| 508 | assert(memTp.getRank() == 1); |
| 509 | const Size memSh = memTp.getDimSize(0); |
| 510 | assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size)); |
| 511 | assert(offsetIdx == 0 || offsetIdx < size); |
| 512 | #endif // NDEBUG |
| 513 | SmallVector<Value> vs; |
| 514 | vs.reserve(N: size); |
| 515 | for (unsigned i = 0; i < size; i++) { |
| 516 | Value v = builder.create<memref::LoadOp>(location: loc, args&: mem, |
| 517 | args: constantIndex(builder, loc, i)); |
| 518 | if (i == offsetIdx && offsetVal) |
| 519 | v = builder.create<arith::AddIOp>(location: loc, args&: v, args&: offsetVal); |
| 520 | vs.push_back(Elt: v); |
| 521 | } |
| 522 | return vs; |
| 523 | } |
| 524 | |
| 525 | void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem, |
| 526 | ValueRange vs, size_t offsetIdx, Value offsetVal) { |
| 527 | #ifndef NDEBUG |
| 528 | const size_t vsize = vs.size(); |
| 529 | const auto memTp = cast<MemRefType>(mem.getType()); |
| 530 | assert(memTp.getRank() == 1); |
| 531 | const Size memSh = memTp.getDimSize(0); |
| 532 | assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize)); |
| 533 | assert(offsetIdx == 0 || offsetIdx < vsize); |
| 534 | #endif // NDEBUG |
| 535 | for (const auto &v : llvm::enumerate(First&: vs)) { |
| 536 | const Value w = |
| 537 | (offsetIdx == v.index() && offsetVal) |
| 538 | ? builder.create<arith::AddIOp>(location: loc, args&: v.value(), args&: offsetVal) |
| 539 | : v.value(); |
| 540 | builder.create<memref::StoreOp>(location: loc, args: w, args&: mem, |
| 541 | args: constantIndex(builder, loc, i: v.index())); |
| 542 | } |
| 543 | } |
| 544 | |
| 545 | TypedValue<BaseMemRefType> |
| 546 | sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) { |
| 547 | auto tTp = llvm::cast<TensorType>(Val: tensor.getType()); |
| 548 | auto mTp = MemRefType::get(shape: tTp.getShape(), elementType: tTp.getElementType()); |
| 549 | return cast<TypedValue<BaseMemRefType>>( |
| 550 | Val: builder.create<bufferization::ToBufferOp>(location: loc, args&: mTp, args&: tensor).getResult()); |
| 551 | } |
| 552 | |
| 553 | Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, |
| 554 | Value tensor, Dimension dim) { |
| 555 | auto enc = getSparseTensorEncoding(type: tensor.getType()); |
| 556 | assert(enc && enc.isSlice()); |
| 557 | std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim); |
| 558 | if (offset.has_value()) |
| 559 | return constantIndex(builder, loc, i: *offset); |
| 560 | return builder.create<ToSliceOffsetOp>(location: loc, args&: tensor, args: APInt(64, dim)); |
| 561 | } |
| 562 | |
| 563 | Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, |
| 564 | Value tensor, Dimension dim) { |
| 565 | auto enc = getSparseTensorEncoding(type: tensor.getType()); |
| 566 | assert(enc && enc.isSlice()); |
| 567 | std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim); |
| 568 | if (stride.has_value()) |
| 569 | return constantIndex(builder, loc, i: *stride); |
| 570 | return builder.create<ToSliceStrideOp>(location: loc, args&: tensor, args: APInt(64, dim)); |
| 571 | } |
| 572 | |
| 573 | Value sparse_tensor::genReader(OpBuilder &builder, Location loc, |
| 574 | SparseTensorType stt, Value tensor, |
| 575 | /*out*/ SmallVectorImpl<Value> &dimSizesValues, |
| 576 | /*out*/ Value &dimSizesBuffer) { |
| 577 | // Construct the dimension **shapes** buffer. The buffer contains the static |
| 578 | // size per dimension, or otherwise a zero for a dynamic size. |
| 579 | Dimension dimRank = stt.getDimRank(); |
| 580 | dimSizesValues.clear(); |
| 581 | dimSizesValues.reserve(N: dimRank); |
| 582 | for (const Size sz : stt.getDimShape()) { |
| 583 | const auto s = ShapedType::isDynamic(dValue: sz) ? 0 : sz; |
| 584 | dimSizesValues.push_back(Elt: constantIndex(builder, loc, i: s)); |
| 585 | } |
| 586 | Value dimShapesBuffer = allocaBuffer(builder, loc, values: dimSizesValues); |
| 587 | // Create the `CheckedSparseTensorReader`. This reader performs a |
| 588 | // consistency check on the static sizes, but accepts any size |
| 589 | // of each dimension with a dynamic size. |
| 590 | Type opaqueTp = getOpaquePointerType(builder); |
| 591 | Type eltTp = stt.getElementType(); |
| 592 | Value valTp = constantPrimaryTypeEncoding(builder, loc, elemTp: eltTp); |
| 593 | Value reader = |
| 594 | createFuncCall(builder, loc, name: "createCheckedSparseTensorReader" , resultType: opaqueTp, |
| 595 | operands: {tensor, dimShapesBuffer, valTp}, emitCInterface: EmitCInterface::On) |
| 596 | .getResult(i: 0); |
| 597 | // For static shapes, the shape buffer can be used right away. For dynamic |
| 598 | // shapes, use the information from the reader to construct a buffer that |
| 599 | // supplies the actual size for each dynamic dimension. |
| 600 | dimSizesBuffer = dimShapesBuffer; |
| 601 | if (stt.hasDynamicDimShape()) { |
| 602 | Type indexTp = builder.getIndexType(); |
| 603 | auto memTp = MemRefType::get(shape: {ShapedType::kDynamic}, elementType: indexTp); |
| 604 | dimSizesBuffer = |
| 605 | createFuncCall(builder, loc, name: "getSparseTensorReaderDimSizes" , resultType: memTp, |
| 606 | operands: reader, emitCInterface: EmitCInterface::On) |
| 607 | .getResult(i: 0); |
| 608 | // Also convert the dim shapes values into dim sizes values, just in case |
| 609 | // subsequent clients need the values (DCE will remove unused). |
| 610 | for (Dimension d = 0; d < dimRank; d++) { |
| 611 | if (stt.isDynamicDim(d)) |
| 612 | dimSizesValues[d] = builder.create<memref::LoadOp>( |
| 613 | location: loc, args&: dimSizesBuffer, args: constantIndex(builder, loc, i: d)); |
| 614 | } |
| 615 | } |
| 616 | return reader; |
| 617 | } |
| 618 | |
| 619 | Value sparse_tensor::genMapBuffers( |
| 620 | OpBuilder &builder, Location loc, SparseTensorType stt, |
| 621 | ArrayRef<Value> dimSizesValues, Value dimSizesBuffer, |
| 622 | /*out*/ SmallVectorImpl<Value> &lvlSizesValues, |
| 623 | /*out*/ Value &dim2lvlBuffer, |
| 624 | /*out*/ Value &lvl2dimBuffer) { |
| 625 | const Dimension dimRank = stt.getDimRank(); |
| 626 | const Level lvlRank = stt.getLvlRank(); |
| 627 | lvlSizesValues.clear(); |
| 628 | lvlSizesValues.reserve(N: lvlRank); |
| 629 | // For an identity mapping, the dim2lvl and lvl2dim mappings are |
| 630 | // identical as are dimSizes and lvlSizes, so buffers are reused |
| 631 | // as much as possible. |
| 632 | if (stt.isIdentity()) { |
| 633 | assert(dimRank == lvlRank); |
| 634 | SmallVector<Value> iotaValues; |
| 635 | iotaValues.reserve(N: lvlRank); |
| 636 | for (Level l = 0; l < lvlRank; l++) { |
| 637 | iotaValues.push_back(Elt: constantIndex(builder, loc, i: l)); |
| 638 | lvlSizesValues.push_back(Elt: dimSizesValues[l]); |
| 639 | } |
| 640 | dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, values: iotaValues); |
| 641 | return dimSizesBuffer; // now lvlSizesBuffer |
| 642 | } |
| 643 | // Otherwise, some code needs to be generated to set up the buffers. |
| 644 | // This code deals with permutations as well as non-permutations that |
| 645 | // arise from rank changing blocking. |
| 646 | const auto dimToLvl = stt.getDimToLvl(); |
| 647 | const auto lvlToDim = stt.getLvlToDim(); |
| 648 | SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars |
| 649 | SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars |
| 650 | // Generate dim2lvl. |
| 651 | assert(lvlRank == dimToLvl.getNumResults()); |
| 652 | for (Level l = 0; l < lvlRank; l++) { |
| 653 | AffineExpr exp = dimToLvl.getResult(idx: l); |
| 654 | // We expect: |
| 655 | // (1) l = d |
| 656 | // (2) l = d / c |
| 657 | // (3) l = d % c |
| 658 | Dimension d = 0; |
| 659 | uint64_t cf = 0, cm = 0; |
| 660 | switch (exp.getKind()) { |
| 661 | case AffineExprKind::DimId: { |
| 662 | d = cast<AffineDimExpr>(Val&: exp).getPosition(); |
| 663 | break; |
| 664 | } |
| 665 | case AffineExprKind::FloorDiv: { |
| 666 | auto floor = cast<AffineBinaryOpExpr>(Val&: exp); |
| 667 | d = cast<AffineDimExpr>(Val: floor.getLHS()).getPosition(); |
| 668 | cf = cast<AffineConstantExpr>(Val: floor.getRHS()).getValue(); |
| 669 | break; |
| 670 | } |
| 671 | case AffineExprKind::Mod: { |
| 672 | auto mod = cast<AffineBinaryOpExpr>(Val&: exp); |
| 673 | d = cast<AffineDimExpr>(Val: mod.getLHS()).getPosition(); |
| 674 | cm = cast<AffineConstantExpr>(Val: mod.getRHS()).getValue(); |
| 675 | break; |
| 676 | } |
| 677 | default: |
| 678 | llvm::report_fatal_error(reason: "unsupported dim2lvl in sparse tensor type" ); |
| 679 | } |
| 680 | dim2lvlValues[l] = constantIndex(builder, loc, i: encodeDim(i: d, cf, cm)); |
| 681 | // Compute the level sizes. |
| 682 | // (1) l = d : size(d) |
| 683 | // (2) l = d / c : size(d) / c |
| 684 | // (3) l = d % c : c |
| 685 | Value lvlSz; |
| 686 | if (cm == 0) { |
| 687 | lvlSz = dimSizesValues[d]; |
| 688 | if (cf != 0) |
| 689 | lvlSz = builder.create<arith::DivUIOp>(location: loc, args&: lvlSz, |
| 690 | args: constantIndex(builder, loc, i: cf)); |
| 691 | } else { |
| 692 | lvlSz = constantIndex(builder, loc, i: cm); |
| 693 | } |
| 694 | lvlSizesValues.push_back(Elt: lvlSz); |
| 695 | } |
| 696 | // Generate lvl2dim. |
| 697 | assert(dimRank == lvlToDim.getNumResults()); |
| 698 | for (Dimension d = 0; d < dimRank; d++) { |
| 699 | AffineExpr exp = lvlToDim.getResult(idx: d); |
| 700 | // We expect: |
| 701 | // (1) d = l |
| 702 | // (2) d = l' * c + l |
| 703 | Level l = 0, ll = 0; |
| 704 | uint64_t c = 0; |
| 705 | switch (exp.getKind()) { |
| 706 | case AffineExprKind::DimId: { |
| 707 | l = cast<AffineDimExpr>(Val&: exp).getPosition(); |
| 708 | break; |
| 709 | } |
| 710 | case AffineExprKind::Add: { |
| 711 | // Always mul on lhs, symbol/constant on rhs. |
| 712 | auto add = cast<AffineBinaryOpExpr>(Val&: exp); |
| 713 | assert(add.getLHS().getKind() == AffineExprKind::Mul); |
| 714 | auto mul = cast<AffineBinaryOpExpr>(Val: add.getLHS()); |
| 715 | ll = cast<AffineDimExpr>(Val: mul.getLHS()).getPosition(); |
| 716 | c = cast<AffineConstantExpr>(Val: mul.getRHS()).getValue(); |
| 717 | l = cast<AffineDimExpr>(Val: add.getRHS()).getPosition(); |
| 718 | break; |
| 719 | } |
| 720 | default: |
| 721 | llvm::report_fatal_error(reason: "unsupported lvl2dim in sparse tensor type" ); |
| 722 | } |
| 723 | lvl2dimValues[d] = constantIndex(builder, loc, i: encodeLvl(i: l, c, ii: ll)); |
| 724 | } |
| 725 | // Return buffers. |
| 726 | dim2lvlBuffer = allocaBuffer(builder, loc, values: dim2lvlValues); |
| 727 | lvl2dimBuffer = allocaBuffer(builder, loc, values: lvl2dimValues); |
| 728 | return allocaBuffer(builder, loc, values: lvlSizesValues); // lvlSizesBuffer |
| 729 | } |
| 730 | |