| 1 | //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This is a prototype GPU codegenerator for the sparsifier. |
| 10 | // The objective is to eventually use the right combination of |
| 11 | // direct code generation and libary calls into vendor-specific |
| 12 | // highly optimized sparse libraries (e.g. cuSparse for CUDA). |
| 13 | // |
| 14 | //===----------------------------------------------------------------------===// |
| 15 | |
| 16 | #include "Utils/CodegenUtils.h" |
| 17 | #include "Utils/LoopEmitter.h" |
| 18 | |
| 19 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 20 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 21 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 22 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 25 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 26 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| 27 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 28 | #include "mlir/IR/IRMapping.h" |
| 29 | #include "mlir/IR/Matchers.h" |
| 30 | |
| 31 | using namespace mlir; |
| 32 | using namespace mlir::sparse_tensor; |
| 33 | |
| 34 | namespace { |
| 35 | |
| 36 | // Sparse formats supported by cuSparse. |
| 37 | enum class CuSparseFormat { |
| 38 | kNone, |
| 39 | kCOO, |
| 40 | kCSR, |
| 41 | kCSC, |
| 42 | kBSR, |
| 43 | }; |
| 44 | |
| 45 | //===----------------------------------------------------------------------===// |
| 46 | // Helper methods. |
| 47 | //===----------------------------------------------------------------------===// |
| 48 | |
| 49 | /// Marks the given top module as a GPU container module. |
| 50 | static void markAsGPUContainer(ModuleOp topModule) { |
| 51 | topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), |
| 52 | UnitAttr::get(topModule->getContext())); |
| 53 | } |
| 54 | |
| 55 | /// Constructs a new GPU module (for GPU kernels) inside the given top module, |
| 56 | /// or returns an existing GPU module if one was built previously. |
| 57 | static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) { |
| 58 | for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>()) |
| 59 | return op; // existing |
| 60 | markAsGPUContainer(topModule); |
| 61 | builder.setInsertionPointToStart(topModule.getBody()); |
| 62 | return builder.create<gpu::GPUModuleOp>(topModule->getLoc(), |
| 63 | "sparse_kernels" ); |
| 64 | } |
| 65 | |
| 66 | /// Constructs a new GPU kernel in the given GPU module. |
| 67 | static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, |
| 68 | SmallVectorImpl<Value> &args) { |
| 69 | // Get a unique kernel name. Not very creative, |
| 70 | // but we simply try kernel0, kernel1, etc. |
| 71 | unsigned kernelNumber = 0; |
| 72 | SmallString<16> kernelName; |
| 73 | do { |
| 74 | kernelName.clear(); |
| 75 | ("kernel" + Twine(kernelNumber++)).toStringRef(Out&: kernelName); |
| 76 | } while (gpuModule.lookupSymbol(kernelName)); |
| 77 | // Then we insert a new kernel with given arguments into the module. |
| 78 | builder.setInsertionPointToStart(gpuModule.getBody()); |
| 79 | SmallVector<Type> argsTp; |
| 80 | for (auto arg : args) |
| 81 | argsTp.push_back(Elt: arg.getType()); |
| 82 | FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {}); |
| 83 | auto gpuFunc = |
| 84 | builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type); |
| 85 | gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), |
| 86 | builder.getUnitAttr()); |
| 87 | return gpuFunc; |
| 88 | } |
| 89 | |
| 90 | /// Constructs code to launch GPU kernel. |
| 91 | static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, |
| 92 | SmallVectorImpl<Value> &args, |
| 93 | SmallVectorImpl<Value> &tokens, |
| 94 | unsigned numThreads) { |
| 95 | Location loc = gpuFunc->getLoc(); |
| 96 | Value none = TypedValue<::mlir::IntegerType>{}; |
| 97 | Value one = constantIndex(builder, loc, i: 1); |
| 98 | Value numT = constantIndex(builder, loc, i: numThreads); |
| 99 | gpu::KernelDim3 gridSize = {.x: one, .y: one, .z: one}; |
| 100 | gpu::KernelDim3 blckSize = {.x: numT, .y: one, .z: one}; |
| 101 | return builder |
| 102 | .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize, |
| 103 | /*dynSharedMemSz*/ none, args, |
| 104 | builder.getType<gpu::AsyncTokenType>(), tokens) |
| 105 | .getAsyncToken(); |
| 106 | } |
| 107 | |
| 108 | /// Maps the provided ranked host buffer into the device address space. |
| 109 | /// Writes from the host are guaranteed to be visible to device kernels |
| 110 | /// that are launched afterwards. Writes from the device are guaranteed |
| 111 | /// to be visible on the host after synchronizing with the device kernel |
| 112 | /// completion. Needs to cast the buffer to a unranked buffer. |
| 113 | static Value genHostRegisterMemref(OpBuilder &builder, Location loc, |
| 114 | Value mem) { |
| 115 | MemRefType memTp = cast<MemRefType>(mem.getType()); |
| 116 | UnrankedMemRefType resTp = |
| 117 | UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); |
| 118 | Value cast = builder.create<memref::CastOp>(loc, resTp, mem); |
| 119 | builder.create<gpu::HostRegisterOp>(loc, cast); |
| 120 | return cast; |
| 121 | } |
| 122 | |
| 123 | /// Unmaps the provided buffer, expecting the casted buffer. |
| 124 | static void genHostUnregisterMemref(OpBuilder &builder, Location loc, |
| 125 | Value cast) { |
| 126 | builder.create<gpu::HostUnregisterOp>(loc, cast); |
| 127 | } |
| 128 | |
| 129 | /// Generates first wait in an asynchronous chain. |
| 130 | static Value genFirstWait(OpBuilder &builder, Location loc) { |
| 131 | Type tokenType = builder.getType<gpu::AsyncTokenType>(); |
| 132 | return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange()) |
| 133 | .getAsyncToken(); |
| 134 | } |
| 135 | |
| 136 | /// Generates last, blocking wait in an asynchronous chain. |
| 137 | static void genBlockingWait(OpBuilder &builder, Location loc, |
| 138 | ValueRange operands) { |
| 139 | builder.create<gpu::WaitOp>(loc, Type(), operands); |
| 140 | } |
| 141 | |
| 142 | /// Allocates memory on the device. |
| 143 | /// TODO: A `host_shared` attribute could be used to indicate that |
| 144 | /// the buffer is visible by both host and device, but lowering |
| 145 | /// that feature does not seem to be fully supported yet. |
| 146 | static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, |
| 147 | Value token) { |
| 148 | auto tp = cast<ShapedType>(mem.getType()); |
| 149 | auto elemTp = tp.getElementType(); |
| 150 | auto shape = tp.getShape(); |
| 151 | auto memTp = MemRefType::get(shape, elemTp); |
| 152 | SmallVector<Value> dynamicSizes; |
| 153 | for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) { |
| 154 | if (shape[r] == ShapedType::kDynamic) { |
| 155 | Value dimOp = linalg::createOrFoldDimOp(b&: builder, loc, val: mem, dim: r); |
| 156 | dynamicSizes.push_back(Elt: dimOp); |
| 157 | } |
| 158 | } |
| 159 | return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), |
| 160 | token, dynamicSizes, ValueRange()); |
| 161 | } |
| 162 | |
| 163 | // Allocates a typed buffer on the host with given size. |
| 164 | static Value genHostBuffer(OpBuilder &builder, Location loc, Type type, |
| 165 | Value size) { |
| 166 | const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); |
| 167 | return builder.create<memref::AllocOp>(loc, memTp, size).getResult(); |
| 168 | } |
| 169 | |
| 170 | // Allocates a typed buffer on the device with given size. |
| 171 | static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type, |
| 172 | Value size, Value token) { |
| 173 | const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); |
| 174 | return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), |
| 175 | token, size, ValueRange()); |
| 176 | } |
| 177 | |
| 178 | // Allocates a void buffer on the device with given size. |
| 179 | static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, |
| 180 | Value token) { |
| 181 | return genAllocBuffer(builder, loc, builder.getI8Type(), size, token); |
| 182 | } |
| 183 | |
| 184 | /// Deallocates memory from the device. |
| 185 | static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, |
| 186 | Value token) { |
| 187 | return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem) |
| 188 | .getAsyncToken(); |
| 189 | } |
| 190 | |
| 191 | /// Copies memory between host and device (direction is implicit). |
| 192 | static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst, |
| 193 | Value src, Value token) { |
| 194 | return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src) |
| 195 | .getAsyncToken(); |
| 196 | } |
| 197 | |
| 198 | /// Generates an alloc/copy pair. |
| 199 | static Value genAllocCopy(OpBuilder &builder, Location loc, Value b, |
| 200 | SmallVectorImpl<Value> &tokens) { |
| 201 | Value firstToken = genFirstWait(builder, loc); |
| 202 | auto alloc = genAllocMemRef(builder, loc, b, firstToken); |
| 203 | Value devMem = alloc.getResult(0); |
| 204 | Value depToken = alloc.getAsyncToken(); // copy-after-alloc |
| 205 | tokens.push_back(Elt: genCopyMemRef(builder, loc, dst: devMem, src: b, token: depToken)); |
| 206 | return devMem; |
| 207 | } |
| 208 | |
| 209 | /// Generates a memref from tensor operation. |
| 210 | static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, |
| 211 | Value tensor) { |
| 212 | auto tensorType = llvm::cast<ShapedType>(tensor.getType()); |
| 213 | auto memrefType = |
| 214 | MemRefType::get(tensorType.getShape(), tensorType.getElementType()); |
| 215 | return rewriter.create<bufferization::ToBufferOp>(loc, memrefType, tensor); |
| 216 | } |
| 217 | |
| 218 | /// Prepares the outlined arguments, passing scalars and buffers in. Here we |
| 219 | /// assume that the first buffer is the one allocated for output. We create |
| 220 | /// a set of properly chained asynchronous allocation/copy pairs to increase |
| 221 | /// overlap before launching the kernel. |
| 222 | static Value genParametersIn(OpBuilder &builder, Location loc, |
| 223 | SmallVectorImpl<Value> &scalars, |
| 224 | SmallVectorImpl<Value> &buffers, |
| 225 | SmallVectorImpl<Value> &args, |
| 226 | SmallVectorImpl<Value> &tokens, |
| 227 | bool useHostRegistrationForOut) { |
| 228 | Value out; |
| 229 | // Scalars are passed by value. |
| 230 | for (Value s : scalars) |
| 231 | args.push_back(Elt: s); |
| 232 | // Buffers are need to be made visible on device. |
| 233 | for (Value b : buffers) { |
| 234 | if (useHostRegistrationForOut) { |
| 235 | out = genHostRegisterMemref(builder, loc, mem: b); |
| 236 | args.push_back(Elt: b); |
| 237 | useHostRegistrationForOut = false; |
| 238 | continue; |
| 239 | } |
| 240 | args.push_back(Elt: genAllocCopy(builder, loc, b, tokens)); |
| 241 | } |
| 242 | return out; |
| 243 | } |
| 244 | |
| 245 | /// Finalizes the outlined arguments. The output buffer is copied depending |
| 246 | /// on the kernel token and then deallocated. All other buffers are simply |
| 247 | /// deallocated. Then we wait for all operations to complete. |
| 248 | static void genParametersOut(OpBuilder &builder, Location loc, Value out, |
| 249 | Value kernelToken, SmallVectorImpl<Value> &scalars, |
| 250 | SmallVectorImpl<Value> &buffers, |
| 251 | SmallVectorImpl<Value> &args, |
| 252 | SmallVectorImpl<Value> &tokens) { |
| 253 | unsigned base = scalars.size(); |
| 254 | for (unsigned i = base, e = args.size(); i < e; i++) { |
| 255 | Value firstToken; |
| 256 | if (i == base) { |
| 257 | // Assumed output parameter: unregister or copy-out. |
| 258 | if (out) { |
| 259 | genHostUnregisterMemref(builder, loc, cast: out); |
| 260 | out = Value(); |
| 261 | continue; |
| 262 | } |
| 263 | firstToken = |
| 264 | genCopyMemRef(builder, loc, dst: buffers[0], src: args[i], token: kernelToken); |
| 265 | } else { |
| 266 | firstToken = genFirstWait(builder, loc); |
| 267 | } |
| 268 | tokens.push_back(Elt: genDeallocMemRef(builder, loc, mem: args[i], token: firstToken)); |
| 269 | } |
| 270 | } |
| 271 | |
| 272 | /// Constructs code for new GPU kernel. |
| 273 | static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, |
| 274 | scf::ParallelOp forallOp, |
| 275 | SmallVectorImpl<Value> &constants, |
| 276 | SmallVectorImpl<Value> &scalars, |
| 277 | SmallVectorImpl<Value> &buffers) { |
| 278 | Location loc = gpuFunc->getLoc(); |
| 279 | Block &block = gpuFunc.getBody().front(); |
| 280 | rewriter.setInsertionPointToStart(&block); |
| 281 | |
| 282 | // Re-generate the constants, recapture all arguments. |
| 283 | unsigned arg = 0; |
| 284 | IRMapping irMap; |
| 285 | for (Value c : constants) |
| 286 | irMap.map(from: c, to: rewriter.clone(op&: *c.getDefiningOp())->getResult(idx: 0)); |
| 287 | for (Value s : scalars) |
| 288 | irMap.map(from: s, to: block.getArgument(i: arg++)); |
| 289 | for (Value b : buffers) |
| 290 | irMap.map(from: b, to: block.getArgument(i: arg++)); |
| 291 | |
| 292 | // Assume 1-dimensional grid/block configuration (only x dimension), |
| 293 | // so that: |
| 294 | // row = blockIdx.x * blockDim.x + threadIdx.x |
| 295 | // inc = blockDim.x * gridDim.x |
| 296 | Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x); |
| 297 | Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x); |
| 298 | Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); |
| 299 | Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x); |
| 300 | Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz); |
| 301 | Value row = rewriter.create<arith::AddIOp>(loc, mul, tid); |
| 302 | Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz); |
| 303 | |
| 304 | // Construct the iteration over the computational space that |
| 305 | // accounts for the fact that the total number of threads and |
| 306 | // the amount of work to be done usually do not match precisely. |
| 307 | // for (r = row; r < N; r += inc) { |
| 308 | // <loop-body> |
| 309 | // } |
| 310 | Value upper = irMap.lookup(forallOp.getUpperBound()[0]); |
| 311 | scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc); |
| 312 | // The scf.for builder creates an empty block. scf.for does not allow multiple |
| 313 | // blocks in its region, so delete the block before `cloneRegionBefore` adds |
| 314 | // an additional block. |
| 315 | rewriter.eraseBlock(block: forOp.getBody()); |
| 316 | rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(), |
| 317 | forOp.getRegion().begin(), irMap); |
| 318 | // Replace the scf.reduce terminator. |
| 319 | rewriter.setInsertionPoint(forOp.getBody()->getTerminator()); |
| 320 | rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator()); |
| 321 | |
| 322 | // Done. |
| 323 | rewriter.setInsertionPointAfter(forOp); |
| 324 | rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc()); |
| 325 | } |
| 326 | |
| 327 | //===----------------------------------------------------------------------===// |
| 328 | // Library helper methods. |
| 329 | //===----------------------------------------------------------------------===// |
| 330 | |
| 331 | /// Helper to detect a + b with arguments taken from given block. |
| 332 | static bool matchAddOfArgs(Block *block, Value val) { |
| 333 | if (auto *def = val.getDefiningOp()) { |
| 334 | if (isa<arith::AddFOp, arith::AddIOp>(def)) { |
| 335 | Value a = block->getArguments()[0]; |
| 336 | Value b = block->getArguments()[1]; |
| 337 | return (def->getOperand(idx: 0) == a && def->getOperand(idx: 1) == b) || |
| 338 | (def->getOperand(idx: 0) == b && def->getOperand(idx: 1) == a); |
| 339 | } |
| 340 | } |
| 341 | return false; |
| 342 | } |
| 343 | |
| 344 | /// Helper to detect a * b with arguments taken from given block. |
| 345 | static bool matchMulOfArgs(Block *block, Value val) { |
| 346 | if (auto *def = val.getDefiningOp()) { |
| 347 | if (isa<arith::MulFOp, arith::MulIOp>(def)) { |
| 348 | Value a = block->getArguments()[0]; |
| 349 | Value b = block->getArguments()[1]; |
| 350 | return (def->getOperand(idx: 0) == a && def->getOperand(idx: 1) == b) || |
| 351 | (def->getOperand(idx: 0) == b && def->getOperand(idx: 1) == a); |
| 352 | } |
| 353 | } |
| 354 | return false; |
| 355 | } |
| 356 | |
| 357 | /// Helper to detect x = x + a * b |
| 358 | static bool matchSumOfMultOfArgs(linalg::GenericOp op) { |
| 359 | auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); |
| 360 | if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { |
| 361 | if (isa<arith::AddFOp, arith::AddIOp>(def)) { |
| 362 | Value x = op.getBlock()->getArguments()[2]; |
| 363 | return (def->getOperand(0) == x && |
| 364 | matchMulOfArgs(op.getBlock(), def->getOperand(1))) || |
| 365 | (def->getOperand(1) == x && |
| 366 | matchMulOfArgs(op.getBlock(), def->getOperand(0))); |
| 367 | } |
| 368 | } |
| 369 | return false; |
| 370 | } |
| 371 | |
| 372 | // Helper to detect c += spy(s) x (a * b) |
| 373 | static bool matchSumReductionOfMulUnary(linalg::GenericOp op) { |
| 374 | auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); |
| 375 | // The linalg yields a custom reduce result. |
| 376 | Value s_out = op.getBlock()->getArguments()[2]; |
| 377 | if (auto redOp = |
| 378 | yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) { |
| 379 | // The reduce consumes the output. |
| 380 | Value other; |
| 381 | if (s_out == redOp->getOperand(0)) |
| 382 | other = redOp->getOperand(1); |
| 383 | else if (s_out == redOp->getOperand(1)) |
| 384 | other = redOp->getOperand(0); |
| 385 | else |
| 386 | return false; |
| 387 | // The reduce op also consumes an unary which also consumes the output |
| 388 | // and does not define an absent value. |
| 389 | if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) { |
| 390 | if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty()) |
| 391 | return false; |
| 392 | // And the bodies are as expected. |
| 393 | auto yieldUn = cast<sparse_tensor::YieldOp>( |
| 394 | unOp.getRegion(0).front().getTerminator()); |
| 395 | auto yieldRed = cast<sparse_tensor::YieldOp>( |
| 396 | redOp.getRegion().front().getTerminator()); |
| 397 | return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) && |
| 398 | matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0)); |
| 399 | } |
| 400 | } |
| 401 | return false; |
| 402 | } |
| 403 | |
| 404 | /// Test for dense tensor. |
| 405 | static bool isDenseTensor(Value v) { |
| 406 | auto sTp = getSparseTensorType(val: v); |
| 407 | return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense(); |
| 408 | } |
| 409 | |
| 410 | /// Test for suitable positions/coordinates width. |
| 411 | static bool isAdmissibleMetaData(SparseTensorType &aTp) { |
| 412 | return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) && |
| 413 | (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16); |
| 414 | } |
| 415 | |
| 416 | /// Test for sorted COO matrix with suitable metadata. |
| 417 | static bool isAdmissibleCOO(SparseTensorType &aTp) { |
| 418 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && |
| 419 | aTp.isCompressedLvl(l: 0) && aTp.isOrderedLvl(l: 0) && !aTp.isUniqueLvl(l: 0) && |
| 420 | aTp.isSingletonLvl(l: 1) && aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) && |
| 421 | isAdmissibleMetaData(aTp); |
| 422 | } |
| 423 | |
| 424 | /// Test for CSR matrix with suitable metadata. |
| 425 | static bool isAdmissibleCSR(SparseTensorType &aTp) { |
| 426 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && |
| 427 | aTp.isDenseLvl(l: 0) && aTp.isCompressedLvl(l: 1) && aTp.isOrderedLvl(l: 1) && |
| 428 | aTp.isUniqueLvl(l: 1) && isAdmissibleMetaData(aTp); |
| 429 | } |
| 430 | |
| 431 | /// Test for CSC matrix with suitable metadata. |
| 432 | static bool isAdmissibleCSC(SparseTensorType &aTp) { |
| 433 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() && |
| 434 | aTp.isPermutation() && aTp.isDenseLvl(l: 0) && aTp.isCompressedLvl(l: 1) && |
| 435 | aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) && isAdmissibleMetaData(aTp); |
| 436 | } |
| 437 | |
| 438 | /// Test for BSR matrix with suitable metadata. |
| 439 | static bool isAdmissibleBSR(SparseTensorType &aTp) { |
| 440 | if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(l: 0) && |
| 441 | aTp.isCompressedLvl(l: 1) && aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) && |
| 442 | aTp.isDenseLvl(l: 2) && aTp.isDenseLvl(l: 3) && isAdmissibleMetaData(aTp)) { |
| 443 | // CuSparse only supports "square" blocks currently. |
| 444 | SmallVector<unsigned> dims = getBlockSize(dimToLvl: aTp.getDimToLvl()); |
| 445 | assert(dims.size() == 2); |
| 446 | return dims[0] == dims[1] && dims[0] > 1; |
| 447 | } |
| 448 | return false; |
| 449 | } |
| 450 | |
| 451 | /// Test for 2:4 matrix with suitable metadata. |
| 452 | static bool isAdmissible24(SparseTensorType &aTp) { |
| 453 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(l: 0) && |
| 454 | aTp.isDenseLvl(l: 1) && aTp.isNOutOfMLvl(l: 2) && isAdmissibleMetaData(aTp); |
| 455 | } |
| 456 | |
| 457 | /// Test for conversion into 2:4 matrix. |
| 458 | static bool isConversionInto24(Value v) { |
| 459 | if (auto cnv = v.getDefiningOp<ConvertOp>()) { |
| 460 | Value a = cnv.getResult(); |
| 461 | Value d = cnv.getSource(); |
| 462 | SparseTensorType aTp = getSparseTensorType(val: a); |
| 463 | return isDenseTensor(v: d) && isAdmissible24(aTp); |
| 464 | } |
| 465 | return false; |
| 466 | } |
| 467 | |
| 468 | /// Returns a suitable sparse format for the operation and given operand |
| 469 | /// types with cuSparse, or kNone if none is available. |
| 470 | static CuSparseFormat getCuSparseFormat(SparseTensorType aTp, |
| 471 | SparseTensorType bTp, |
| 472 | SparseTensorType cTp, bool enableRT, |
| 473 | bool isMatVec) { |
| 474 | // The other operands have a dense type. |
| 475 | if (bTp.hasEncoding() || cTp.hasEncoding()) |
| 476 | return CuSparseFormat::kNone; |
| 477 | // Now check for suitable operand type for the main operand. |
| 478 | if (isAdmissibleCOO(aTp)) |
| 479 | #ifdef CUSPARSE_COO_AOS |
| 480 | return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone; |
| 481 | #else |
| 482 | return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone; |
| 483 | #endif |
| 484 | if (isAdmissibleCSR(aTp)) |
| 485 | return CuSparseFormat::kCSR; |
| 486 | if (isAdmissibleCSC(aTp)) |
| 487 | return CuSparseFormat::kCSC; |
| 488 | if (isAdmissibleBSR(aTp)) |
| 489 | return CuSparseFormat::kBSR; |
| 490 | return CuSparseFormat::kNone; |
| 491 | } |
| 492 | |
| 493 | /// Generates the first positions/coordinates of a sparse matrix. |
| 494 | static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, |
| 495 | CuSparseFormat format, bool enableRT) { |
| 496 | if (format == CuSparseFormat::kCOO) { |
| 497 | // Library uses SoA COO, direct IR uses AoS COO. |
| 498 | if (enableRT) |
| 499 | return builder.create<ToCoordinatesOp>(loc, a, 0); |
| 500 | return builder.create<ToCoordinatesBufferOp>(loc, a); |
| 501 | } |
| 502 | // Formats CSR/CSC and BSR use positions at 1. |
| 503 | return builder.create<ToPositionsOp>(loc, a, 1); |
| 504 | } |
| 505 | |
| 506 | /// Generates the second coordinates of a sparse matrix. |
| 507 | static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, |
| 508 | CuSparseFormat format, bool enableRT) { |
| 509 | bool isCOO = format == CuSparseFormat::kCOO; |
| 510 | if (isCOO && !enableRT) |
| 511 | return Value(); // nothing needed |
| 512 | // Formats CSR/CSC and BSR use coordinates at 1. |
| 513 | return builder.create<ToCoordinatesOp>(loc, a, 1); |
| 514 | } |
| 515 | |
| 516 | /// Generates the sparse matrix handle. |
| 517 | static Operation *genSpMat(OpBuilder &builder, Location loc, |
| 518 | SparseTensorType &aTp, Type handleTp, Type tokenTp, |
| 519 | Value token, Value sz1, Value sz2, Value nseA, |
| 520 | Value rowA, Value colA, Value valA, |
| 521 | CuSparseFormat format, bool enableRT) { |
| 522 | if (format == CuSparseFormat::kCOO) { |
| 523 | // Library uses SoA COO, direct IR uses AoS COO. |
| 524 | if (enableRT) { |
| 525 | assert(colA); |
| 526 | return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token, |
| 527 | sz1, sz2, nseA, rowA, colA, valA); |
| 528 | } |
| 529 | #ifdef CUSPARSE_COO_AOS |
| 530 | assert(!colA); |
| 531 | return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token, |
| 532 | sz1, sz2, nseA, rowA, valA); |
| 533 | #else |
| 534 | llvm_unreachable("gpu::CreateCooAoSOp is deprecated" ); |
| 535 | #endif |
| 536 | } |
| 537 | assert(colA); |
| 538 | if (format == CuSparseFormat::kCSR) |
| 539 | return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1, |
| 540 | sz2, nseA, rowA, colA, valA); |
| 541 | if (format == CuSparseFormat::kCSC) |
| 542 | return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1, |
| 543 | sz2, nseA, rowA, colA, valA); |
| 544 | // BSR requires a bit more work since we need to pass in the block size |
| 545 | // and all others sizes in terms of blocks (#block-rows, #block-cols, |
| 546 | // #nonzero-blocks). |
| 547 | assert(format == CuSparseFormat::kBSR); |
| 548 | SmallVector<unsigned> dims = getBlockSize(dimToLvl: aTp.getDimToLvl()); |
| 549 | assert(dims.size() == 2 && dims[0] == dims[1]); |
| 550 | uint64_t b = dims[0]; |
| 551 | Value bSz = constantIndex(builder, loc, i: b); |
| 552 | Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz); |
| 553 | Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz); |
| 554 | Value bNum = builder.create<arith::DivUIOp>( |
| 555 | loc, nseA, constantIndex(builder, loc, b * b)); |
| 556 | return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows, |
| 557 | bCols, bNum, bSz, bSz, rowA, colA, |
| 558 | valA); |
| 559 | } |
| 560 | |
| 561 | /// Match and rewrite SpMV kernel. |
| 562 | static LogicalResult rewriteSpMV(PatternRewriter &rewriter, |
| 563 | linalg::GenericOp op, bool enableRT) { |
| 564 | Location loc = op.getLoc(); |
| 565 | Value a = op.getOperand(0); |
| 566 | Value x = op.getOperand(1); |
| 567 | Value y = op.getOperand(2); // we have y = Ax |
| 568 | SmallVector<Value> tokens; |
| 569 | |
| 570 | // Only admissible sparse matrix format and dense vectors (no BSR). |
| 571 | SparseTensorType aTp = getSparseTensorType(val: a); |
| 572 | SparseTensorType xTp = getSparseTensorType(val: x); |
| 573 | SparseTensorType yTp = getSparseTensorType(val: y); |
| 574 | auto format = getCuSparseFormat(aTp, bTp: xTp, cTp: yTp, enableRT, /*isMatVec=*/true); |
| 575 | if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) |
| 576 | return failure(); |
| 577 | |
| 578 | // Start sparse kernel and copy data from host to device. |
| 579 | // a : memR/memC/memV -> rowA,colA,valA |
| 580 | // x : memX -> vecX |
| 581 | // y : memY -> vecY |
| 582 | Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); |
| 583 | Value szY = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
| 584 | Value szX = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
| 585 | Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); |
| 586 | Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty |
| 587 | Value memV = rewriter.create<ToValuesOp>(loc, a); |
| 588 | Value rowA = genAllocCopy(builder&: rewriter, loc, b: memR, tokens); |
| 589 | Value colA = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value(); |
| 590 | Value valA = genAllocCopy(builder&: rewriter, loc, b: memV, tokens); |
| 591 | Value memX = genTensorToMemref(rewriter, loc, tensor: x); |
| 592 | Value vecX = genAllocCopy(builder&: rewriter, loc, b: memX, tokens); |
| 593 | Value memY = genTensorToMemref(rewriter, loc, tensor: y); |
| 594 | Value vecY = genAllocCopy(builder&: rewriter, loc, b: memY, tokens); |
| 595 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 596 | tokens.clear(); |
| 597 | |
| 598 | // Create sparse environment and sparse matrix/dense vector handles. |
| 599 | Type indexTp = rewriter.getIndexType(); |
| 600 | Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
| 601 | Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
| 602 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
| 603 | Value token = genFirstWait(builder&: rewriter, loc); |
| 604 | Operation *spGenA = |
| 605 | genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX, |
| 606 | nseA, rowA, colA, valA, format, enableRT); |
| 607 | Value spMatA = spGenA->getResult(idx: 0); |
| 608 | token = spGenA->getResult(idx: 1); |
| 609 | auto dvecX = rewriter.create<gpu::CreateDnTensorOp>( |
| 610 | loc, dnTensorHandleTp, tokenTp, token, vecX, szX); |
| 611 | Value dnX = dvecX.getResult(0); |
| 612 | token = dvecX.getAsyncToken(); |
| 613 | auto dvecY = rewriter.create<gpu::CreateDnTensorOp>( |
| 614 | loc, dnTensorHandleTp, tokenTp, token, vecY, szY); |
| 615 | Value dnY = dvecY.getResult(0); |
| 616 | token = dvecY.getAsyncToken(); |
| 617 | auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType(); |
| 618 | |
| 619 | // Precompute buffersize for SpMV. |
| 620 | auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>( |
| 621 | loc, indexTp, tokenTp, token, spMatA, dnX, dnY, |
| 622 | /*computeType=*/dnYType); |
| 623 | Value bufferSz = bufferComp.getResult(0); |
| 624 | token = bufferComp.getAsyncToken(); |
| 625 | auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); |
| 626 | Value buffer = buf.getResult(0); |
| 627 | token = buf.getAsyncToken(); |
| 628 | |
| 629 | // Perform the SpMV. |
| 630 | auto spmvComp = rewriter.create<gpu::SpMVOp>( |
| 631 | loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer); |
| 632 | token = spmvComp.getAsyncToken(); |
| 633 | |
| 634 | // Copy data back to host and free all the resoures. |
| 635 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) |
| 636 | .getAsyncToken(); |
| 637 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX) |
| 638 | .getAsyncToken(); |
| 639 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY) |
| 640 | .getAsyncToken(); |
| 641 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token); |
| 642 | if (colA) |
| 643 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token); |
| 644 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token); |
| 645 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token); |
| 646 | token = genDeallocMemRef(builder&: rewriter, loc, mem: vecX, token); |
| 647 | token = genCopyMemRef(builder&: rewriter, loc, dst: memY, src: vecY, token); |
| 648 | token = genDeallocMemRef(builder&: rewriter, loc, mem: vecY, token); |
| 649 | tokens.push_back(Elt: token); |
| 650 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 651 | tokens.clear(); |
| 652 | |
| 653 | // Done. |
| 654 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY); |
| 655 | return success(); |
| 656 | } |
| 657 | |
| 658 | /// Match and rewrite SpMM kernel. |
| 659 | static LogicalResult rewriteSpMM(PatternRewriter &rewriter, |
| 660 | linalg::GenericOp op, bool enableRT) { |
| 661 | Location loc = op.getLoc(); |
| 662 | Value a = op.getOperand(0); |
| 663 | Value b = op.getOperand(1); |
| 664 | Value c = op.getOperand(2); // we have C = AB |
| 665 | SmallVector<Value> tokens; |
| 666 | |
| 667 | // Only admissible sparse matrix format and dense matrices (no BSR). |
| 668 | SparseTensorType aTp = getSparseTensorType(val: a); |
| 669 | SparseTensorType bTp = getSparseTensorType(val: b); |
| 670 | SparseTensorType cTp = getSparseTensorType(val: c); |
| 671 | auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false); |
| 672 | if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) |
| 673 | return failure(); |
| 674 | |
| 675 | // Start sparse kernel and copy data from host to device. |
| 676 | // a : memR/memC/memV -> rowA,colA,valA |
| 677 | // b : bufB -> matB |
| 678 | // c : bufC -> matC |
| 679 | Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); |
| 680 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
| 681 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
| 682 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1); |
| 683 | Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); |
| 684 | Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty |
| 685 | Value memV = rewriter.create<ToValuesOp>(loc, a); |
| 686 | Value rowA = genAllocCopy(builder&: rewriter, loc, b: memR, tokens); |
| 687 | Value colA = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value(); |
| 688 | Value valA = genAllocCopy(builder&: rewriter, loc, b: memV, tokens); |
| 689 | Value bufB = genTensorToMemref(rewriter, loc, tensor: b); |
| 690 | Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens); |
| 691 | Value bufC = genTensorToMemref(rewriter, loc, tensor: c); |
| 692 | Value matC = genAllocCopy(builder&: rewriter, loc, b: bufC, tokens); |
| 693 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 694 | tokens.clear(); |
| 695 | |
| 696 | // Create sparse environment and sparse matrix/dense matrix handles. |
| 697 | Type indexTp = rewriter.getIndexType(); |
| 698 | Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
| 699 | Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
| 700 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
| 701 | Value token = genFirstWait(builder&: rewriter, loc); |
| 702 | Operation *spGenA = |
| 703 | genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk, |
| 704 | nseA, rowA, colA, valA, format, enableRT); |
| 705 | Value spMatA = spGenA->getResult(idx: 0); |
| 706 | token = spGenA->getResult(idx: 1); |
| 707 | auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( |
| 708 | loc, dnTensorHandleTp, tokenTp, token, matB, |
| 709 | SmallVector<Value>{szk, szn}); |
| 710 | Value dnB = dmatB.getResult(0); |
| 711 | token = dmatB.getAsyncToken(); |
| 712 | auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( |
| 713 | loc, dnTensorHandleTp, tokenTp, token, matC, |
| 714 | SmallVector<Value>{szm, szn}); |
| 715 | Value dnC = dmatC.getResult(0); |
| 716 | token = dmatC.getAsyncToken(); |
| 717 | auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType(); |
| 718 | |
| 719 | // Precompute buffersize for SpMM. |
| 720 | auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( |
| 721 | loc, indexTp, tokenTp, token, spMatA, dnB, dnC, |
| 722 | /*computeType=*/dmatCType); |
| 723 | Value bufferSz = bufferComp.getResult(0); |
| 724 | token = bufferComp.getAsyncToken(); |
| 725 | auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); |
| 726 | Value buffer = buf.getResult(0); |
| 727 | token = buf.getAsyncToken(); |
| 728 | auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); |
| 729 | |
| 730 | // Perform the SpMM. |
| 731 | auto spmmComp = rewriter.create<gpu::SpMMOp>( |
| 732 | loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer); |
| 733 | token = spmmComp.getAsyncToken(); |
| 734 | |
| 735 | // Copy data back to host and free all the resoures. |
| 736 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) |
| 737 | .getAsyncToken(); |
| 738 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) |
| 739 | .getAsyncToken(); |
| 740 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) |
| 741 | .getAsyncToken(); |
| 742 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token); |
| 743 | if (colA) |
| 744 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token); |
| 745 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token); |
| 746 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token); |
| 747 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token); |
| 748 | token = genCopyMemRef(builder&: rewriter, loc, dst: bufC, src: matC, token); |
| 749 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matC, token); |
| 750 | tokens.push_back(Elt: token); |
| 751 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 752 | tokens.clear(); |
| 753 | |
| 754 | // Done. |
| 755 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); |
| 756 | return success(); |
| 757 | } |
| 758 | |
| 759 | // Match and rewrite SpGEMM kernel. |
| 760 | static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, |
| 761 | linalg::GenericOp op, bool enableRT) { |
| 762 | Location loc = op.getLoc(); |
| 763 | Value a = op.getOperand(0); |
| 764 | Value b = op.getOperand(1); |
| 765 | Value c = op.getOperand(2); // we have C = AB |
| 766 | SmallVector<Value> tokens; |
| 767 | |
| 768 | // Only CSR <- CSR x CSR supported. |
| 769 | auto format = CuSparseFormat::kCSR; |
| 770 | SparseTensorType aTp = getSparseTensorType(val: a); |
| 771 | SparseTensorType bTp = getSparseTensorType(val: b); |
| 772 | SparseTensorType cTp = getSparseTensorType(val: c); |
| 773 | if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(aTp&: bTp) || !isAdmissibleCSR(aTp&: cTp)) |
| 774 | return failure(); |
| 775 | |
| 776 | // Start sparse kernel and copy data from host to device. |
| 777 | // a : amemR/amemC/amemV -> rowA,colA,valA |
| 778 | // b : bmemR/bmemC/bmemV -> rowB,colB,valB |
| 779 | // c : materializes |
| 780 | auto dnCType = cTp.getElementType(); |
| 781 | Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); |
| 782 | Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b); |
| 783 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
| 784 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
| 785 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1); |
| 786 | Value amemR = genFirstPosOrCrds(builder&: rewriter, loc, a, format, enableRT); |
| 787 | Value amemC = genSecondCrds(builder&: rewriter, loc, a, format, enableRT); // not empty |
| 788 | Value amemV = rewriter.create<ToValuesOp>(loc, a); |
| 789 | Value bmemR = genFirstPosOrCrds(builder&: rewriter, loc, a: b, format, enableRT); |
| 790 | Value bmemC = genSecondCrds(builder&: rewriter, loc, a: b, format, enableRT); // not empty |
| 791 | Value bmemV = rewriter.create<ToValuesOp>(loc, b); |
| 792 | Value rowA = genAllocCopy(builder&: rewriter, loc, b: amemR, tokens); |
| 793 | Value colA = genAllocCopy(builder&: rewriter, loc, b: amemC, tokens); |
| 794 | Value valA = genAllocCopy(builder&: rewriter, loc, b: amemV, tokens); |
| 795 | Value rowB = genAllocCopy(builder&: rewriter, loc, b: bmemR, tokens); |
| 796 | Value colB = genAllocCopy(builder&: rewriter, loc, b: bmemC, tokens); |
| 797 | Value valB = genAllocCopy(builder&: rewriter, loc, b: bmemV, tokens); |
| 798 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 799 | tokens.clear(); |
| 800 | |
| 801 | // Create sparse environment and sparse matrix/dense vector handles. |
| 802 | Type indexTp = rewriter.getIndexType(); |
| 803 | Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
| 804 | Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>(); |
| 805 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
| 806 | Value token = genFirstWait(builder&: rewriter, loc); |
| 807 | Operation *spGenA = |
| 808 | genSpMat(builder&: rewriter, loc, aTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szm, sz2: szk, |
| 809 | nseA, rowA, colA, valA, format, enableRT); |
| 810 | Value spMatA = spGenA->getResult(idx: 0); |
| 811 | token = spGenA->getResult(idx: 1); |
| 812 | Operation *spGenB = |
| 813 | genSpMat(builder&: rewriter, loc, aTp&: bTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szk, sz2: szn, |
| 814 | nseA: nseB, rowA: rowB, colA: colB, valA: valB, format, enableRT); |
| 815 | Value spMatB = spGenB->getResult(idx: 0); |
| 816 | token = spGenB->getResult(idx: 1); |
| 817 | |
| 818 | // Sparse matrix C materializes (also assumes beta == 0). |
| 819 | Value zero = constantIndex(builder&: rewriter, loc, i: 0); |
| 820 | Value one = constantIndex(builder&: rewriter, loc, i: 1); |
| 821 | Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one); |
| 822 | auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token); |
| 823 | Value rowC = e1.getResult(0); |
| 824 | token = e1.getAsyncToken(); |
| 825 | auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token); |
| 826 | Value colC = e2.getResult(0); // no free needed |
| 827 | token = e2.getAsyncToken(); |
| 828 | auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token); |
| 829 | Value valC = e3.getResult(0); // no free needed |
| 830 | token = e3.getAsyncToken(); |
| 831 | Operation *spGenC = |
| 832 | genSpMat(builder&: rewriter, loc, aTp&: cTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szm, sz2: szn, |
| 833 | nseA: zero, rowA: rowC, colA: colC, valA: valC, format, enableRT); |
| 834 | Value spMatC = spGenC->getResult(idx: 0); |
| 835 | token = spGenC->getResult(idx: 1); |
| 836 | |
| 837 | // Precompute buffersizes for SpGEMM. |
| 838 | Operation *descOp = |
| 839 | rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token); |
| 840 | Value desc = descOp->getResult(idx: 0); |
| 841 | token = descOp->getResult(idx: 1); |
| 842 | Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
| 843 | loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, |
| 844 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, |
| 845 | valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); |
| 846 | Value bufferSz1 = work1->getResult(idx: 0); |
| 847 | token = work1->getResult(idx: 1); |
| 848 | auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); |
| 849 | Value buffer1 = buf1.getResult(0); |
| 850 | token = buf1.getAsyncToken(); |
| 851 | Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
| 852 | loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, |
| 853 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, |
| 854 | bufferSz1, buffer1, |
| 855 | gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); |
| 856 | token = work2->getResult(idx: 1); |
| 857 | |
| 858 | // Compute step. |
| 859 | Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
| 860 | loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, |
| 861 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, |
| 862 | valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); |
| 863 | Value bufferSz2 = compute1->getResult(idx: 0); |
| 864 | token = compute1->getResult(idx: 1); |
| 865 | auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); |
| 866 | Value buffer2 = buf2.getResult(0); |
| 867 | token = buf2.getAsyncToken(); |
| 868 | Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
| 869 | loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, |
| 870 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, |
| 871 | bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); |
| 872 | token = compute2->getResult(idx: 1); |
| 873 | |
| 874 | // Get sizes. |
| 875 | Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>( |
| 876 | loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC); |
| 877 | Value nnz = sizes->getResult(idx: 2); |
| 878 | token = sizes->getResult(idx: 3); |
| 879 | auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token); |
| 880 | colC = a2.getResult(0); |
| 881 | token = a2.getAsyncToken(); |
| 882 | auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token); |
| 883 | valC = a3.getResult(0); |
| 884 | token = a3.getAsyncToken(); |
| 885 | |
| 886 | // Update C with new pointers and copy final product back into C. |
| 887 | Operation *update = rewriter.create<gpu::SetCsrPointersOp>( |
| 888 | loc, tokenTp, token, spMatC, rowC, colC, valC); |
| 889 | token = update->getResult(idx: 0); |
| 890 | Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>( |
| 891 | loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, |
| 892 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType); |
| 893 | token = copy->getResult(idx: 0); |
| 894 | |
| 895 | // Allocate buffers on host. |
| 896 | Value rowH = genHostBuffer(builder&: rewriter, loc, type: cTp.getPosType(), size: mplus1); |
| 897 | Value colH = genHostBuffer(builder&: rewriter, loc, type: cTp.getCrdType(), size: nnz); |
| 898 | Value valH = genHostBuffer(rewriter, loc, dnCType, nnz); |
| 899 | |
| 900 | // Copy data back to host and free all the resoures. |
| 901 | token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc) |
| 902 | .getAsyncToken(); |
| 903 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) |
| 904 | .getAsyncToken(); |
| 905 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB) |
| 906 | .getAsyncToken(); |
| 907 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) |
| 908 | .getAsyncToken(); |
| 909 | token = genCopyMemRef(builder&: rewriter, loc, dst: rowH, src: rowC, token); |
| 910 | token = genCopyMemRef(builder&: rewriter, loc, dst: colH, src: colC, token); |
| 911 | token = genCopyMemRef(builder&: rewriter, loc, dst: valH, src: valC, token); |
| 912 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token); |
| 913 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token); |
| 914 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token); |
| 915 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowB, token); |
| 916 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colB, token); |
| 917 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valB, token); |
| 918 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowC, token); |
| 919 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colC, token); |
| 920 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valC, token); |
| 921 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer1, token); |
| 922 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer2, token); |
| 923 | tokens.push_back(Elt: token); |
| 924 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 925 | tokens.clear(); |
| 926 | |
| 927 | // Done. |
| 928 | Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH); |
| 929 | Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH); |
| 930 | Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH); |
| 931 | rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct}, |
| 932 | vt); |
| 933 | return success(); |
| 934 | } |
| 935 | |
| 936 | // Match and rewrite 2:4 SpMM kernel. |
| 937 | static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, |
| 938 | linalg::GenericOp op) { |
| 939 | Location loc = op.getLoc(); |
| 940 | Value A = op.getOperand(0); |
| 941 | Value B = op.getOperand(1); |
| 942 | Value C = op.getOperand(2); // we have C = AB |
| 943 | SmallVector<Value> tokens; |
| 944 | |
| 945 | // The cuSparselt API currently only allows pruning and compression |
| 946 | // to occur on the device. So we recognize the pattern |
| 947 | // A' = convert A ; dense to 2:4 |
| 948 | // C = A'B ; 2:4 matrix mult |
| 949 | // and then perform compression and matrix multiplication on device. |
| 950 | auto cnv = A.getDefiningOp<ConvertOp>(); |
| 951 | assert(cnv); |
| 952 | A = cnv.getSource(); |
| 953 | |
| 954 | // All input should be dense tensors. |
| 955 | if (!isDenseTensor(v: A) || !isDenseTensor(v: B) || !isDenseTensor(v: C)) |
| 956 | return failure(); |
| 957 | |
| 958 | // Start sparse kernel and copy data from host to device. |
| 959 | // a : bufA -> matA |
| 960 | // b : bufB -> matB |
| 961 | // c : bufC -> matC |
| 962 | Value bufA = genTensorToMemref(rewriter, loc, tensor: A); |
| 963 | Value matA = genAllocCopy(builder&: rewriter, loc, b: bufA, tokens); |
| 964 | Value bufB = genTensorToMemref(rewriter, loc, tensor: B); |
| 965 | Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens); |
| 966 | Value bufC = genTensorToMemref(rewriter, loc, tensor: C); |
| 967 | Value matC = genAllocCopy(builder&: rewriter, loc, b: bufC, tokens); |
| 968 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 969 | tokens.clear(); |
| 970 | |
| 971 | // Create sparse environment and sparse matrix/dense vector handles. |
| 972 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matA, dim: 0); |
| 973 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matB, dim: 0); |
| 974 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matC, dim: 1); |
| 975 | Type indexTp = rewriter.getIndexType(); |
| 976 | Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
| 977 | Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
| 978 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
| 979 | Value token = genFirstWait(builder&: rewriter, loc); |
| 980 | Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>( |
| 981 | loc, spMatHandleTp, tokenTp, token, szm, szk, |
| 982 | gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA); |
| 983 | Value spMatA = spGenA->getResult(idx: 0); |
| 984 | token = spGenA->getResult(idx: 1); |
| 985 | auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( |
| 986 | loc, dnTensorHandleTp, tokenTp, token, matB, |
| 987 | SmallVector<Value>{szk, szn}); |
| 988 | Value dnB = dmatB.getResult(0); |
| 989 | token = dmatB.getAsyncToken(); |
| 990 | auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( |
| 991 | loc, dnTensorHandleTp, tokenTp, token, matC, |
| 992 | SmallVector<Value>{szm, szn}); |
| 993 | Value dnC = dmatC.getResult(0); |
| 994 | token = dmatC.getAsyncToken(); |
| 995 | auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); |
| 996 | |
| 997 | // Precompute buffersize for SpMM. |
| 998 | SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp}; |
| 999 | TypeRange bufferTypes(bufferTypes_); |
| 1000 | auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( |
| 1001 | loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, |
| 1002 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, |
| 1003 | /*computeType=*/dmatCType); |
| 1004 | token = bufferComp.getAsyncToken(); |
| 1005 | |
| 1006 | // Allocate buffers on host. |
| 1007 | Value bufferSz1 = bufferComp.getResult(0); |
| 1008 | auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); |
| 1009 | Value buffer1 = buf1.getResult(0); |
| 1010 | token = buf1.getAsyncToken(); |
| 1011 | Value bufferSz2 = bufferComp.getResult(1); |
| 1012 | auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); |
| 1013 | Value buffer2 = buf2.getResult(0); |
| 1014 | token = buf2.getAsyncToken(); |
| 1015 | Value bufferSz3 = bufferComp.getResult(2); |
| 1016 | auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token); |
| 1017 | Value buffer3 = buf3.getResult(0); |
| 1018 | token = buf3.getAsyncToken(); |
| 1019 | |
| 1020 | // Perform the SpMM. |
| 1021 | auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); |
| 1022 | auto spmmComp = rewriter.create<gpu::SpMMOp>( |
| 1023 | loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, |
| 1024 | SmallVector<Value>{buffer1, buffer2, buffer3}); |
| 1025 | token = spmmComp.getAsyncToken(); |
| 1026 | |
| 1027 | // Copy data back to host and free all the resources. |
| 1028 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) |
| 1029 | .getAsyncToken(); |
| 1030 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) |
| 1031 | .getAsyncToken(); |
| 1032 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) |
| 1033 | .getAsyncToken(); |
| 1034 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer1, token); |
| 1035 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer2, token); |
| 1036 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer3, token); |
| 1037 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matA, token); |
| 1038 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token); |
| 1039 | token = genCopyMemRef(builder&: rewriter, loc, dst: bufC, src: matC, token); |
| 1040 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matC, token); |
| 1041 | tokens.push_back(Elt: token); |
| 1042 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 1043 | tokens.clear(); |
| 1044 | |
| 1045 | // Done. |
| 1046 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); |
| 1047 | return success(); |
| 1048 | } |
| 1049 | |
| 1050 | /// Match and rewrite SDDMM kernel. |
| 1051 | static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, |
| 1052 | linalg::GenericOp op, bool enableRT) { |
| 1053 | Location loc = op.getLoc(); |
| 1054 | Value a = op.getOperand(0); |
| 1055 | Value b = op.getOperand(1); |
| 1056 | Value c = op.getOperand(2); |
| 1057 | SmallVector<Value> tokens; |
| 1058 | |
| 1059 | // Only admissible sparse matrix format (no COO/CSC) and dense matrices. |
| 1060 | SparseTensorType aTp = getSparseTensorType(val: a); |
| 1061 | SparseTensorType bTp = getSparseTensorType(val: b); |
| 1062 | SparseTensorType cTp = getSparseTensorType(val: c); |
| 1063 | auto format = getCuSparseFormat(aTp: cTp, bTp, cTp: aTp, enableRT, /*isMatVec=*/false); |
| 1064 | if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO || |
| 1065 | format == CuSparseFormat::kCSC) |
| 1066 | return failure(); |
| 1067 | |
| 1068 | // The SDDMM does the in-place operation. |
| 1069 | // Start sparse kernel and copy data from host to device. |
| 1070 | // a : bufA -> matA |
| 1071 | // b : bufB -> matB |
| 1072 | // c : memR/memC/memV -> rowC,colC,valC |
| 1073 | Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c); |
| 1074 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
| 1075 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
| 1076 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1); |
| 1077 | Value bufA = genTensorToMemref(rewriter, loc, tensor: a); |
| 1078 | Value matA = genAllocCopy(builder&: rewriter, loc, b: bufA, tokens); |
| 1079 | Value bufB = genTensorToMemref(rewriter, loc, tensor: b); |
| 1080 | Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens); |
| 1081 | Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT); |
| 1082 | Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty |
| 1083 | Value memV = rewriter.create<ToValuesOp>(loc, c); |
| 1084 | Value rowC = genAllocCopy(builder&: rewriter, loc, b: memR, tokens); |
| 1085 | Value colC = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value(); |
| 1086 | Value valC = genAllocCopy(builder&: rewriter, loc, b: memV, tokens); |
| 1087 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 1088 | tokens.clear(); |
| 1089 | |
| 1090 | // Create sparse environment and sparse matrix/dense matrix handles. |
| 1091 | Type indexTp = rewriter.getIndexType(); |
| 1092 | Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
| 1093 | Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
| 1094 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
| 1095 | Value token = genFirstWait(builder&: rewriter, loc); |
| 1096 | auto dmatA = rewriter.create<gpu::CreateDnTensorOp>( |
| 1097 | loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk}); |
| 1098 | Value dnA = dmatA.getResult(0); |
| 1099 | token = dmatA.getAsyncToken(); |
| 1100 | auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( |
| 1101 | loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn}); |
| 1102 | Value dnB = dmatB.getResult(0); |
| 1103 | token = dmatB.getAsyncToken(); |
| 1104 | Operation *spGenC = |
| 1105 | genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn, |
| 1106 | nseC, rowC, colC, valC, format, enableRT); |
| 1107 | Value spMatC = spGenC->getResult(idx: 0); |
| 1108 | token = spGenC->getResult(idx: 1); |
| 1109 | auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); |
| 1110 | |
| 1111 | // Precompute buffersize for SDDMM. |
| 1112 | auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>( |
| 1113 | loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType); |
| 1114 | Value bufferSz = bufferComp.getResult(0); |
| 1115 | token = bufferComp.getAsyncToken(); |
| 1116 | auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); |
| 1117 | Value buffer = buf.getResult(0); |
| 1118 | token = buf.getAsyncToken(); |
| 1119 | |
| 1120 | // Perform the SDDMM. |
| 1121 | auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB, |
| 1122 | spMatC, dnCType, buffer); |
| 1123 | token = sddmmComp.getAsyncToken(); |
| 1124 | |
| 1125 | // Copy data back to host and free all the resoures. |
| 1126 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA) |
| 1127 | .getAsyncToken(); |
| 1128 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) |
| 1129 | .getAsyncToken(); |
| 1130 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) |
| 1131 | .getAsyncToken(); |
| 1132 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token); |
| 1133 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matA, token); |
| 1134 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token); |
| 1135 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowC, token); |
| 1136 | if (colC) |
| 1137 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colC, token); |
| 1138 | token = genCopyMemRef(builder&: rewriter, loc, dst: memV, src: valC, token); |
| 1139 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valC, token); |
| 1140 | tokens.push_back(Elt: token); |
| 1141 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 1142 | tokens.clear(); |
| 1143 | |
| 1144 | // Done. |
| 1145 | rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c); |
| 1146 | return success(); |
| 1147 | } |
| 1148 | |
| 1149 | //===----------------------------------------------------------------------===// |
| 1150 | // Rewriting rules for direct code generation. |
| 1151 | //===----------------------------------------------------------------------===// |
| 1152 | |
| 1153 | /// Proof-of-concept rewriter. This rule generates a GPU implementation |
| 1154 | /// for each outermost forall loop generated by the sparsifier. |
| 1155 | /// TODO: right now works with parallelization-strategy=dense-outer-loop |
| 1156 | /// but give this its own flags in the future |
| 1157 | struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> { |
| 1158 | using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; |
| 1159 | |
| 1160 | ForallRewriter(MLIRContext *context, unsigned nT) |
| 1161 | : OpRewritePattern(context), numThreads(nT){}; |
| 1162 | |
| 1163 | LogicalResult matchAndRewrite(scf::ParallelOp forallOp, |
| 1164 | PatternRewriter &rewriter) const override { |
| 1165 | // Reject inadmissible loop form. |
| 1166 | // Essentially only accept a loop, generated by the sparsifier, |
| 1167 | // of the form |
| 1168 | // forall (i = 0; i < N; i++) |
| 1169 | // so that cyclic scheduling over the threads is easy. |
| 1170 | if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) || |
| 1171 | forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 || |
| 1172 | !matchPattern(forallOp.getLowerBound()[0], m_Zero()) || |
| 1173 | !matchPattern(forallOp.getStep()[0], m_One())) |
| 1174 | return failure(); |
| 1175 | // Collect every value that is computed outside the parallel loop. |
| 1176 | SetVector<Value> invariants; // stable iteration! |
| 1177 | forallOp->walk([&](Operation *op) { |
| 1178 | // Collect all values of admissible ops. |
| 1179 | for (OpOperand &o : op->getOpOperands()) { |
| 1180 | Value val = o.get(); |
| 1181 | Block *block; |
| 1182 | if (auto arg = dyn_cast<BlockArgument>(Val&: val)) |
| 1183 | block = arg.getOwner(); |
| 1184 | else |
| 1185 | block = val.getDefiningOp()->getBlock(); |
| 1186 | if (!forallOp.getRegion().findAncestorBlockInRegion(*block)) |
| 1187 | invariants.insert(X: val); |
| 1188 | } |
| 1189 | }); |
| 1190 | // Outline the outside values as proper parameters. Fail when sharing |
| 1191 | // value between host and device is not straightforward. |
| 1192 | SmallVector<Value> constants; |
| 1193 | SmallVector<Value> scalars; |
| 1194 | SmallVector<Value> buffers; |
| 1195 | for (Value val : invariants) { |
| 1196 | Type tp = val.getType(); |
| 1197 | if (val.getDefiningOp<arith::ConstantOp>()) |
| 1198 | constants.push_back(Elt: val); |
| 1199 | else if (isa<FloatType>(Val: tp) || tp.isIntOrIndex()) |
| 1200 | scalars.push_back(Elt: val); |
| 1201 | else if (isa<MemRefType>(Val: tp)) |
| 1202 | buffers.push_back(Elt: val); |
| 1203 | else |
| 1204 | return failure(); // don't know how to share |
| 1205 | } |
| 1206 | // Pass outlined non-constant values. |
| 1207 | // TODO: Experiment with `useHostRegistrationForOut` to see if we want to |
| 1208 | // keep the feature at all (either through a heuristic or compiler |
| 1209 | // option for gpu codegen). |
| 1210 | Location loc = forallOp->getLoc(); |
| 1211 | SmallVector<Value> args; |
| 1212 | SmallVector<Value> tokens; |
| 1213 | Value out = genParametersIn(builder&: rewriter, loc, scalars, buffers, args, tokens, |
| 1214 | /*useHostRegistrationForOut=*/false); |
| 1215 | // Set up GPU module and construct GPU function. |
| 1216 | auto saveIp = rewriter.saveInsertionPoint(); |
| 1217 | ModuleOp topModule = forallOp->getParentOfType<ModuleOp>(); |
| 1218 | auto gpuModule = genGPUModule(rewriter, topModule); |
| 1219 | auto gpuFunc = genGPUFunc(rewriter, gpuModule, args); |
| 1220 | genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers); |
| 1221 | // Generate code that launches the kernel asynchronously, blocking on all |
| 1222 | // opens tokens and yielding a new token for the output. |
| 1223 | // TODO: Passing in tokens to launch up does not seem to be properly lowered |
| 1224 | // by cubin yet, hence the current blocking wait. |
| 1225 | rewriter.restoreInsertionPoint(ip: saveIp); |
| 1226 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 1227 | tokens.clear(); |
| 1228 | Value kernelToken = |
| 1229 | genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads); |
| 1230 | // Finalize the outlined arguments. |
| 1231 | genParametersOut(builder&: rewriter, loc, out, kernelToken, scalars, buffers, args, |
| 1232 | tokens); |
| 1233 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
| 1234 | rewriter.eraseOp(op: forallOp); |
| 1235 | return success(); |
| 1236 | } |
| 1237 | |
| 1238 | private: |
| 1239 | unsigned numThreads; |
| 1240 | }; |
| 1241 | |
| 1242 | //===----------------------------------------------------------------------===// |
| 1243 | // Rewriting rules for library recognition and code generation. |
| 1244 | //===----------------------------------------------------------------------===// |
| 1245 | |
| 1246 | /// Proof-of-concept rewriter. This rule recognizes certain math kernels |
| 1247 | /// and replaces these with corresponding calls into a sparse library. |
| 1248 | struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> { |
| 1249 | using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; |
| 1250 | |
| 1251 | LinalgOpRewriter(MLIRContext *context, bool rt) |
| 1252 | : OpRewritePattern(context), enableRT(rt) {} |
| 1253 | |
| 1254 | LogicalResult matchAndRewrite(linalg::GenericOp op, |
| 1255 | PatternRewriter &rewriter) const override { |
| 1256 | if (op.getNumDpsInits() != 1) |
| 1257 | return failure(); // reject multi-output |
| 1258 | |
| 1259 | const unsigned numLoops = op.getNumLoops(); |
| 1260 | const unsigned numTensors = op->getNumOperands(); |
| 1261 | const auto iteratorTypes = op.getIteratorTypesArray(); |
| 1262 | SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); |
| 1263 | |
| 1264 | using MapList = ArrayRef<ArrayRef<AffineExpr>>; |
| 1265 | auto infer = [&](MapList m) { |
| 1266 | return AffineMap::inferFromExprList(m, op.getContext()); |
| 1267 | }; |
| 1268 | AffineExpr i, j, k; |
| 1269 | bindDims(ctx: getContext(), exprs&: i, exprs&: j, exprs&: k); |
| 1270 | |
| 1271 | // TODO: more robust patterns, transposed versions, more kernels, |
| 1272 | // identify alpha and beta and pass them to the CUDA calls. |
| 1273 | |
| 1274 | // Recognize a SpMV kernel. |
| 1275 | if (numLoops == 2 && numTensors == 3 && |
| 1276 | linalg::isParallelIterator(iteratorTypes[0]) && |
| 1277 | linalg::isReductionIterator(iteratorTypes[1]) && |
| 1278 | maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { |
| 1279 | return rewriteSpMV(rewriter, op, enableRT); |
| 1280 | } |
| 1281 | |
| 1282 | // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel. |
| 1283 | if (numLoops == 3 && numTensors == 3 && |
| 1284 | linalg::isParallelIterator(iteratorTypes[0]) && |
| 1285 | linalg::isParallelIterator(iteratorTypes[1]) && |
| 1286 | linalg::isReductionIterator(iteratorTypes[2]) && |
| 1287 | maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { |
| 1288 | if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1))) |
| 1289 | return rewriteSpGEMM(rewriter, op, enableRT); |
| 1290 | if (isConversionInto24(op.getOperand(0))) |
| 1291 | return rewrite2To4SpMM(rewriter, op); |
| 1292 | return rewriteSpMM(rewriter, op, enableRT); |
| 1293 | } |
| 1294 | |
| 1295 | // Recognize a SDDMM kernel. |
| 1296 | if (numLoops == 3 && numTensors == 3 && |
| 1297 | linalg::isParallelIterator(iteratorTypes[0]) && |
| 1298 | linalg::isParallelIterator(iteratorTypes[1]) && |
| 1299 | linalg::isReductionIterator(iteratorTypes[2]) && |
| 1300 | maps == infer({{i, k}, {k, j}, {i, j}}) && |
| 1301 | matchSumReductionOfMulUnary(op)) { |
| 1302 | return rewriteSDDMM(rewriter, op, enableRT); |
| 1303 | } |
| 1304 | |
| 1305 | return failure(); |
| 1306 | } |
| 1307 | |
| 1308 | private: |
| 1309 | bool enableRT; |
| 1310 | }; |
| 1311 | |
| 1312 | } // namespace |
| 1313 | |
| 1314 | //===----------------------------------------------------------------------===// |
| 1315 | // Public method for populating GPU rewriting rules. |
| 1316 | // |
| 1317 | // Currently two set of rewriting rules are made available. The first set |
| 1318 | // implements direct code generation, currently by means of convering the |
| 1319 | // outermost paralell loop into GPU threads. The second set implements |
| 1320 | // libary recognition of a set of sparse operations. Eventually, the right |
| 1321 | // combination of these two approaches has to be found. |
| 1322 | //===----------------------------------------------------------------------===// |
| 1323 | |
| 1324 | void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, |
| 1325 | unsigned numThreads) { |
| 1326 | patterns.add<ForallRewriter>(arg: patterns.getContext(), args&: numThreads); |
| 1327 | } |
| 1328 | |
| 1329 | void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, |
| 1330 | bool enableRT) { |
| 1331 | patterns.add<LinalgOpRewriter>(arg: patterns.getContext(), args&: enableRT); |
| 1332 | } |
| 1333 | |