| 1 | //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This is a prototype GPU codegenerator for the sparsifier. |
| 10 | // The objective is to eventually use the right combination of |
| 11 | // direct code generation and libary calls into vendor-specific |
| 12 | // highly optimized sparse libraries (e.g. cuSparse for CUDA). |
| 13 | // |
| 14 | //===----------------------------------------------------------------------===// |
| 15 | |
| 16 | #include "Utils/CodegenUtils.h" |
| 17 | #include "Utils/LoopEmitter.h" |
| 18 | |
| 19 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 20 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 21 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 22 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 25 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 26 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| 27 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 28 | #include "mlir/IR/IRMapping.h" |
| 29 | #include "mlir/IR/Matchers.h" |
| 30 | |
| 31 | using namespace mlir; |
| 32 | using namespace mlir::sparse_tensor; |
| 33 | |
| 34 | namespace { |
| 35 | |
| 36 | // Sparse formats supported by cuSparse. |
| 37 | enum class CuSparseFormat { |
| 38 | kNone, |
| 39 | kCOO, |
| 40 | kCSR, |
| 41 | kCSC, |
| 42 | kBSR, |
| 43 | }; |
| 44 | |
| 45 | //===----------------------------------------------------------------------===// |
| 46 | // Helper methods. |
| 47 | //===----------------------------------------------------------------------===// |
| 48 | |
| 49 | /// Marks the given top module as a GPU container module. |
| 50 | static void markAsGPUContainer(ModuleOp topModule) { |
| 51 | topModule->setAttr(name: gpu::GPUDialect::getContainerModuleAttrName(), |
| 52 | value: UnitAttr::get(context: topModule->getContext())); |
| 53 | } |
| 54 | |
| 55 | /// Constructs a new GPU module (for GPU kernels) inside the given top module, |
| 56 | /// or returns an existing GPU module if one was built previously. |
| 57 | static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) { |
| 58 | for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>()) |
| 59 | return op; // existing |
| 60 | markAsGPUContainer(topModule); |
| 61 | builder.setInsertionPointToStart(topModule.getBody()); |
| 62 | return builder.create<gpu::GPUModuleOp>(location: topModule->getLoc(), |
| 63 | args: "sparse_kernels" ); |
| 64 | } |
| 65 | |
| 66 | /// Constructs a new GPU kernel in the given GPU module. |
| 67 | static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, |
| 68 | SmallVectorImpl<Value> &args) { |
| 69 | // Get a unique kernel name. Not very creative, |
| 70 | // but we simply try kernel0, kernel1, etc. |
| 71 | unsigned kernelNumber = 0; |
| 72 | SmallString<16> kernelName; |
| 73 | do { |
| 74 | kernelName.clear(); |
| 75 | ("kernel" + Twine(kernelNumber++)).toStringRef(Out&: kernelName); |
| 76 | } while (gpuModule.lookupSymbol(name: kernelName)); |
| 77 | // Then we insert a new kernel with given arguments into the module. |
| 78 | builder.setInsertionPointToStart(gpuModule.getBody()); |
| 79 | SmallVector<Type> argsTp; |
| 80 | for (auto arg : args) |
| 81 | argsTp.push_back(Elt: arg.getType()); |
| 82 | FunctionType type = FunctionType::get(context: gpuModule->getContext(), inputs: argsTp, results: {}); |
| 83 | auto gpuFunc = |
| 84 | builder.create<gpu::GPUFuncOp>(location: gpuModule->getLoc(), args&: kernelName, args&: type); |
| 85 | gpuFunc->setAttr(name: gpu::GPUDialect::getKernelFuncAttrName(), |
| 86 | value: builder.getUnitAttr()); |
| 87 | return gpuFunc; |
| 88 | } |
| 89 | |
| 90 | /// Constructs code to launch GPU kernel. |
| 91 | static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, |
| 92 | SmallVectorImpl<Value> &args, |
| 93 | SmallVectorImpl<Value> &tokens, |
| 94 | unsigned numThreads) { |
| 95 | Location loc = gpuFunc->getLoc(); |
| 96 | Value none = TypedValue<::mlir::IntegerType>{}; |
| 97 | Value one = constantIndex(builder, loc, i: 1); |
| 98 | Value numT = constantIndex(builder, loc, i: numThreads); |
| 99 | gpu::KernelDim3 gridSize = {.x: one, .y: one, .z: one}; |
| 100 | gpu::KernelDim3 blckSize = {.x: numT, .y: one, .z: one}; |
| 101 | return builder |
| 102 | .create<gpu::LaunchFuncOp>(location: loc, args&: gpuFunc, args&: gridSize, args&: blckSize, |
| 103 | /*dynSharedMemSz*/ args&: none, args, |
| 104 | args: builder.getType<gpu::AsyncTokenType>(), args&: tokens) |
| 105 | .getAsyncToken(); |
| 106 | } |
| 107 | |
| 108 | /// Maps the provided ranked host buffer into the device address space. |
| 109 | /// Writes from the host are guaranteed to be visible to device kernels |
| 110 | /// that are launched afterwards. Writes from the device are guaranteed |
| 111 | /// to be visible on the host after synchronizing with the device kernel |
| 112 | /// completion. Needs to cast the buffer to a unranked buffer. |
| 113 | static Value genHostRegisterMemref(OpBuilder &builder, Location loc, |
| 114 | Value mem) { |
| 115 | MemRefType memTp = cast<MemRefType>(Val: mem.getType()); |
| 116 | UnrankedMemRefType resTp = |
| 117 | UnrankedMemRefType::get(elementType: memTp.getElementType(), /*memorySpace=*/0); |
| 118 | Value cast = builder.create<memref::CastOp>(location: loc, args&: resTp, args&: mem); |
| 119 | builder.create<gpu::HostRegisterOp>(location: loc, args&: cast); |
| 120 | return cast; |
| 121 | } |
| 122 | |
| 123 | /// Unmaps the provided buffer, expecting the casted buffer. |
| 124 | static void genHostUnregisterMemref(OpBuilder &builder, Location loc, |
| 125 | Value cast) { |
| 126 | builder.create<gpu::HostUnregisterOp>(location: loc, args&: cast); |
| 127 | } |
| 128 | |
| 129 | /// Generates first wait in an asynchronous chain. |
| 130 | static Value genFirstWait(OpBuilder &builder, Location loc) { |
| 131 | Type tokenType = builder.getType<gpu::AsyncTokenType>(); |
| 132 | return builder.create<gpu::WaitOp>(location: loc, args&: tokenType, args: ValueRange()) |
| 133 | .getAsyncToken(); |
| 134 | } |
| 135 | |
| 136 | /// Generates last, blocking wait in an asynchronous chain. |
| 137 | static void genBlockingWait(OpBuilder &builder, Location loc, |
| 138 | ValueRange operands) { |
| 139 | builder.create<gpu::WaitOp>(location: loc, args: Type(), args&: operands); |
| 140 | } |
| 141 | |
| 142 | /// Allocates memory on the device. |
| 143 | /// TODO: A `host_shared` attribute could be used to indicate that |
| 144 | /// the buffer is visible by both host and device, but lowering |
| 145 | /// that feature does not seem to be fully supported yet. |
| 146 | static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, |
| 147 | Value token) { |
| 148 | auto tp = cast<ShapedType>(Val: mem.getType()); |
| 149 | auto elemTp = tp.getElementType(); |
| 150 | auto shape = tp.getShape(); |
| 151 | auto memTp = MemRefType::get(shape, elementType: elemTp); |
| 152 | SmallVector<Value> dynamicSizes; |
| 153 | for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) { |
| 154 | if (shape[r] == ShapedType::kDynamic) { |
| 155 | Value dimOp = linalg::createOrFoldDimOp(b&: builder, loc, val: mem, dim: r); |
| 156 | dynamicSizes.push_back(Elt: dimOp); |
| 157 | } |
| 158 | } |
| 159 | return builder.create<gpu::AllocOp>(location: loc, args: TypeRange({memTp, token.getType()}), |
| 160 | args&: token, args&: dynamicSizes, args: ValueRange()); |
| 161 | } |
| 162 | |
| 163 | // Allocates a typed buffer on the host with given size. |
| 164 | static Value genHostBuffer(OpBuilder &builder, Location loc, Type type, |
| 165 | Value size) { |
| 166 | const auto memTp = MemRefType::get(shape: {ShapedType::kDynamic}, elementType: type); |
| 167 | return builder.create<memref::AllocOp>(location: loc, args: memTp, args&: size).getResult(); |
| 168 | } |
| 169 | |
| 170 | // Allocates a typed buffer on the device with given size. |
| 171 | static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type, |
| 172 | Value size, Value token) { |
| 173 | const auto memTp = MemRefType::get(shape: {ShapedType::kDynamic}, elementType: type); |
| 174 | return builder.create<gpu::AllocOp>(location: loc, args: TypeRange({memTp, token.getType()}), |
| 175 | args&: token, args&: size, args: ValueRange()); |
| 176 | } |
| 177 | |
| 178 | // Allocates a void buffer on the device with given size. |
| 179 | static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, |
| 180 | Value token) { |
| 181 | return genAllocBuffer(builder, loc, type: builder.getI8Type(), size, token); |
| 182 | } |
| 183 | |
| 184 | /// Deallocates memory from the device. |
| 185 | static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, |
| 186 | Value token) { |
| 187 | return builder.create<gpu::DeallocOp>(location: loc, args: token.getType(), args&: token, args&: mem) |
| 188 | .getAsyncToken(); |
| 189 | } |
| 190 | |
| 191 | /// Copies memory between host and device (direction is implicit). |
| 192 | static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst, |
| 193 | Value src, Value token) { |
| 194 | return builder.create<gpu::MemcpyOp>(location: loc, args: token.getType(), args&: token, args&: dst, args&: src) |
| 195 | .getAsyncToken(); |
| 196 | } |
| 197 | |
| 198 | /// Generates an alloc/copy pair. |
| 199 | static Value genAllocCopy(OpBuilder &builder, Location loc, Value b, |
| 200 | SmallVectorImpl<Value> &tokens) { |
| 201 | Value firstToken = genFirstWait(builder, loc); |
| 202 | auto alloc = genAllocMemRef(builder, loc, mem: b, token: firstToken); |
| 203 | Value devMem = alloc.getResult(i: 0); |
| 204 | Value depToken = alloc.getAsyncToken(); // copy-after-alloc |
| 205 | tokens.push_back(Elt: genCopyMemRef(builder, loc, dst: devMem, src: b, token: depToken)); |
| 206 | return devMem; |
| 207 | } |
| 208 | |
| 209 | /// Generates a memref from tensor operation. |
| 210 | static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, |
| 211 | Value tensor) { |
| 212 | auto tensorType = llvm::cast<ShapedType>(Val: tensor.getType()); |
| 213 | auto memrefType = |
| 214 | MemRefType::get(shape: tensorType.getShape(), elementType: tensorType.getElementType()); |
| 215 | return rewriter.create<bufferization::ToBufferOp>(location: loc, args&: memrefType, args&: tensor); |
| 216 | } |
| 217 | |
| 218 | /// Prepares the outlined arguments, passing scalars and buffers in. Here we |
| 219 | /// assume that the first buffer is the one allocated for output. We create |
| 220 | /// a set of properly chained asynchronous allocation/copy pairs to increase |
| 221 | /// overlap before launching the kernel. |
| 222 | static Value genParametersIn(OpBuilder &builder, Location loc, |
| 223 | SmallVectorImpl<Value> &scalars, |
| 224 | SmallVectorImpl<Value> &buffers, |
| 225 | SmallVectorImpl<Value> &args, |
| 226 | SmallVectorImpl<Value> &tokens, |
| 227 | bool useHostRegistrationForOut) { |
| 228 | Value out; |
| 229 | // Scalars are passed by value. |
| 230 | for (Value s : scalars) |
| 231 | args.push_back(Elt: s); |
| 232 | // Buffers are need to be made visible on device. |
| 233 | for (Value b : buffers) { |
| 234 | if (useHostRegistrationForOut) { |
| 235 | out = genHostRegisterMemref(builder, loc, mem: b); |
| 236 | args.push_back(Elt: b); |
| 237 | useHostRegistrationForOut = false; |
| 238 | continue; |
| 239 | } |
| 240 | args.push_back(Elt: genAllocCopy(builder, loc, b, tokens)); |
| 241 | } |
| 242 | return out; |
| 243 | } |
| 244 | |
| 245 | /// Finalizes the outlined arguments. The output buffer is copied depending |
| 246 | /// on the kernel token and then deallocated. All other buffers are simply |
| 247 | /// deallocated. Then we wait for all operations to complete. |
| 248 | static void genParametersOut(OpBuilder &builder, Location loc, Value out, |
| 249 | Value kernelToken, SmallVectorImpl<Value> &scalars, |
| 250 | SmallVectorImpl<Value> &buffers, |
| 251 | SmallVectorImpl<Value> &args, |
| 252 | SmallVectorImpl<Value> &tokens) { |
| 253 | unsigned base = scalars.size(); |
| 254 | for (unsigned i = base, e = args.size(); i < e; i++) { |
| 255 | Value firstToken; |
| 256 | if (i == base) { |
| 257 | // Assumed output parameter: unregister or copy-out. |
| 258 | if (out) { |
| 259 | genHostUnregisterMemref(builder, loc, cast: out); |
| 260 | out = Value(); |
| 261 | continue; |
| 262 | } |
| 263 | firstToken = |
| 264 | genCopyMemRef(builder, loc, dst: buffers[0], src: args[i], token: kernelToken); |
| 265 | } else { |
| 266 | firstToken = genFirstWait(builder, loc); |
| 267 | } |
| 268 | tokens.push_back(Elt: genDeallocMemRef(builder, loc, mem: args[i], token: firstToken)); |
| 269 | } |
| 270 | } |
| 271 | |
| 272 | /// Constructs code for new GPU kernel. |
| 273 | static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, |
| 274 | scf::ParallelOp forallOp, |
| 275 | SmallVectorImpl<Value> &constants, |
| 276 | SmallVectorImpl<Value> &scalars, |
| 277 | SmallVectorImpl<Value> &buffers) { |
| 278 | Location loc = gpuFunc->getLoc(); |
| 279 | Block &block = gpuFunc.getBody().front(); |
| 280 | rewriter.setInsertionPointToStart(&block); |
| 281 | |
| 282 | // Re-generate the constants, recapture all arguments. |
| 283 | unsigned arg = 0; |
| 284 | IRMapping irMap; |
| 285 | for (Value c : constants) |
| 286 | irMap.map(from: c, to: rewriter.clone(op&: *c.getDefiningOp())->getResult(idx: 0)); |
| 287 | for (Value s : scalars) |
| 288 | irMap.map(from: s, to: block.getArgument(i: arg++)); |
| 289 | for (Value b : buffers) |
| 290 | irMap.map(from: b, to: block.getArgument(i: arg++)); |
| 291 | |
| 292 | // Assume 1-dimensional grid/block configuration (only x dimension), |
| 293 | // so that: |
| 294 | // row = blockIdx.x * blockDim.x + threadIdx.x |
| 295 | // inc = blockDim.x * gridDim.x |
| 296 | Value bid = rewriter.create<gpu::BlockIdOp>(location: loc, args: gpu::Dimension::x); |
| 297 | Value bsz = rewriter.create<gpu::BlockDimOp>(location: loc, args: gpu::Dimension::x); |
| 298 | Value tid = rewriter.create<gpu::ThreadIdOp>(location: loc, args: gpu::Dimension::x); |
| 299 | Value gsz = rewriter.create<gpu::GridDimOp>(location: loc, args: gpu::Dimension::x); |
| 300 | Value mul = rewriter.create<arith::MulIOp>(location: loc, args&: bid, args&: bsz); |
| 301 | Value row = rewriter.create<arith::AddIOp>(location: loc, args&: mul, args&: tid); |
| 302 | Value inc = rewriter.create<arith::MulIOp>(location: loc, args&: bsz, args&: gsz); |
| 303 | |
| 304 | // Construct the iteration over the computational space that |
| 305 | // accounts for the fact that the total number of threads and |
| 306 | // the amount of work to be done usually do not match precisely. |
| 307 | // for (r = row; r < N; r += inc) { |
| 308 | // <loop-body> |
| 309 | // } |
| 310 | Value upper = irMap.lookup(from: forallOp.getUpperBound()[0]); |
| 311 | scf::ForOp forOp = rewriter.create<scf::ForOp>(location: loc, args&: row, args&: upper, args&: inc); |
| 312 | // The scf.for builder creates an empty block. scf.for does not allow multiple |
| 313 | // blocks in its region, so delete the block before `cloneRegionBefore` adds |
| 314 | // an additional block. |
| 315 | rewriter.eraseBlock(block: forOp.getBody()); |
| 316 | rewriter.cloneRegionBefore(region&: forallOp.getRegion(), parent&: forOp.getRegion(), |
| 317 | before: forOp.getRegion().begin(), mapping&: irMap); |
| 318 | // Replace the scf.reduce terminator. |
| 319 | rewriter.setInsertionPoint(forOp.getBody()->getTerminator()); |
| 320 | rewriter.replaceOpWithNewOp<scf::YieldOp>(op: forOp.getBody()->getTerminator()); |
| 321 | |
| 322 | // Done. |
| 323 | rewriter.setInsertionPointAfter(forOp); |
| 324 | rewriter.create<gpu::ReturnOp>(location: gpuFunc->getLoc()); |
| 325 | } |
| 326 | |
| 327 | //===----------------------------------------------------------------------===// |
| 328 | // Library helper methods. |
| 329 | //===----------------------------------------------------------------------===// |
| 330 | |
| 331 | /// Helper to detect a + b with arguments taken from given block. |
| 332 | static bool matchAddOfArgs(Block *block, Value val) { |
| 333 | if (auto *def = val.getDefiningOp()) { |
| 334 | if (isa<arith::AddFOp, arith::AddIOp>(Val: def)) { |
| 335 | Value a = block->getArguments()[0]; |
| 336 | Value b = block->getArguments()[1]; |
| 337 | return (def->getOperand(idx: 0) == a && def->getOperand(idx: 1) == b) || |
| 338 | (def->getOperand(idx: 0) == b && def->getOperand(idx: 1) == a); |
| 339 | } |
| 340 | } |
| 341 | return false; |
| 342 | } |
| 343 | |
| 344 | /// Helper to detect a * b with arguments taken from given block. |
| 345 | static bool matchMulOfArgs(Block *block, Value val) { |
| 346 | if (auto *def = val.getDefiningOp()) { |
| 347 | if (isa<arith::MulFOp, arith::MulIOp>(Val: def)) { |
| 348 | Value a = block->getArguments()[0]; |
| 349 | Value b = block->getArguments()[1]; |
| 350 | return (def->getOperand(idx: 0) == a && def->getOperand(idx: 1) == b) || |
| 351 | (def->getOperand(idx: 0) == b && def->getOperand(idx: 1) == a); |
| 352 | } |
| 353 | } |
| 354 | return false; |
| 355 | } |
| 356 | |
| 357 | /// Helper to detect x = x + a * b |
| 358 | static bool matchSumOfMultOfArgs(linalg::GenericOp op) { |
| 359 | auto yieldOp = cast<linalg::YieldOp>(Val: op.getRegion().front().getTerminator()); |
| 360 | if (auto *def = yieldOp.getOperand(i: 0).getDefiningOp()) { |
| 361 | if (isa<arith::AddFOp, arith::AddIOp>(Val: def)) { |
| 362 | Value x = op.getBlock()->getArguments()[2]; |
| 363 | return (def->getOperand(idx: 0) == x && |
| 364 | matchMulOfArgs(block: op.getBlock(), val: def->getOperand(idx: 1))) || |
| 365 | (def->getOperand(idx: 1) == x && |
| 366 | matchMulOfArgs(block: op.getBlock(), val: def->getOperand(idx: 0))); |
| 367 | } |
| 368 | } |
| 369 | return false; |
| 370 | } |
| 371 | |
| 372 | // Helper to detect c += spy(s) x (a * b) |
| 373 | static bool matchSumReductionOfMulUnary(linalg::GenericOp op) { |
| 374 | auto yieldOp = cast<linalg::YieldOp>(Val: op.getRegion().front().getTerminator()); |
| 375 | // The linalg yields a custom reduce result. |
| 376 | Value s_out = op.getBlock()->getArguments()[2]; |
| 377 | if (auto redOp = |
| 378 | yieldOp.getOperand(i: 0).getDefiningOp<sparse_tensor::ReduceOp>()) { |
| 379 | // The reduce consumes the output. |
| 380 | Value other; |
| 381 | if (s_out == redOp->getOperand(idx: 0)) |
| 382 | other = redOp->getOperand(idx: 1); |
| 383 | else if (s_out == redOp->getOperand(idx: 1)) |
| 384 | other = redOp->getOperand(idx: 0); |
| 385 | else |
| 386 | return false; |
| 387 | // The reduce op also consumes an unary which also consumes the output |
| 388 | // and does not define an absent value. |
| 389 | if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) { |
| 390 | if (s_out != unOp->getOperand(idx: 0) || !unOp.getAbsentRegion().empty()) |
| 391 | return false; |
| 392 | // And the bodies are as expected. |
| 393 | auto yieldUn = cast<sparse_tensor::YieldOp>( |
| 394 | Val: unOp.getRegion(i: 0).front().getTerminator()); |
| 395 | auto yieldRed = cast<sparse_tensor::YieldOp>( |
| 396 | Val: redOp.getRegion().front().getTerminator()); |
| 397 | return matchMulOfArgs(block: op.getBlock(), val: yieldUn.getOperand(i: 0)) && |
| 398 | matchAddOfArgs(block: &redOp.getRegion().front(), val: yieldRed.getOperand(i: 0)); |
| 399 | } |
| 400 | } |
| 401 | return false; |
| 402 | } |
| 403 | |
| 404 | /// Test for dense tensor. |
| 405 | static bool isDenseTensor(Value v) { |
| 406 | auto sTp = getSparseTensorType(val: v); |
| 407 | return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense(); |
| 408 | } |
| 409 | |
| 410 | /// Test for suitable positions/coordinates width. |
| 411 | static bool isAdmissibleMetaData(SparseTensorType &aTp) { |
| 412 | return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) && |
| 413 | (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16); |
| 414 | } |
| 415 | |
| 416 | /// Test for sorted COO matrix with suitable metadata. |
| 417 | static bool isAdmissibleCOO(SparseTensorType &aTp) { |
| 418 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && |
| 419 | aTp.isCompressedLvl(l: 0) && aTp.isOrderedLvl(l: 0) && !aTp.isUniqueLvl(l: 0) && |
| 420 | aTp.isSingletonLvl(l: 1) && aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) && |
| 421 | isAdmissibleMetaData(aTp); |
| 422 | } |
| 423 | |
| 424 | /// Test for CSR matrix with suitable metadata. |
| 425 | static bool isAdmissibleCSR(SparseTensorType &aTp) { |
| 426 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && |
| 427 | aTp.isDenseLvl(l: 0) && aTp.isCompressedLvl(l: 1) && aTp.isOrderedLvl(l: 1) && |
| 428 | aTp.isUniqueLvl(l: 1) && isAdmissibleMetaData(aTp); |
| 429 | } |
| 430 | |
| 431 | /// Test for CSC matrix with suitable metadata. |
| 432 | static bool isAdmissibleCSC(SparseTensorType &aTp) { |
| 433 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() && |
| 434 | aTp.isPermutation() && aTp.isDenseLvl(l: 0) && aTp.isCompressedLvl(l: 1) && |
| 435 | aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) && isAdmissibleMetaData(aTp); |
| 436 | } |
| 437 | |
| 438 | /// Test for BSR matrix with suitable metadata. |
| 439 | static bool isAdmissibleBSR(SparseTensorType &aTp) { |
| 440 | if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(l: 0) && |
| 441 | aTp.isCompressedLvl(l: 1) && aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) && |
| 442 | aTp.isDenseLvl(l: 2) && aTp.isDenseLvl(l: 3) && isAdmissibleMetaData(aTp)) { |
| 443 | // CuSparse only supports "square" blocks currently. |
| 444 | SmallVector<unsigned> dims = getBlockSize(dimToLvl: aTp.getDimToLvl()); |
| 445 | assert(dims.size() == 2); |
| 446 | return dims[0] == dims[1] && dims[0] > 1; |
| 447 | } |
| 448 | return false; |
| 449 | } |
| 450 | |
| 451 | /// Test for 2:4 matrix with suitable metadata. |
| 452 | static bool isAdmissible24(SparseTensorType &aTp) { |
| 453 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(l: 0) && |
| 454 | aTp.isDenseLvl(l: 1) && aTp.isNOutOfMLvl(l: 2) && isAdmissibleMetaData(aTp); |
| 455 | } |
| 456 | |
| 457 | /// Test for conversion into 2:4 matrix. |
| 458 | static bool isConversionInto24(Value v) { |
| 459 | if (auto cnv = v.getDefiningOp<ConvertOp>()) { |
| 460 | Value a = cnv.getResult(); |
| 461 | Value d = cnv.getSource(); |
| 462 | SparseTensorType aTp = getSparseTensorType(val: a); |
| 463 | return isDenseTensor(v: d) && isAdmissible24(aTp); |
| 464 | } |
| 465 | return false; |
| 466 | } |
| 467 | |
| 468 | /// Returns a suitable sparse format for the operation and given operand |
| 469 | /// types with cuSparse, or kNone if none is available. |
| 470 | static CuSparseFormat getCuSparseFormat(SparseTensorType aTp, |
| 471 | SparseTensorType bTp, |
| 472 | SparseTensorType cTp, bool enableRT, |
| 473 | bool isMatVec) { |
| 474 | // The other operands have a dense type. |
| 475 | if (bTp.hasEncoding() || cTp.hasEncoding()) |
| 476 | return CuSparseFormat::kNone; |
| 477 | // Now check for suitable operand type for the main operand. |
| 478 | if (isAdmissibleCOO(aTp)) |
| 479 | #ifdef CUSPARSE_COO_AOS |
| 480 | return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone; |
| 481 | #else |
| 482 | return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone; |
| 483 | #endif |
| 484 | if (isAdmissibleCSR(aTp)) |
| 485 | return CuSparseFormat::kCSR; |
| 486 | if (isAdmissibleCSC(aTp)) |
| 487 | return CuSparseFormat::kCSC; |
| 488 | if (isAdmissibleBSR(aTp)) |
| 489 | return CuSparseFormat::kBSR; |
| 490 | return CuSparseFormat::kNone; |
| 491 | } |
| 492 | |
| 493 | /// Generates the first positions/coordinates of a sparse matrix. |
| 494 | static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, |
| 495 | CuSparseFormat format, bool enableRT) { |
| 496 | if (format == CuSparseFormat::kCOO) { |
| 497 | // Library uses SoA COO, direct IR uses AoS COO. |
| 498 | if (enableRT) |
| 499 | return builder.create<ToCoordinatesOp>(location: loc, args&: a, args: 0); |
| 500 | return builder.create<ToCoordinatesBufferOp>(location: loc, args&: a); |
| 501 | } |
| 502 | // Formats CSR/CSC and BSR use positions at 1. |
| 503 | return builder.create<ToPositionsOp>(location: loc, args&: a, args: 1); |
| 504 | } |
| 505 | |
| 506 | /// Generates the second coordinates of a sparse matrix. |
| 507 | static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, |
| 508 | CuSparseFormat format, bool enableRT) { |
| 509 | bool isCOO = format == CuSparseFormat::kCOO; |
| 510 | if (isCOO && !enableRT) |
| 511 | return Value(); // nothing needed |
| 512 | // Formats CSR/CSC and BSR use coordinates at 1. |
| 513 | return builder.create<ToCoordinatesOp>(location: loc, args&: a, args: 1); |
| 514 | } |
| 515 | |
| 516 | /// Generates the sparse matrix handle. |
| 517 | static Operation *genSpMat(OpBuilder &builder, Location loc, |
| 518 | SparseTensorType &aTp, Type handleTp, Type tokenTp, |
| 519 | Value token, Value sz1, Value sz2, Value nseA, |
| 520 | Value rowA, Value colA, Value valA, |
| 521 | CuSparseFormat format, bool enableRT) { |
| 522 | if (format == CuSparseFormat::kCOO) { |
| 523 | // Library uses SoA COO, direct IR uses AoS COO. |
| 524 | if (enableRT) { |
| 525 | assert(colA); |
| 526 | return builder.create<gpu::CreateCooOp>(location: loc, args&: handleTp, args&: tokenTp, args&: token, |
| 527 | args&: sz1, args&: sz2, args&: nseA, args&: rowA, args&: colA, args&: valA); |
| 528 | } |
| 529 | #ifdef CUSPARSE_COO_AOS |
| 530 | assert(!colA); |
| 531 | return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token, |
| 532 | sz1, sz2, nseA, rowA, valA); |
| 533 | #else |
| 534 | llvm_unreachable("gpu::CreateCooAoSOp is deprecated" ); |
| 535 | #endif |
| 536 | } |
| 537 | assert(colA); |
| 538 | if (format == CuSparseFormat::kCSR) |
| 539 | return builder.create<gpu::CreateCsrOp>(location: loc, args&: handleTp, args&: tokenTp, args&: token, args&: sz1, |
| 540 | args&: sz2, args&: nseA, args&: rowA, args&: colA, args&: valA); |
| 541 | if (format == CuSparseFormat::kCSC) |
| 542 | return builder.create<gpu::CreateCscOp>(location: loc, args&: handleTp, args&: tokenTp, args&: token, args&: sz1, |
| 543 | args&: sz2, args&: nseA, args&: rowA, args&: colA, args&: valA); |
| 544 | // BSR requires a bit more work since we need to pass in the block size |
| 545 | // and all others sizes in terms of blocks (#block-rows, #block-cols, |
| 546 | // #nonzero-blocks). |
| 547 | assert(format == CuSparseFormat::kBSR); |
| 548 | SmallVector<unsigned> dims = getBlockSize(dimToLvl: aTp.getDimToLvl()); |
| 549 | assert(dims.size() == 2 && dims[0] == dims[1]); |
| 550 | uint64_t b = dims[0]; |
| 551 | Value bSz = constantIndex(builder, loc, i: b); |
| 552 | Value bRows = builder.create<arith::DivUIOp>(location: loc, args&: sz1, args&: bSz); |
| 553 | Value bCols = builder.create<arith::DivUIOp>(location: loc, args&: sz2, args&: bSz); |
| 554 | Value bNum = builder.create<arith::DivUIOp>( |
| 555 | location: loc, args&: nseA, args: constantIndex(builder, loc, i: b * b)); |
| 556 | return builder.create<gpu::CreateBsrOp>(location: loc, args&: handleTp, args&: tokenTp, args&: token, args&: bRows, |
| 557 | args&: bCols, args&: bNum, args&: bSz, args&: bSz, args&: rowA, args&: colA, |
| 558 | args&: valA); |
| 559 | } |
| 560 | |
| 561 | /// Match and rewrite SpMV kernel. |
| 562 | static LogicalResult rewriteSpMV(PatternRewriter &rewriter, |
| 563 | linalg::GenericOp op, bool enableRT) { |
| 564 | Location loc = op.getLoc(); |
| 565 | Value a = op.getOperand(i: 0); |
| 566 | Value x = op.getOperand(i: 1); |
| 567 | Value y = op.getOperand(i: 2); // we have y = Ax |
| 568 | SmallVector<Value> tokens; |
| 569 | |
| 570 | // Only admissible sparse matrix format and dense vectors (no BSR). |
| 571 | SparseTensorType aTp = getSparseTensorType(val: a); |
| 572 | SparseTensorType xTp = getSparseTensorType(val: x); |
| 573 | SparseTensorType yTp = getSparseTensorType(val: y); |
| 574 | auto format = getCuSparseFormat(aTp, bTp: xTp, cTp: yTp, enableRT, /*isMatVec=*/true); |
| 575 | if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) |
| 576 | return failure(); |
| 577 | |
| 578 | // Start sparse kernel and copy data from host to device. |
| 579 | // a : memR/memC/memV -> rowA,colA,valA |
| 580 | // x : memX -> vecX |
| 581 | // y : memY -> vecY |
| 582 | Value nseA = rewriter.create<NumberOfEntriesOp>(location: loc, args&: a); |
| 583 | Value szY = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
| 584 | Value szX = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
| 585 | Value memR = genFirstPosOrCrds(builder&: rewriter, loc, a, format, enableRT); |
| 586 | Value memC = genSecondCrds(builder&: rewriter, loc, a, format, enableRT); // or empty |
| 587 | Value memV = rewriter.create<ToValuesOp>(location: loc, args&: a); |
| 588 | Value rowA = genAllocCopy(builder&: rewriter, loc, b: memR, tokens); |
| 589 | Value colA = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value(); |
| 590 | Value valA = genAllocCopy(builder&: rewriter, loc, b: memV, tokens); |
| 591 | Value memX = genTensorToMemref(rewriter, loc, tensor: x); |
| 592 | Value vecX = genAllocCopy(builder&: rewriter, loc, b: memX, tokens); |
| 593 | Value memY = genTensorToMemref(rewriter, loc, tensor: y); |
| 594 | Value vecY = genAllocCopy(builder&: rewriter, loc, b: memY, tokens); |
| 595 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 596 | tokens.clear(); |
| 597 | |
| 598 | // Create sparse environment and sparse matrix/dense vector handles. |
| 599 | Type indexTp = rewriter.getIndexType(); |
| 600 | Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
| 601 | Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
| 602 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
| 603 | Value token = genFirstWait(builder&: rewriter, loc); |
| 604 | Operation *spGenA = |
| 605 | genSpMat(builder&: rewriter, loc, aTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szY, sz2: szX, |
| 606 | nseA, rowA, colA, valA, format, enableRT); |
| 607 | Value spMatA = spGenA->getResult(idx: 0); |
| 608 | token = spGenA->getResult(idx: 1); |
| 609 | auto dvecX = rewriter.create<gpu::CreateDnTensorOp>( |
| 610 | location: loc, args&: dnTensorHandleTp, args&: tokenTp, args&: token, args&: vecX, args&: szX); |
| 611 | Value dnX = dvecX.getResult(i: 0); |
| 612 | token = dvecX.getAsyncToken(); |
| 613 | auto dvecY = rewriter.create<gpu::CreateDnTensorOp>( |
| 614 | location: loc, args&: dnTensorHandleTp, args&: tokenTp, args&: token, args&: vecY, args&: szY); |
| 615 | Value dnY = dvecY.getResult(i: 0); |
| 616 | token = dvecY.getAsyncToken(); |
| 617 | auto dnYType = llvm::cast<ShapedType>(Val: y.getType()).getElementType(); |
| 618 | |
| 619 | // Precompute buffersize for SpMV. |
| 620 | auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>( |
| 621 | location: loc, args&: indexTp, args&: tokenTp, args&: token, args&: spMatA, args&: dnX, args&: dnY, |
| 622 | /*computeType=*/args&: dnYType); |
| 623 | Value bufferSz = bufferComp.getResult(i: 0); |
| 624 | token = bufferComp.getAsyncToken(); |
| 625 | auto buf = genAllocBuffer(builder&: rewriter, loc, size: bufferSz, token); |
| 626 | Value buffer = buf.getResult(i: 0); |
| 627 | token = buf.getAsyncToken(); |
| 628 | |
| 629 | // Perform the SpMV. |
| 630 | auto spmvComp = rewriter.create<gpu::SpMVOp>( |
| 631 | location: loc, args&: tokenTp, args&: token, args&: spMatA, args&: dnX, args&: dnY, /*computeType=*/args&: dnYType, args&: buffer); |
| 632 | token = spmvComp.getAsyncToken(); |
| 633 | |
| 634 | // Copy data back to host and free all the resoures. |
| 635 | token = rewriter.create<gpu::DestroySpMatOp>(location: loc, args&: tokenTp, args&: token, args&: spMatA) |
| 636 | .getAsyncToken(); |
| 637 | token = rewriter.create<gpu::DestroyDnTensorOp>(location: loc, args&: tokenTp, args&: token, args&: dnX) |
| 638 | .getAsyncToken(); |
| 639 | token = rewriter.create<gpu::DestroyDnTensorOp>(location: loc, args&: tokenTp, args&: token, args&: dnY) |
| 640 | .getAsyncToken(); |
| 641 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token); |
| 642 | if (colA) |
| 643 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token); |
| 644 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token); |
| 645 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token); |
| 646 | token = genDeallocMemRef(builder&: rewriter, loc, mem: vecX, token); |
| 647 | token = genCopyMemRef(builder&: rewriter, loc, dst: memY, src: vecY, token); |
| 648 | token = genDeallocMemRef(builder&: rewriter, loc, mem: vecY, token); |
| 649 | tokens.push_back(Elt: token); |
| 650 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 651 | tokens.clear(); |
| 652 | |
| 653 | // Done. |
| 654 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, args: y.getType(), args&: memY); |
| 655 | return success(); |
| 656 | } |
| 657 | |
| 658 | /// Match and rewrite SpMM kernel. |
| 659 | static LogicalResult rewriteSpMM(PatternRewriter &rewriter, |
| 660 | linalg::GenericOp op, bool enableRT) { |
| 661 | Location loc = op.getLoc(); |
| 662 | Value a = op.getOperand(i: 0); |
| 663 | Value b = op.getOperand(i: 1); |
| 664 | Value c = op.getOperand(i: 2); // we have C = AB |
| 665 | SmallVector<Value> tokens; |
| 666 | |
| 667 | // Only admissible sparse matrix format and dense matrices (no BSR). |
| 668 | SparseTensorType aTp = getSparseTensorType(val: a); |
| 669 | SparseTensorType bTp = getSparseTensorType(val: b); |
| 670 | SparseTensorType cTp = getSparseTensorType(val: c); |
| 671 | auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false); |
| 672 | if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) |
| 673 | return failure(); |
| 674 | |
| 675 | // Start sparse kernel and copy data from host to device. |
| 676 | // a : memR/memC/memV -> rowA,colA,valA |
| 677 | // b : bufB -> matB |
| 678 | // c : bufC -> matC |
| 679 | Value nseA = rewriter.create<NumberOfEntriesOp>(location: loc, args&: a); |
| 680 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
| 681 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
| 682 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1); |
| 683 | Value memR = genFirstPosOrCrds(builder&: rewriter, loc, a, format, enableRT); |
| 684 | Value memC = genSecondCrds(builder&: rewriter, loc, a, format, enableRT); // or empty |
| 685 | Value memV = rewriter.create<ToValuesOp>(location: loc, args&: a); |
| 686 | Value rowA = genAllocCopy(builder&: rewriter, loc, b: memR, tokens); |
| 687 | Value colA = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value(); |
| 688 | Value valA = genAllocCopy(builder&: rewriter, loc, b: memV, tokens); |
| 689 | Value bufB = genTensorToMemref(rewriter, loc, tensor: b); |
| 690 | Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens); |
| 691 | Value bufC = genTensorToMemref(rewriter, loc, tensor: c); |
| 692 | Value matC = genAllocCopy(builder&: rewriter, loc, b: bufC, tokens); |
| 693 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 694 | tokens.clear(); |
| 695 | |
| 696 | // Create sparse environment and sparse matrix/dense matrix handles. |
| 697 | Type indexTp = rewriter.getIndexType(); |
| 698 | Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
| 699 | Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
| 700 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
| 701 | Value token = genFirstWait(builder&: rewriter, loc); |
| 702 | Operation *spGenA = |
| 703 | genSpMat(builder&: rewriter, loc, aTp, handleTp: spMatHandleTp, tokenTp, token, sz1: szm, sz2: szk, |
| 704 | nseA, rowA, colA, valA, format, enableRT); |
| 705 | Value spMatA = spGenA->getResult(idx: 0); |
| 706 | token = spGenA->getResult(idx: 1); |
| 707 | auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( |
| 708 | location: loc, args&: dnTensorHandleTp, args&: tokenTp, args&: token, args&: matB, |
| 709 | args: SmallVector<Value>{szk, szn}); |
| 710 | Value dnB = dmatB.getResult(i: 0); |
| 711 | token = dmatB.getAsyncToken(); |
| 712 | auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( |
| 713 | location: loc, args&: dnTensorHandleTp, args&: tokenTp, args&: token, args&: matC, |
| 714 | args: SmallVector<Value>{szm, szn}); |
| 715 | Value dnC = dmatC.getResult(i: 0); |
| 716 | token = dmatC.getAsyncToken(); |
| 717 | auto dmatCType = llvm::cast<ShapedType>(Val: c.getType()).getElementType(); |
| 718 | |
| 719 | // Precompute buffersize for SpMM. |
| 720 | auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( |
| 721 | location: loc, args&: indexTp, args&: tokenTp, args&: token, args&: spMatA, args&: dnB, args&: dnC, |
| 722 | /*computeType=*/args&: dmatCType); |
| 723 | Value bufferSz = bufferComp.getResult(i: 0); |
| 724 | token = bufferComp.getAsyncToken(); |
| 725 | auto buf = genAllocBuffer(builder&: rewriter, loc, size: bufferSz, token); |
| 726 | Value buffer = buf.getResult(i: 0); |
| 727 | token = buf.getAsyncToken(); |
| 728 | auto dnCType = llvm::cast<ShapedType>(Val: c.getType()).getElementType(); |
| 729 | |
| 730 | // Perform the SpMM. |
| 731 | auto spmmComp = rewriter.create<gpu::SpMMOp>( |
| 732 | location: loc, args&: tokenTp, args&: token, args&: spMatA, args&: dnB, args&: dnC, /*computeType=*/args&: dnCType, args&: buffer); |
| 733 | token = spmmComp.getAsyncToken(); |
| 734 | |
| 735 | // Copy data back to host and free all the resoures. |
| 736 | token = rewriter.create<gpu::DestroySpMatOp>(location: loc, args&: tokenTp, args&: token, args&: spMatA) |
| 737 | .getAsyncToken(); |
| 738 | token = rewriter.create<gpu::DestroyDnTensorOp>(location: loc, args&: tokenTp, args&: token, args&: dnB) |
| 739 | .getAsyncToken(); |
| 740 | token = rewriter.create<gpu::DestroyDnTensorOp>(location: loc, args&: tokenTp, args&: token, args&: dnC) |
| 741 | .getAsyncToken(); |
| 742 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token); |
| 743 | if (colA) |
| 744 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token); |
| 745 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token); |
| 746 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token); |
| 747 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token); |
| 748 | token = genCopyMemRef(builder&: rewriter, loc, dst: bufC, src: matC, token); |
| 749 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matC, token); |
| 750 | tokens.push_back(Elt: token); |
| 751 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 752 | tokens.clear(); |
| 753 | |
| 754 | // Done. |
| 755 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, args: c.getType(), args&: bufC); |
| 756 | return success(); |
| 757 | } |
| 758 | |
| 759 | // Match and rewrite SpGEMM kernel. |
| 760 | static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, |
| 761 | linalg::GenericOp op, bool enableRT) { |
| 762 | Location loc = op.getLoc(); |
| 763 | Value a = op.getOperand(i: 0); |
| 764 | Value b = op.getOperand(i: 1); |
| 765 | Value c = op.getOperand(i: 2); // we have C = AB |
| 766 | SmallVector<Value> tokens; |
| 767 | |
| 768 | // Only CSR <- CSR x CSR supported. |
| 769 | auto format = CuSparseFormat::kCSR; |
| 770 | SparseTensorType aTp = getSparseTensorType(val: a); |
| 771 | SparseTensorType bTp = getSparseTensorType(val: b); |
| 772 | SparseTensorType cTp = getSparseTensorType(val: c); |
| 773 | if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(aTp&: bTp) || !isAdmissibleCSR(aTp&: cTp)) |
| 774 | return failure(); |
| 775 | |
| 776 | // Start sparse kernel and copy data from host to device. |
| 777 | // a : amemR/amemC/amemV -> rowA,colA,valA |
| 778 | // b : bmemR/bmemC/bmemV -> rowB,colB,valB |
| 779 | // c : materializes |
| 780 | auto dnCType = cTp.getElementType(); |
| 781 | Value nseA = rewriter.create<NumberOfEntriesOp>(location: loc, args&: a); |
| 782 | Value nseB = rewriter.create<NumberOfEntriesOp>(location: loc, args&: b); |
| 783 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
| 784 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
| 785 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1); |
| 786 | Value amemR = genFirstPosOrCrds(builder&: rewriter, loc, a, format, enableRT); |
| 787 | Value amemC = genSecondCrds(builder&: rewriter, loc, a, format, enableRT); // not empty |
| 788 | Value amemV = rewriter.create<ToValuesOp>(location: loc, args&: a); |
| 789 | Value bmemR = genFirstPosOrCrds(builder&: rewriter, loc, a: b, format, enableRT); |
| 790 | Value bmemC = genSecondCrds(builder&: rewriter, loc, a: b, format, enableRT); // not empty |
| 791 | Value bmemV = rewriter.create<ToValuesOp>(location: loc, args&: b); |
| 792 | Value rowA = genAllocCopy(builder&: rewriter, loc, b: amemR, tokens); |
| 793 | Value colA = genAllocCopy(builder&: rewriter, loc, b: amemC, tokens); |
| 794 | Value valA = genAllocCopy(builder&: rewriter, loc, b: amemV, tokens); |
| 795 | Value rowB = genAllocCopy(builder&: rewriter, loc, b: bmemR, tokens); |
| 796 | Value colB = genAllocCopy(builder&: rewriter, loc, b: bmemC, tokens); |
| 797 | Value valB = genAllocCopy(builder&: rewriter, loc, b: bmemV, tokens); |
| 798 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 799 | tokens.clear(); |
| 800 | |
| 801 | // Create sparse environment and sparse matrix/dense vector handles. |
| 802 | Type indexTp = rewriter.getIndexType(); |
| 803 | Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
| 804 | Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>(); |
| 805 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
| 806 | Value token = genFirstWait(builder&: rewriter, loc); |
| 807 | Operation *spGenA = |
| 808 | genSpMat(builder&: rewriter, loc, aTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szm, sz2: szk, |
| 809 | nseA, rowA, colA, valA, format, enableRT); |
| 810 | Value spMatA = spGenA->getResult(idx: 0); |
| 811 | token = spGenA->getResult(idx: 1); |
| 812 | Operation *spGenB = |
| 813 | genSpMat(builder&: rewriter, loc, aTp&: bTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szk, sz2: szn, |
| 814 | nseA: nseB, rowA: rowB, colA: colB, valA: valB, format, enableRT); |
| 815 | Value spMatB = spGenB->getResult(idx: 0); |
| 816 | token = spGenB->getResult(idx: 1); |
| 817 | |
| 818 | // Sparse matrix C materializes (also assumes beta == 0). |
| 819 | Value zero = constantIndex(builder&: rewriter, loc, i: 0); |
| 820 | Value one = constantIndex(builder&: rewriter, loc, i: 1); |
| 821 | Value mplus1 = rewriter.create<arith::AddIOp>(location: loc, args&: szm, args&: one); |
| 822 | auto e1 = genAllocBuffer(builder&: rewriter, loc, type: cTp.getPosType(), size: mplus1, token); |
| 823 | Value rowC = e1.getResult(i: 0); |
| 824 | token = e1.getAsyncToken(); |
| 825 | auto e2 = genAllocBuffer(builder&: rewriter, loc, type: cTp.getCrdType(), size: zero, token); |
| 826 | Value colC = e2.getResult(i: 0); // no free needed |
| 827 | token = e2.getAsyncToken(); |
| 828 | auto e3 = genAllocBuffer(builder&: rewriter, loc, type: dnCType, size: zero, token); |
| 829 | Value valC = e3.getResult(i: 0); // no free needed |
| 830 | token = e3.getAsyncToken(); |
| 831 | Operation *spGenC = |
| 832 | genSpMat(builder&: rewriter, loc, aTp&: cTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szm, sz2: szn, |
| 833 | nseA: zero, rowA: rowC, colA: colC, valA: valC, format, enableRT); |
| 834 | Value spMatC = spGenC->getResult(idx: 0); |
| 835 | token = spGenC->getResult(idx: 1); |
| 836 | |
| 837 | // Precompute buffersizes for SpGEMM. |
| 838 | Operation *descOp = |
| 839 | rewriter.create<gpu::SpGEMMCreateDescrOp>(location: loc, args&: descTp, args&: tokenTp, args&: token); |
| 840 | Value desc = descOp->getResult(idx: 0); |
| 841 | token = descOp->getResult(idx: 1); |
| 842 | Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
| 843 | location: loc, args&: indexTp, args&: tokenTp, args&: token, args&: desc, args: gpu::TransposeMode::NON_TRANSPOSE, |
| 844 | args: gpu::TransposeMode::NON_TRANSPOSE, args&: spMatA, args&: spMatB, args&: spMatC, args&: dnCType, args&: zero, |
| 845 | args&: valC, args: gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); |
| 846 | Value bufferSz1 = work1->getResult(idx: 0); |
| 847 | token = work1->getResult(idx: 1); |
| 848 | auto buf1 = genAllocBuffer(builder&: rewriter, loc, size: bufferSz1, token); |
| 849 | Value buffer1 = buf1.getResult(i: 0); |
| 850 | token = buf1.getAsyncToken(); |
| 851 | Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
| 852 | location: loc, args&: indexTp, args&: tokenTp, args&: token, args&: desc, args: gpu::TransposeMode::NON_TRANSPOSE, |
| 853 | args: gpu::TransposeMode::NON_TRANSPOSE, args&: spMatA, args&: spMatB, args&: spMatC, args&: dnCType, |
| 854 | args&: bufferSz1, args&: buffer1, |
| 855 | args: gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); |
| 856 | token = work2->getResult(idx: 1); |
| 857 | |
| 858 | // Compute step. |
| 859 | Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
| 860 | location: loc, args&: indexTp, args&: tokenTp, args&: token, args&: desc, args: gpu::TransposeMode::NON_TRANSPOSE, |
| 861 | args: gpu::TransposeMode::NON_TRANSPOSE, args&: spMatA, args&: spMatB, args&: spMatC, args&: dnCType, args&: zero, |
| 862 | args&: valC, args: gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); |
| 863 | Value bufferSz2 = compute1->getResult(idx: 0); |
| 864 | token = compute1->getResult(idx: 1); |
| 865 | auto buf2 = genAllocBuffer(builder&: rewriter, loc, size: bufferSz2, token); |
| 866 | Value buffer2 = buf2.getResult(i: 0); |
| 867 | token = buf2.getAsyncToken(); |
| 868 | Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
| 869 | location: loc, args&: indexTp, args&: tokenTp, args&: token, args&: desc, args: gpu::TransposeMode::NON_TRANSPOSE, |
| 870 | args: gpu::TransposeMode::NON_TRANSPOSE, args&: spMatA, args&: spMatB, args&: spMatC, args&: dnCType, |
| 871 | args&: bufferSz2, args&: buffer2, args: gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); |
| 872 | token = compute2->getResult(idx: 1); |
| 873 | |
| 874 | // Get sizes. |
| 875 | Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>( |
| 876 | location: loc, args&: indexTp, args&: indexTp, args&: indexTp, args&: tokenTp, args&: token, args&: spMatC); |
| 877 | Value nnz = sizes->getResult(idx: 2); |
| 878 | token = sizes->getResult(idx: 3); |
| 879 | auto a2 = genAllocBuffer(builder&: rewriter, loc, type: cTp.getCrdType(), size: nnz, token); |
| 880 | colC = a2.getResult(i: 0); |
| 881 | token = a2.getAsyncToken(); |
| 882 | auto a3 = genAllocBuffer(builder&: rewriter, loc, type: dnCType, size: nnz, token); |
| 883 | valC = a3.getResult(i: 0); |
| 884 | token = a3.getAsyncToken(); |
| 885 | |
| 886 | // Update C with new pointers and copy final product back into C. |
| 887 | Operation *update = rewriter.create<gpu::SetCsrPointersOp>( |
| 888 | location: loc, args&: tokenTp, args&: token, args&: spMatC, args&: rowC, args&: colC, args&: valC); |
| 889 | token = update->getResult(idx: 0); |
| 890 | Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>( |
| 891 | location: loc, args&: tokenTp, args&: token, args&: desc, args: gpu::TransposeMode::NON_TRANSPOSE, |
| 892 | args: gpu::TransposeMode::NON_TRANSPOSE, args&: spMatA, args&: spMatB, args&: spMatC, args&: dnCType); |
| 893 | token = copy->getResult(idx: 0); |
| 894 | |
| 895 | // Allocate buffers on host. |
| 896 | Value rowH = genHostBuffer(builder&: rewriter, loc, type: cTp.getPosType(), size: mplus1); |
| 897 | Value colH = genHostBuffer(builder&: rewriter, loc, type: cTp.getCrdType(), size: nnz); |
| 898 | Value valH = genHostBuffer(builder&: rewriter, loc, type: dnCType, size: nnz); |
| 899 | |
| 900 | // Copy data back to host and free all the resoures. |
| 901 | token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(location: loc, args&: tokenTp, args&: token, args&: desc) |
| 902 | .getAsyncToken(); |
| 903 | token = rewriter.create<gpu::DestroySpMatOp>(location: loc, args&: tokenTp, args&: token, args&: spMatA) |
| 904 | .getAsyncToken(); |
| 905 | token = rewriter.create<gpu::DestroySpMatOp>(location: loc, args&: tokenTp, args&: token, args&: spMatB) |
| 906 | .getAsyncToken(); |
| 907 | token = rewriter.create<gpu::DestroySpMatOp>(location: loc, args&: tokenTp, args&: token, args&: spMatC) |
| 908 | .getAsyncToken(); |
| 909 | token = genCopyMemRef(builder&: rewriter, loc, dst: rowH, src: rowC, token); |
| 910 | token = genCopyMemRef(builder&: rewriter, loc, dst: colH, src: colC, token); |
| 911 | token = genCopyMemRef(builder&: rewriter, loc, dst: valH, src: valC, token); |
| 912 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token); |
| 913 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token); |
| 914 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token); |
| 915 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowB, token); |
| 916 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colB, token); |
| 917 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valB, token); |
| 918 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowC, token); |
| 919 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colC, token); |
| 920 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valC, token); |
| 921 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer1, token); |
| 922 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer2, token); |
| 923 | tokens.push_back(Elt: token); |
| 924 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 925 | tokens.clear(); |
| 926 | |
| 927 | // Done. |
| 928 | Value vt = rewriter.create<bufferization::ToTensorOp>( |
| 929 | location: loc, args: memref::getTensorTypeFromMemRefType(type: valH.getType()), args&: valH); |
| 930 | Value rt = rewriter.create<bufferization::ToTensorOp>( |
| 931 | location: loc, args: memref::getTensorTypeFromMemRefType(type: rowH.getType()), args&: rowH); |
| 932 | Value ct = rewriter.create<bufferization::ToTensorOp>( |
| 933 | location: loc, args: memref::getTensorTypeFromMemRefType(type: colH.getType()), args&: colH); |
| 934 | rewriter.replaceOpWithNewOp<AssembleOp>(op, args: c.getType(), args: ValueRange{rt, ct}, |
| 935 | args&: vt); |
| 936 | return success(); |
| 937 | } |
| 938 | |
| 939 | // Match and rewrite 2:4 SpMM kernel. |
| 940 | static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, |
| 941 | linalg::GenericOp op) { |
| 942 | Location loc = op.getLoc(); |
| 943 | Value A = op.getOperand(i: 0); |
| 944 | Value B = op.getOperand(i: 1); |
| 945 | Value C = op.getOperand(i: 2); // we have C = AB |
| 946 | SmallVector<Value> tokens; |
| 947 | |
| 948 | // The cuSparselt API currently only allows pruning and compression |
| 949 | // to occur on the device. So we recognize the pattern |
| 950 | // A' = convert A ; dense to 2:4 |
| 951 | // C = A'B ; 2:4 matrix mult |
| 952 | // and then perform compression and matrix multiplication on device. |
| 953 | auto cnv = A.getDefiningOp<ConvertOp>(); |
| 954 | assert(cnv); |
| 955 | A = cnv.getSource(); |
| 956 | |
| 957 | // All input should be dense tensors. |
| 958 | if (!isDenseTensor(v: A) || !isDenseTensor(v: B) || !isDenseTensor(v: C)) |
| 959 | return failure(); |
| 960 | |
| 961 | // Start sparse kernel and copy data from host to device. |
| 962 | // a : bufA -> matA |
| 963 | // b : bufB -> matB |
| 964 | // c : bufC -> matC |
| 965 | Value bufA = genTensorToMemref(rewriter, loc, tensor: A); |
| 966 | Value matA = genAllocCopy(builder&: rewriter, loc, b: bufA, tokens); |
| 967 | Value bufB = genTensorToMemref(rewriter, loc, tensor: B); |
| 968 | Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens); |
| 969 | Value bufC = genTensorToMemref(rewriter, loc, tensor: C); |
| 970 | Value matC = genAllocCopy(builder&: rewriter, loc, b: bufC, tokens); |
| 971 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 972 | tokens.clear(); |
| 973 | |
| 974 | // Create sparse environment and sparse matrix/dense vector handles. |
| 975 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matA, dim: 0); |
| 976 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matB, dim: 0); |
| 977 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matC, dim: 1); |
| 978 | Type indexTp = rewriter.getIndexType(); |
| 979 | Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
| 980 | Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
| 981 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
| 982 | Value token = genFirstWait(builder&: rewriter, loc); |
| 983 | Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>( |
| 984 | location: loc, args&: spMatHandleTp, args&: tokenTp, args&: token, args&: szm, args&: szk, |
| 985 | args: gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, args&: matA); |
| 986 | Value spMatA = spGenA->getResult(idx: 0); |
| 987 | token = spGenA->getResult(idx: 1); |
| 988 | auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( |
| 989 | location: loc, args&: dnTensorHandleTp, args&: tokenTp, args&: token, args&: matB, |
| 990 | args: SmallVector<Value>{szk, szn}); |
| 991 | Value dnB = dmatB.getResult(i: 0); |
| 992 | token = dmatB.getAsyncToken(); |
| 993 | auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( |
| 994 | location: loc, args&: dnTensorHandleTp, args&: tokenTp, args&: token, args&: matC, |
| 995 | args: SmallVector<Value>{szm, szn}); |
| 996 | Value dnC = dmatC.getResult(i: 0); |
| 997 | token = dmatC.getAsyncToken(); |
| 998 | auto dmatCType = llvm::cast<ShapedType>(Val: matC.getType()).getElementType(); |
| 999 | |
| 1000 | // Precompute buffersize for SpMM. |
| 1001 | SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp}; |
| 1002 | TypeRange bufferTypes(bufferTypes_); |
| 1003 | auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( |
| 1004 | location: loc, args&: bufferTypes, args&: tokenTp, args&: token, args: gpu::TransposeMode::NON_TRANSPOSE, |
| 1005 | args: gpu::TransposeMode::NON_TRANSPOSE, args&: spMatA, args&: dnB, args&: dnC, |
| 1006 | /*computeType=*/args&: dmatCType); |
| 1007 | token = bufferComp.getAsyncToken(); |
| 1008 | |
| 1009 | // Allocate buffers on host. |
| 1010 | Value bufferSz1 = bufferComp.getResult(i: 0); |
| 1011 | auto buf1 = genAllocBuffer(builder&: rewriter, loc, size: bufferSz1, token); |
| 1012 | Value buffer1 = buf1.getResult(i: 0); |
| 1013 | token = buf1.getAsyncToken(); |
| 1014 | Value bufferSz2 = bufferComp.getResult(i: 1); |
| 1015 | auto buf2 = genAllocBuffer(builder&: rewriter, loc, size: bufferSz2, token); |
| 1016 | Value buffer2 = buf2.getResult(i: 0); |
| 1017 | token = buf2.getAsyncToken(); |
| 1018 | Value bufferSz3 = bufferComp.getResult(i: 2); |
| 1019 | auto buf3 = genAllocBuffer(builder&: rewriter, loc, size: bufferSz3, token); |
| 1020 | Value buffer3 = buf3.getResult(i: 0); |
| 1021 | token = buf3.getAsyncToken(); |
| 1022 | |
| 1023 | // Perform the SpMM. |
| 1024 | auto dnCType = llvm::cast<ShapedType>(Val: matC.getType()).getElementType(); |
| 1025 | auto spmmComp = rewriter.create<gpu::SpMMOp>( |
| 1026 | location: loc, args&: tokenTp, args&: token, args&: spMatA, args&: dnB, args&: dnC, /*computeType=*/args&: dnCType, |
| 1027 | args: SmallVector<Value>{buffer1, buffer2, buffer3}); |
| 1028 | token = spmmComp.getAsyncToken(); |
| 1029 | |
| 1030 | // Copy data back to host and free all the resources. |
| 1031 | token = rewriter.create<gpu::DestroySpMatOp>(location: loc, args&: tokenTp, args&: token, args&: spMatA) |
| 1032 | .getAsyncToken(); |
| 1033 | token = rewriter.create<gpu::DestroyDnTensorOp>(location: loc, args&: tokenTp, args&: token, args&: dnB) |
| 1034 | .getAsyncToken(); |
| 1035 | token = rewriter.create<gpu::DestroyDnTensorOp>(location: loc, args&: tokenTp, args&: token, args&: dnC) |
| 1036 | .getAsyncToken(); |
| 1037 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer1, token); |
| 1038 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer2, token); |
| 1039 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer3, token); |
| 1040 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matA, token); |
| 1041 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token); |
| 1042 | token = genCopyMemRef(builder&: rewriter, loc, dst: bufC, src: matC, token); |
| 1043 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matC, token); |
| 1044 | tokens.push_back(Elt: token); |
| 1045 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 1046 | tokens.clear(); |
| 1047 | |
| 1048 | // Done. |
| 1049 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, args: C.getType(), args&: bufC); |
| 1050 | return success(); |
| 1051 | } |
| 1052 | |
| 1053 | /// Match and rewrite SDDMM kernel. |
| 1054 | static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, |
| 1055 | linalg::GenericOp op, bool enableRT) { |
| 1056 | Location loc = op.getLoc(); |
| 1057 | Value a = op.getOperand(i: 0); |
| 1058 | Value b = op.getOperand(i: 1); |
| 1059 | Value c = op.getOperand(i: 2); |
| 1060 | SmallVector<Value> tokens; |
| 1061 | |
| 1062 | // Only admissible sparse matrix format (no COO/CSC) and dense matrices. |
| 1063 | SparseTensorType aTp = getSparseTensorType(val: a); |
| 1064 | SparseTensorType bTp = getSparseTensorType(val: b); |
| 1065 | SparseTensorType cTp = getSparseTensorType(val: c); |
| 1066 | auto format = getCuSparseFormat(aTp: cTp, bTp, cTp: aTp, enableRT, /*isMatVec=*/false); |
| 1067 | if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO || |
| 1068 | format == CuSparseFormat::kCSC) |
| 1069 | return failure(); |
| 1070 | |
| 1071 | // The SDDMM does the in-place operation. |
| 1072 | // Start sparse kernel and copy data from host to device. |
| 1073 | // a : bufA -> matA |
| 1074 | // b : bufB -> matB |
| 1075 | // c : memR/memC/memV -> rowC,colC,valC |
| 1076 | Value nseC = rewriter.create<NumberOfEntriesOp>(location: loc, args&: c); |
| 1077 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
| 1078 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
| 1079 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1); |
| 1080 | Value bufA = genTensorToMemref(rewriter, loc, tensor: a); |
| 1081 | Value matA = genAllocCopy(builder&: rewriter, loc, b: bufA, tokens); |
| 1082 | Value bufB = genTensorToMemref(rewriter, loc, tensor: b); |
| 1083 | Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens); |
| 1084 | Value memR = genFirstPosOrCrds(builder&: rewriter, loc, a: c, format, enableRT); |
| 1085 | Value memC = genSecondCrds(builder&: rewriter, loc, a: c, format, enableRT); // or empty |
| 1086 | Value memV = rewriter.create<ToValuesOp>(location: loc, args&: c); |
| 1087 | Value rowC = genAllocCopy(builder&: rewriter, loc, b: memR, tokens); |
| 1088 | Value colC = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value(); |
| 1089 | Value valC = genAllocCopy(builder&: rewriter, loc, b: memV, tokens); |
| 1090 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 1091 | tokens.clear(); |
| 1092 | |
| 1093 | // Create sparse environment and sparse matrix/dense matrix handles. |
| 1094 | Type indexTp = rewriter.getIndexType(); |
| 1095 | Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
| 1096 | Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
| 1097 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
| 1098 | Value token = genFirstWait(builder&: rewriter, loc); |
| 1099 | auto dmatA = rewriter.create<gpu::CreateDnTensorOp>( |
| 1100 | location: loc, args&: dnMatHandleTp, args&: tokenTp, args&: token, args&: matA, args: SmallVector<Value>{szm, szk}); |
| 1101 | Value dnA = dmatA.getResult(i: 0); |
| 1102 | token = dmatA.getAsyncToken(); |
| 1103 | auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( |
| 1104 | location: loc, args&: dnMatHandleTp, args&: tokenTp, args&: token, args&: matB, args: SmallVector<Value>{szk, szn}); |
| 1105 | Value dnB = dmatB.getResult(i: 0); |
| 1106 | token = dmatB.getAsyncToken(); |
| 1107 | Operation *spGenC = |
| 1108 | genSpMat(builder&: rewriter, loc, aTp&: cTp, handleTp: spMatHandleTp, tokenTp, token, sz1: szm, sz2: szn, |
| 1109 | nseA: nseC, rowA: rowC, colA: colC, valA: valC, format, enableRT); |
| 1110 | Value spMatC = spGenC->getResult(idx: 0); |
| 1111 | token = spGenC->getResult(idx: 1); |
| 1112 | auto dnCType = llvm::cast<ShapedType>(Val: c.getType()).getElementType(); |
| 1113 | |
| 1114 | // Precompute buffersize for SDDMM. |
| 1115 | auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>( |
| 1116 | location: loc, args&: indexTp, args&: tokenTp, args&: token, args&: dnA, args&: dnB, args&: spMatC, args&: dnCType); |
| 1117 | Value bufferSz = bufferComp.getResult(i: 0); |
| 1118 | token = bufferComp.getAsyncToken(); |
| 1119 | auto buf = genAllocBuffer(builder&: rewriter, loc, size: bufferSz, token); |
| 1120 | Value buffer = buf.getResult(i: 0); |
| 1121 | token = buf.getAsyncToken(); |
| 1122 | |
| 1123 | // Perform the SDDMM. |
| 1124 | auto sddmmComp = rewriter.create<gpu::SDDMMOp>(location: loc, args&: tokenTp, args&: token, args&: dnA, args&: dnB, |
| 1125 | args&: spMatC, args&: dnCType, args&: buffer); |
| 1126 | token = sddmmComp.getAsyncToken(); |
| 1127 | |
| 1128 | // Copy data back to host and free all the resoures. |
| 1129 | token = rewriter.create<gpu::DestroyDnTensorOp>(location: loc, args&: tokenTp, args&: token, args&: dnA) |
| 1130 | .getAsyncToken(); |
| 1131 | token = rewriter.create<gpu::DestroyDnTensorOp>(location: loc, args&: tokenTp, args&: token, args&: dnB) |
| 1132 | .getAsyncToken(); |
| 1133 | token = rewriter.create<gpu::DestroySpMatOp>(location: loc, args&: tokenTp, args&: token, args&: spMatC) |
| 1134 | .getAsyncToken(); |
| 1135 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token); |
| 1136 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matA, token); |
| 1137 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token); |
| 1138 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowC, token); |
| 1139 | if (colC) |
| 1140 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colC, token); |
| 1141 | token = genCopyMemRef(builder&: rewriter, loc, dst: memV, src: valC, token); |
| 1142 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valC, token); |
| 1143 | tokens.push_back(Elt: token); |
| 1144 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 1145 | tokens.clear(); |
| 1146 | |
| 1147 | // Done. |
| 1148 | rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, args&: c); |
| 1149 | return success(); |
| 1150 | } |
| 1151 | |
| 1152 | //===----------------------------------------------------------------------===// |
| 1153 | // Rewriting rules for direct code generation. |
| 1154 | //===----------------------------------------------------------------------===// |
| 1155 | |
| 1156 | /// Proof-of-concept rewriter. This rule generates a GPU implementation |
| 1157 | /// for each outermost forall loop generated by the sparsifier. |
| 1158 | /// TODO: right now works with parallelization-strategy=dense-outer-loop |
| 1159 | /// but give this its own flags in the future |
| 1160 | struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> { |
| 1161 | using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; |
| 1162 | |
| 1163 | ForallRewriter(MLIRContext *context, unsigned nT) |
| 1164 | : OpRewritePattern(context), numThreads(nT){}; |
| 1165 | |
| 1166 | LogicalResult matchAndRewrite(scf::ParallelOp forallOp, |
| 1167 | PatternRewriter &rewriter) const override { |
| 1168 | // Reject inadmissible loop form. |
| 1169 | // Essentially only accept a loop, generated by the sparsifier, |
| 1170 | // of the form |
| 1171 | // forall (i = 0; i < N; i++) |
| 1172 | // so that cyclic scheduling over the threads is easy. |
| 1173 | if (!forallOp->hasAttr(name: LoopEmitter::getLoopEmitterLoopAttrName()) || |
| 1174 | forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 || |
| 1175 | !matchPattern(value: forallOp.getLowerBound()[0], pattern: m_Zero()) || |
| 1176 | !matchPattern(value: forallOp.getStep()[0], pattern: m_One())) |
| 1177 | return failure(); |
| 1178 | // Collect every value that is computed outside the parallel loop. |
| 1179 | SetVector<Value> invariants; // stable iteration! |
| 1180 | forallOp->walk(callback: [&](Operation *op) { |
| 1181 | // Collect all values of admissible ops. |
| 1182 | for (OpOperand &o : op->getOpOperands()) { |
| 1183 | Value val = o.get(); |
| 1184 | Block *block; |
| 1185 | if (auto arg = dyn_cast<BlockArgument>(Val&: val)) |
| 1186 | block = arg.getOwner(); |
| 1187 | else |
| 1188 | block = val.getDefiningOp()->getBlock(); |
| 1189 | if (!forallOp.getRegion().findAncestorBlockInRegion(block&: *block)) |
| 1190 | invariants.insert(X: val); |
| 1191 | } |
| 1192 | }); |
| 1193 | // Outline the outside values as proper parameters. Fail when sharing |
| 1194 | // value between host and device is not straightforward. |
| 1195 | SmallVector<Value> constants; |
| 1196 | SmallVector<Value> scalars; |
| 1197 | SmallVector<Value> buffers; |
| 1198 | for (Value val : invariants) { |
| 1199 | Type tp = val.getType(); |
| 1200 | if (val.getDefiningOp<arith::ConstantOp>()) |
| 1201 | constants.push_back(Elt: val); |
| 1202 | else if (isa<FloatType>(Val: tp) || tp.isIntOrIndex()) |
| 1203 | scalars.push_back(Elt: val); |
| 1204 | else if (isa<MemRefType>(Val: tp)) |
| 1205 | buffers.push_back(Elt: val); |
| 1206 | else |
| 1207 | return failure(); // don't know how to share |
| 1208 | } |
| 1209 | // Pass outlined non-constant values. |
| 1210 | // TODO: Experiment with `useHostRegistrationForOut` to see if we want to |
| 1211 | // keep the feature at all (either through a heuristic or compiler |
| 1212 | // option for gpu codegen). |
| 1213 | Location loc = forallOp->getLoc(); |
| 1214 | SmallVector<Value> args; |
| 1215 | SmallVector<Value> tokens; |
| 1216 | Value out = genParametersIn(builder&: rewriter, loc, scalars, buffers, args, tokens, |
| 1217 | /*useHostRegistrationForOut=*/false); |
| 1218 | // Set up GPU module and construct GPU function. |
| 1219 | auto saveIp = rewriter.saveInsertionPoint(); |
| 1220 | ModuleOp topModule = forallOp->getParentOfType<ModuleOp>(); |
| 1221 | auto gpuModule = genGPUModule(builder&: rewriter, topModule); |
| 1222 | auto gpuFunc = genGPUFunc(builder&: rewriter, gpuModule, args); |
| 1223 | genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers); |
| 1224 | // Generate code that launches the kernel asynchronously, blocking on all |
| 1225 | // opens tokens and yielding a new token for the output. |
| 1226 | // TODO: Passing in tokens to launch up does not seem to be properly lowered |
| 1227 | // by cubin yet, hence the current blocking wait. |
| 1228 | rewriter.restoreInsertionPoint(ip: saveIp); |
| 1229 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 1230 | tokens.clear(); |
| 1231 | Value kernelToken = |
| 1232 | genLaunchGPUFunc(builder&: rewriter, gpuFunc, args, tokens, numThreads); |
| 1233 | // Finalize the outlined arguments. |
| 1234 | genParametersOut(builder&: rewriter, loc, out, kernelToken, scalars, buffers, args, |
| 1235 | tokens); |
| 1236 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 1237 | rewriter.eraseOp(op: forallOp); |
| 1238 | return success(); |
| 1239 | } |
| 1240 | |
| 1241 | private: |
| 1242 | unsigned numThreads; |
| 1243 | }; |
| 1244 | |
| 1245 | //===----------------------------------------------------------------------===// |
| 1246 | // Rewriting rules for library recognition and code generation. |
| 1247 | //===----------------------------------------------------------------------===// |
| 1248 | |
| 1249 | /// Proof-of-concept rewriter. This rule recognizes certain math kernels |
| 1250 | /// and replaces these with corresponding calls into a sparse library. |
| 1251 | struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> { |
| 1252 | using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; |
| 1253 | |
| 1254 | LinalgOpRewriter(MLIRContext *context, bool rt) |
| 1255 | : OpRewritePattern(context), enableRT(rt) {} |
| 1256 | |
| 1257 | LogicalResult matchAndRewrite(linalg::GenericOp op, |
| 1258 | PatternRewriter &rewriter) const override { |
| 1259 | if (op.getNumDpsInits() != 1) |
| 1260 | return failure(); // reject multi-output |
| 1261 | |
| 1262 | const unsigned numLoops = op.getNumLoops(); |
| 1263 | const unsigned numTensors = op->getNumOperands(); |
| 1264 | const auto iteratorTypes = op.getIteratorTypesArray(); |
| 1265 | SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); |
| 1266 | |
| 1267 | using MapList = ArrayRef<ArrayRef<AffineExpr>>; |
| 1268 | auto infer = [&](MapList m) { |
| 1269 | return AffineMap::inferFromExprList(exprsList: m, context: op.getContext()); |
| 1270 | }; |
| 1271 | AffineExpr i, j, k; |
| 1272 | bindDims(ctx: getContext(), exprs&: i, exprs&: j, exprs&: k); |
| 1273 | |
| 1274 | // TODO: more robust patterns, transposed versions, more kernels, |
| 1275 | // identify alpha and beta and pass them to the CUDA calls. |
| 1276 | |
| 1277 | // Recognize a SpMV kernel. |
| 1278 | if (numLoops == 2 && numTensors == 3 && |
| 1279 | linalg::isParallelIterator(iteratorType: iteratorTypes[0]) && |
| 1280 | linalg::isReductionIterator(iteratorType: iteratorTypes[1]) && |
| 1281 | maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { |
| 1282 | return rewriteSpMV(rewriter, op, enableRT); |
| 1283 | } |
| 1284 | |
| 1285 | // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel. |
| 1286 | if (numLoops == 3 && numTensors == 3 && |
| 1287 | linalg::isParallelIterator(iteratorType: iteratorTypes[0]) && |
| 1288 | linalg::isParallelIterator(iteratorType: iteratorTypes[1]) && |
| 1289 | linalg::isReductionIterator(iteratorType: iteratorTypes[2]) && |
| 1290 | maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { |
| 1291 | if (!isDenseTensor(v: op.getOperand(i: 0)) && !isDenseTensor(v: op.getOperand(i: 1))) |
| 1292 | return rewriteSpGEMM(rewriter, op, enableRT); |
| 1293 | if (isConversionInto24(v: op.getOperand(i: 0))) |
| 1294 | return rewrite2To4SpMM(rewriter, op); |
| 1295 | return rewriteSpMM(rewriter, op, enableRT); |
| 1296 | } |
| 1297 | |
| 1298 | // Recognize a SDDMM kernel. |
| 1299 | if (numLoops == 3 && numTensors == 3 && |
| 1300 | linalg::isParallelIterator(iteratorType: iteratorTypes[0]) && |
| 1301 | linalg::isParallelIterator(iteratorType: iteratorTypes[1]) && |
| 1302 | linalg::isReductionIterator(iteratorType: iteratorTypes[2]) && |
| 1303 | maps == infer({{i, k}, {k, j}, {i, j}}) && |
| 1304 | matchSumReductionOfMulUnary(op)) { |
| 1305 | return rewriteSDDMM(rewriter, op, enableRT); |
| 1306 | } |
| 1307 | |
| 1308 | return failure(); |
| 1309 | } |
| 1310 | |
| 1311 | private: |
| 1312 | bool enableRT; |
| 1313 | }; |
| 1314 | |
| 1315 | } // namespace |
| 1316 | |
| 1317 | //===----------------------------------------------------------------------===// |
| 1318 | // Public method for populating GPU rewriting rules. |
| 1319 | // |
| 1320 | // Currently two set of rewriting rules are made available. The first set |
| 1321 | // implements direct code generation, currently by means of convering the |
| 1322 | // outermost paralell loop into GPU threads. The second set implements |
| 1323 | // libary recognition of a set of sparse operations. Eventually, the right |
| 1324 | // combination of these two approaches has to be found. |
| 1325 | //===----------------------------------------------------------------------===// |
| 1326 | |
| 1327 | void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, |
| 1328 | unsigned numThreads) { |
| 1329 | patterns.add<ForallRewriter>(arg: patterns.getContext(), args&: numThreads); |
| 1330 | } |
| 1331 | |
| 1332 | void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, |
| 1333 | bool enableRT) { |
| 1334 | patterns.add<LinalgOpRewriter>(arg: patterns.getContext(), args&: enableRT); |
| 1335 | } |
| 1336 | |