1//===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is a prototype GPU codegenerator for the sparsifier.
10// The objective is to eventually use the right combination of
11// direct code generation and libary calls into vendor-specific
12// highly optimized sparse libraries (e.g. cuSparse for CUDA).
13//
14//===----------------------------------------------------------------------===//
15
16#include "Utils/CodegenUtils.h"
17#include "Utils/LoopEmitter.h"
18
19#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
20#include "mlir/Dialect/GPU/IR/GPUDialect.h"
21#include "mlir/Dialect/Linalg/IR/Linalg.h"
22#include "mlir/Dialect/Linalg/Utils/Utils.h"
23#include "mlir/Dialect/MemRef/IR/MemRef.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
26#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
27#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
28#include "mlir/IR/IRMapping.h"
29#include "mlir/IR/Matchers.h"
30
31using namespace mlir;
32using namespace mlir::sparse_tensor;
33
34namespace {
35
36// Sparse formats supported by cuSparse.
37enum class CuSparseFormat {
38 kNone,
39 kCOO,
40 kCSR,
41 kCSC,
42 kBSR,
43};
44
45//===----------------------------------------------------------------------===//
46// Helper methods.
47//===----------------------------------------------------------------------===//
48
49/// Marks the given top module as a GPU container module.
50static void markAsGPUContainer(ModuleOp topModule) {
51 topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
52 UnitAttr::get(topModule->getContext()));
53}
54
55/// Constructs a new GPU module (for GPU kernels) inside the given top module,
56/// or returns an existing GPU module if one was built previously.
57static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) {
58 for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
59 return op; // existing
60 markAsGPUContainer(topModule);
61 builder.setInsertionPointToStart(&topModule.getBodyRegion().front());
62 return builder.create<gpu::GPUModuleOp>(topModule->getLoc(),
63 "sparse_kernels");
64}
65
66/// Constructs a new GPU kernel in the given GPU module.
67static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
68 SmallVectorImpl<Value> &args) {
69 // Get a unique kernel name. Not very creative,
70 // but we simply try kernel0, kernel1, etc.
71 unsigned kernelNumber = 0;
72 SmallString<16> kernelName;
73 do {
74 kernelName.clear();
75 ("kernel" + Twine(kernelNumber++)).toStringRef(Out&: kernelName);
76 } while (gpuModule.lookupSymbol(kernelName));
77 // Then we insert a new kernel with given arguments into the module.
78 builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front());
79 SmallVector<Type> argsTp;
80 for (auto arg : args)
81 argsTp.push_back(Elt: arg.getType());
82 FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
83 auto gpuFunc =
84 builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type);
85 gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
86 builder.getUnitAttr());
87 return gpuFunc;
88}
89
90/// Constructs code to launch GPU kernel.
91static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
92 SmallVectorImpl<Value> &args,
93 SmallVectorImpl<Value> &tokens,
94 unsigned numThreads) {
95 Location loc = gpuFunc->getLoc();
96 Value none = TypedValue<::mlir::IntegerType>{};
97 Value one = constantIndex(builder, loc, i: 1);
98 Value numT = constantIndex(builder, loc, i: numThreads);
99 gpu::KernelDim3 gridSize = {.x: one, .y: one, .z: one};
100 gpu::KernelDim3 blckSize = {.x: numT, .y: one, .z: one};
101 return builder
102 .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
103 /*dynSharedMemSz*/ none, args,
104 builder.getType<gpu::AsyncTokenType>(), tokens)
105 .getAsyncToken();
106}
107
108/// Maps the provided ranked host buffer into the device address space.
109/// Writes from the host are guaranteed to be visible to device kernels
110/// that are launched afterwards. Writes from the device are guaranteed
111/// to be visible on the host after synchronizing with the device kernel
112/// completion. Needs to cast the buffer to a unranked buffer.
113static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
114 Value mem) {
115 MemRefType memTp = cast<MemRefType>(mem.getType());
116 UnrankedMemRefType resTp =
117 UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
118 Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
119 builder.create<gpu::HostRegisterOp>(loc, cast);
120 return cast;
121}
122
123/// Unmaps the provided buffer, expecting the casted buffer.
124static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
125 Value cast) {
126 builder.create<gpu::HostUnregisterOp>(loc, cast);
127}
128
129/// Generates first wait in an asynchronous chain.
130static Value genFirstWait(OpBuilder &builder, Location loc) {
131 Type tokenType = builder.getType<gpu::AsyncTokenType>();
132 return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange())
133 .getAsyncToken();
134}
135
136/// Generates last, blocking wait in an asynchronous chain.
137static void genBlockingWait(OpBuilder &builder, Location loc,
138 ValueRange operands) {
139 builder.create<gpu::WaitOp>(loc, Type(), operands);
140}
141
142/// Allocates memory on the device.
143/// TODO: A `host_shared` attribute could be used to indicate that
144/// the buffer is visible by both host and device, but lowering
145/// that feature does not seem to be fully supported yet.
146static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
147 Value token) {
148 auto tp = cast<ShapedType>(mem.getType());
149 auto elemTp = tp.getElementType();
150 auto shape = tp.getShape();
151 auto memTp = MemRefType::get(shape, elemTp);
152 SmallVector<Value> dynamicSizes;
153 for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
154 if (shape[r] == ShapedType::kDynamic) {
155 Value dimOp = linalg::createOrFoldDimOp(b&: builder, loc, val: mem, dim: r);
156 dynamicSizes.push_back(Elt: dimOp);
157 }
158 }
159 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
160 token, dynamicSizes, ValueRange());
161}
162
163// Allocates a typed buffer on the host with given size.
164static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
165 Value size) {
166 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
167 return builder.create<memref::AllocOp>(loc, memTp, size).getResult();
168}
169
170// Allocates a typed buffer on the device with given size.
171static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
172 Value size, Value token) {
173 const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
174 return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
175 token, size, ValueRange());
176}
177
178// Allocates a void buffer on the device with given size.
179static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
180 Value token) {
181 return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
182}
183
184/// Deallocates memory from the device.
185static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
186 Value token) {
187 return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem)
188 .getAsyncToken();
189}
190
191/// Copies memory between host and device (direction is implicit).
192static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
193 Value src, Value token) {
194 return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src)
195 .getAsyncToken();
196}
197
198/// Generates an alloc/copy pair.
199static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
200 SmallVectorImpl<Value> &tokens) {
201 Value firstToken = genFirstWait(builder, loc);
202 auto alloc = genAllocMemRef(builder, loc, b, firstToken);
203 Value devMem = alloc.getResult(0);
204 Value depToken = alloc.getAsyncToken(); // copy-after-alloc
205 tokens.push_back(Elt: genCopyMemRef(builder, loc, dst: devMem, src: b, token: depToken));
206 return devMem;
207}
208
209/// Generates a memref from tensor operation.
210static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
211 Value tensor) {
212 auto tensorType = llvm::cast<ShapedType>(tensor.getType());
213 auto memrefType =
214 MemRefType::get(tensorType.getShape(), tensorType.getElementType());
215 return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
216}
217
218/// Prepares the outlined arguments, passing scalars and buffers in. Here we
219/// assume that the first buffer is the one allocated for output. We create
220/// a set of properly chained asynchronous allocation/copy pairs to increase
221/// overlap before launching the kernel.
222static Value genParametersIn(OpBuilder &builder, Location loc,
223 SmallVectorImpl<Value> &scalars,
224 SmallVectorImpl<Value> &buffers,
225 SmallVectorImpl<Value> &args,
226 SmallVectorImpl<Value> &tokens,
227 bool useHostRegistrationForOut) {
228 Value out;
229 // Scalars are passed by value.
230 for (Value s : scalars)
231 args.push_back(Elt: s);
232 // Buffers are need to be made visible on device.
233 for (Value b : buffers) {
234 if (useHostRegistrationForOut) {
235 out = genHostRegisterMemref(builder, loc, mem: b);
236 args.push_back(Elt: b);
237 useHostRegistrationForOut = false;
238 continue;
239 }
240 args.push_back(Elt: genAllocCopy(builder, loc, b, tokens));
241 }
242 return out;
243}
244
245/// Finalizes the outlined arguments. The output buffer is copied depending
246/// on the kernel token and then deallocated. All other buffers are simply
247/// deallocated. Then we wait for all operations to complete.
248static void genParametersOut(OpBuilder &builder, Location loc, Value out,
249 Value kernelToken, SmallVectorImpl<Value> &scalars,
250 SmallVectorImpl<Value> &buffers,
251 SmallVectorImpl<Value> &args,
252 SmallVectorImpl<Value> &tokens) {
253 unsigned base = scalars.size();
254 for (unsigned i = base, e = args.size(); i < e; i++) {
255 Value firstToken;
256 if (i == base) {
257 // Assumed output parameter: unregister or copy-out.
258 if (out) {
259 genHostUnregisterMemref(builder, loc, cast: out);
260 out = Value();
261 continue;
262 }
263 firstToken =
264 genCopyMemRef(builder, loc, dst: buffers[0], src: args[i], token: kernelToken);
265 } else {
266 firstToken = genFirstWait(builder, loc);
267 }
268 tokens.push_back(Elt: genDeallocMemRef(builder, loc, mem: args[i], token: firstToken));
269 }
270}
271
272/// Constructs code for new GPU kernel.
273static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
274 scf::ParallelOp forallOp,
275 SmallVectorImpl<Value> &constants,
276 SmallVectorImpl<Value> &scalars,
277 SmallVectorImpl<Value> &buffers) {
278 Location loc = gpuFunc->getLoc();
279 Block &block = gpuFunc.getBody().front();
280 rewriter.setInsertionPointToStart(&block);
281
282 // Re-generate the constants, recapture all arguments.
283 unsigned arg = 0;
284 IRMapping irMap;
285 for (Value c : constants)
286 irMap.map(from: c, to: rewriter.clone(op&: *c.getDefiningOp())->getResult(idx: 0));
287 for (Value s : scalars)
288 irMap.map(from: s, to: block.getArgument(i: arg++));
289 for (Value b : buffers)
290 irMap.map(from: b, to: block.getArgument(i: arg++));
291
292 // Assume 1-dimensional grid/block configuration (only x dimension),
293 // so that:
294 // row = blockIdx.x * blockDim.x + threadIdx.x
295 // inc = blockDim.x * gridDim.x
296 Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x);
297 Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
298 Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
299 Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x);
300 Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz);
301 Value row = rewriter.create<arith::AddIOp>(loc, mul, tid);
302 Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz);
303
304 // Construct the iteration over the computational space that
305 // accounts for the fact that the total number of threads and
306 // the amount of work to be done usually do not match precisely.
307 // for (r = row; r < N; r += inc) {
308 // <loop-body>
309 // }
310 Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
311 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
312 // The scf.for builder creates an empty block. scf.for does not allow multiple
313 // blocks in its region, so delete the block before `cloneRegionBefore` adds
314 // an additional block.
315 rewriter.eraseBlock(block: forOp.getBody());
316 rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
317 forOp.getRegion().begin(), irMap);
318 // Replace the scf.reduce terminator.
319 rewriter.setInsertionPoint(forOp.getBody()->getTerminator());
320 rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator());
321
322 // Done.
323 rewriter.setInsertionPointAfter(forOp);
324 rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc());
325}
326
327//===----------------------------------------------------------------------===//
328// Library helper methods.
329//===----------------------------------------------------------------------===//
330
331/// Helper to detect a + b with arguments taken from given block.
332static bool matchAddOfArgs(Block *block, Value val) {
333 if (auto *def = val.getDefiningOp()) {
334 if (isa<arith::AddFOp, arith::AddIOp>(def)) {
335 Value a = block->getArguments()[0];
336 Value b = block->getArguments()[1];
337 return (def->getOperand(idx: 0) == a && def->getOperand(idx: 1) == b) ||
338 (def->getOperand(idx: 0) == b && def->getOperand(idx: 1) == a);
339 }
340 }
341 return false;
342}
343
344/// Helper to detect a * b with arguments taken from given block.
345static bool matchMulOfArgs(Block *block, Value val) {
346 if (auto *def = val.getDefiningOp()) {
347 if (isa<arith::MulFOp, arith::MulIOp>(def)) {
348 Value a = block->getArguments()[0];
349 Value b = block->getArguments()[1];
350 return (def->getOperand(idx: 0) == a && def->getOperand(idx: 1) == b) ||
351 (def->getOperand(idx: 0) == b && def->getOperand(idx: 1) == a);
352 }
353 }
354 return false;
355}
356
357/// Helper to detect x = x + a * b
358static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
359 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
360 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
361 if (isa<arith::AddFOp, arith::AddIOp>(def)) {
362 Value x = op.getBlock()->getArguments()[2];
363 return (def->getOperand(0) == x &&
364 matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
365 (def->getOperand(1) == x &&
366 matchMulOfArgs(op.getBlock(), def->getOperand(0)));
367 }
368 }
369 return false;
370}
371
372// Helper to detect c += spy(s) x (a * b)
373static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
374 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
375 // The linalg yields a custom reduce result.
376 Value s_out = op.getBlock()->getArguments()[2];
377 if (auto redOp =
378 yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
379 // The reduce consumes the output.
380 Value other;
381 if (s_out == redOp->getOperand(0))
382 other = redOp->getOperand(1);
383 else if (s_out == redOp->getOperand(1))
384 other = redOp->getOperand(0);
385 else
386 return false;
387 // The reduce op also consumes an unary which also consumes the output
388 // and does not define an absent value.
389 if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
390 if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
391 return false;
392 // And the bodies are as expected.
393 auto yieldUn = cast<sparse_tensor::YieldOp>(
394 unOp.getRegion(0).front().getTerminator());
395 auto yieldRed = cast<sparse_tensor::YieldOp>(
396 redOp.getRegion().front().getTerminator());
397 return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
398 matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
399 }
400 }
401 return false;
402}
403
404/// Test for dense tensor.
405static bool isDenseTensor(Value v) {
406 auto sTp = getSparseTensorType(val: v);
407 return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
408}
409
410/// Test for suitable positions/coordinates width.
411static bool isAdmissibleMetaData(SparseTensorType &aTp) {
412 return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
413 (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
414}
415
416/// Test for sorted COO matrix with suitable metadata.
417static bool isAdmissibleCOO(SparseTensorType &aTp) {
418 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
419 aTp.isCompressedLvl(l: 0) && aTp.isOrderedLvl(l: 0) && !aTp.isUniqueLvl(l: 0) &&
420 aTp.isSingletonLvl(l: 1) && aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) &&
421 isAdmissibleMetaData(aTp);
422}
423
424/// Test for CSR matrix with suitable metadata.
425static bool isAdmissibleCSR(SparseTensorType &aTp) {
426 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
427 aTp.isDenseLvl(l: 0) && aTp.isCompressedLvl(l: 1) && aTp.isOrderedLvl(l: 1) &&
428 aTp.isUniqueLvl(l: 1) && isAdmissibleMetaData(aTp);
429}
430
431/// Test for CSC matrix with suitable metadata.
432static bool isAdmissibleCSC(SparseTensorType &aTp) {
433 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
434 aTp.isPermutation() && aTp.isDenseLvl(l: 0) && aTp.isCompressedLvl(l: 1) &&
435 aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) && isAdmissibleMetaData(aTp);
436}
437
438/// Test for BSR matrix with suitable metadata.
439static bool isAdmissibleBSR(SparseTensorType &aTp) {
440 if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(l: 0) &&
441 aTp.isCompressedLvl(l: 1) && aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) &&
442 aTp.isDenseLvl(l: 2) && aTp.isDenseLvl(l: 3) && isAdmissibleMetaData(aTp)) {
443 // CuSparse only supports "square" blocks currently.
444 SmallVector<unsigned> dims = getBlockSize(dimToLvl: aTp.getDimToLvl());
445 assert(dims.size() == 2);
446 return dims[0] == dims[1] && dims[0] > 1;
447 }
448 return false;
449}
450
451/// Test for 2:4 matrix with suitable metadata.
452static bool isAdmissible24(SparseTensorType &aTp) {
453 return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(l: 0) &&
454 aTp.isDenseLvl(l: 1) && aTp.isNOutOfMLvl(l: 2) && isAdmissibleMetaData(aTp);
455}
456
457/// Test for conversion into 2:4 matrix.
458static bool isConversionInto24(Value v) {
459 if (auto cnv = v.getDefiningOp<ConvertOp>()) {
460 Value a = cnv.getResult();
461 Value d = cnv.getSource();
462 SparseTensorType aTp = getSparseTensorType(val: a);
463 return isDenseTensor(v: d) && isAdmissible24(aTp);
464 }
465 return false;
466}
467
468/// Returns a suitable sparse format for the operation and given operand
469/// types with cuSparse, or kNone if none is available.
470static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
471 SparseTensorType bTp,
472 SparseTensorType cTp, bool enableRT,
473 bool isMatVec) {
474 // The other operands have a dense type.
475 if (bTp.hasEncoding() || cTp.hasEncoding())
476 return CuSparseFormat::kNone;
477 // Now check for suitable operand type for the main operand.
478 if (isAdmissibleCOO(aTp))
479#ifdef CUSPARSE_COO_AOS
480 return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
481#else
482 return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
483#endif
484 if (isAdmissibleCSR(aTp))
485 return CuSparseFormat::kCSR;
486 if (isAdmissibleCSC(aTp))
487 return CuSparseFormat::kCSC;
488 if (isAdmissibleBSR(aTp))
489 return CuSparseFormat::kBSR;
490 return CuSparseFormat::kNone;
491}
492
493/// Generates the first positions/coordinates of a sparse matrix.
494static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
495 CuSparseFormat format, bool enableRT) {
496 if (format == CuSparseFormat::kCOO) {
497 // Library uses SoA COO, direct IR uses AoS COO.
498 if (enableRT)
499 return builder.create<ToCoordinatesOp>(loc, a, 0);
500 return builder.create<ToCoordinatesBufferOp>(loc, a);
501 }
502 // Formats CSR/CSC and BSR use positions at 1.
503 return builder.create<ToPositionsOp>(loc, a, 1);
504}
505
506/// Generates the second coordinates of a sparse matrix.
507static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
508 CuSparseFormat format, bool enableRT) {
509 bool isCOO = format == CuSparseFormat::kCOO;
510 if (isCOO && !enableRT)
511 return Value(); // nothing needed
512 // Formats CSR/CSC and BSR use coordinates at 1.
513 return builder.create<ToCoordinatesOp>(loc, a, 1);
514}
515
516/// Generates the sparse matrix handle.
517static Operation *genSpMat(OpBuilder &builder, Location loc,
518 SparseTensorType &aTp, Type handleTp, Type tokenTp,
519 Value token, Value sz1, Value sz2, Value nseA,
520 Value rowA, Value colA, Value valA,
521 CuSparseFormat format, bool enableRT) {
522 if (format == CuSparseFormat::kCOO) {
523 // Library uses SoA COO, direct IR uses AoS COO.
524 if (enableRT) {
525 assert(colA);
526 return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
527 sz1, sz2, nseA, rowA, colA, valA);
528 }
529#ifdef CUSPARSE_COO_AOS
530 assert(!colA);
531 return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token,
532 sz1, sz2, nseA, rowA, valA);
533#else
534 llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
535#endif
536 }
537 assert(colA);
538 if (format == CuSparseFormat::kCSR)
539 return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
540 sz2, nseA, rowA, colA, valA);
541 if (format == CuSparseFormat::kCSC)
542 return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
543 sz2, nseA, rowA, colA, valA);
544 // BSR requires a bit more work since we need to pass in the block size
545 // and all others sizes in terms of blocks (#block-rows, #block-cols,
546 // #nonzero-blocks).
547 assert(format == CuSparseFormat::kBSR);
548 SmallVector<unsigned> dims = getBlockSize(dimToLvl: aTp.getDimToLvl());
549 assert(dims.size() == 2 && dims[0] == dims[1]);
550 uint64_t b = dims[0];
551 Value bSz = constantIndex(builder, loc, i: b);
552 Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz);
553 Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz);
554 Value bNum = builder.create<arith::DivUIOp>(
555 loc, nseA, constantIndex(builder, loc, b * b));
556 return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows,
557 bCols, bNum, bSz, bSz, rowA, colA,
558 valA);
559}
560
561/// Match and rewrite SpMV kernel.
562static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
563 linalg::GenericOp op, bool enableRT) {
564 Location loc = op.getLoc();
565 Value a = op.getOperand(0);
566 Value x = op.getOperand(1);
567 Value y = op.getOperand(2); // we have y = Ax
568 SmallVector<Value> tokens;
569
570 // Only admissible sparse matrix format and dense vectors (no BSR).
571 SparseTensorType aTp = getSparseTensorType(val: a);
572 SparseTensorType xTp = getSparseTensorType(val: x);
573 SparseTensorType yTp = getSparseTensorType(val: y);
574 auto format = getCuSparseFormat(aTp, bTp: xTp, cTp: yTp, enableRT, /*isMatVec=*/true);
575 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
576 return failure();
577
578 // Start sparse kernel and copy data from host to device.
579 // a : memR/memC/memV -> rowA,colA,valA
580 // x : memX -> vecX
581 // y : memY -> vecY
582 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
583 Value szY = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0);
584 Value szX = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1);
585 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
586 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
587 Value memV = rewriter.create<ToValuesOp>(loc, a);
588 Value rowA = genAllocCopy(builder&: rewriter, loc, b: memR, tokens);
589 Value colA = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value();
590 Value valA = genAllocCopy(builder&: rewriter, loc, b: memV, tokens);
591 Value memX = genTensorToMemref(rewriter, loc, tensor: x);
592 Value vecX = genAllocCopy(builder&: rewriter, loc, b: memX, tokens);
593 Value memY = genTensorToMemref(rewriter, loc, tensor: y);
594 Value vecY = genAllocCopy(builder&: rewriter, loc, b: memY, tokens);
595 genBlockingWait(builder&: rewriter, loc, operands: tokens);
596 tokens.clear();
597
598 // Create sparse environment and sparse matrix/dense vector handles.
599 Type indexTp = rewriter.getIndexType();
600 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
601 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
602 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
603 Value token = genFirstWait(builder&: rewriter, loc);
604 Operation *spGenA =
605 genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
606 nseA, rowA, colA, valA, format, enableRT);
607 Value spMatA = spGenA->getResult(idx: 0);
608 token = spGenA->getResult(idx: 1);
609 auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
610 loc, dnTensorHandleTp, tokenTp, token, vecX, szX);
611 Value dnX = dvecX.getResult(0);
612 token = dvecX.getAsyncToken();
613 auto dvecY = rewriter.create<gpu::CreateDnTensorOp>(
614 loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
615 Value dnY = dvecY.getResult(0);
616 token = dvecY.getAsyncToken();
617 auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
618
619 // Precompute buffersize for SpMV.
620 auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
621 loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
622 /*computeType=*/dnYType);
623 Value bufferSz = bufferComp.getResult(0);
624 token = bufferComp.getAsyncToken();
625 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
626 Value buffer = buf.getResult(0);
627 token = buf.getAsyncToken();
628
629 // Perform the SpMV.
630 auto spmvComp = rewriter.create<gpu::SpMVOp>(
631 loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer);
632 token = spmvComp.getAsyncToken();
633
634 // Copy data back to host and free all the resoures.
635 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
636 .getAsyncToken();
637 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX)
638 .getAsyncToken();
639 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY)
640 .getAsyncToken();
641 token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token);
642 if (colA)
643 token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token);
644 token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token);
645 token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token);
646 token = genDeallocMemRef(builder&: rewriter, loc, mem: vecX, token);
647 token = genCopyMemRef(builder&: rewriter, loc, dst: memY, src: vecY, token);
648 token = genDeallocMemRef(builder&: rewriter, loc, mem: vecY, token);
649 tokens.push_back(Elt: token);
650 genBlockingWait(builder&: rewriter, loc, operands: tokens);
651 tokens.clear();
652
653 // Done.
654 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
655 return success();
656}
657
658/// Match and rewrite SpMM kernel.
659static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
660 linalg::GenericOp op, bool enableRT) {
661 Location loc = op.getLoc();
662 Value a = op.getOperand(0);
663 Value b = op.getOperand(1);
664 Value c = op.getOperand(2); // we have C = AB
665 SmallVector<Value> tokens;
666
667 // Only admissible sparse matrix format and dense matrices (no BSR).
668 SparseTensorType aTp = getSparseTensorType(val: a);
669 SparseTensorType bTp = getSparseTensorType(val: b);
670 SparseTensorType cTp = getSparseTensorType(val: c);
671 auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
672 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
673 return failure();
674
675 // Start sparse kernel and copy data from host to device.
676 // a : memR/memC/memV -> rowA,colA,valA
677 // b : bufB -> matB
678 // c : bufC -> matC
679 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
680 Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0);
681 Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1);
682 Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1);
683 Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
684 Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
685 Value memV = rewriter.create<ToValuesOp>(loc, a);
686 Value rowA = genAllocCopy(builder&: rewriter, loc, b: memR, tokens);
687 Value colA = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value();
688 Value valA = genAllocCopy(builder&: rewriter, loc, b: memV, tokens);
689 Value bufB = genTensorToMemref(rewriter, loc, tensor: b);
690 Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens);
691 Value bufC = genTensorToMemref(rewriter, loc, tensor: c);
692 Value matC = genAllocCopy(builder&: rewriter, loc, b: bufC, tokens);
693 genBlockingWait(builder&: rewriter, loc, operands: tokens);
694 tokens.clear();
695
696 // Create sparse environment and sparse matrix/dense matrix handles.
697 Type indexTp = rewriter.getIndexType();
698 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
699 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
700 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
701 Value token = genFirstWait(builder&: rewriter, loc);
702 Operation *spGenA =
703 genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
704 nseA, rowA, colA, valA, format, enableRT);
705 Value spMatA = spGenA->getResult(idx: 0);
706 token = spGenA->getResult(idx: 1);
707 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
708 loc, dnTensorHandleTp, tokenTp, token, matB,
709 SmallVector<Value>{szk, szn});
710 Value dnB = dmatB.getResult(0);
711 token = dmatB.getAsyncToken();
712 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
713 loc, dnTensorHandleTp, tokenTp, token, matC,
714 SmallVector<Value>{szm, szn});
715 Value dnC = dmatC.getResult(0);
716 token = dmatC.getAsyncToken();
717 auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
718
719 // Precompute buffersize for SpMM.
720 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
721 loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
722 /*computeType=*/dmatCType);
723 Value bufferSz = bufferComp.getResult(0);
724 token = bufferComp.getAsyncToken();
725 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
726 Value buffer = buf.getResult(0);
727 token = buf.getAsyncToken();
728 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
729
730 // Perform the SpMM.
731 auto spmmComp = rewriter.create<gpu::SpMMOp>(
732 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer);
733 token = spmmComp.getAsyncToken();
734
735 // Copy data back to host and free all the resoures.
736 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
737 .getAsyncToken();
738 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
739 .getAsyncToken();
740 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
741 .getAsyncToken();
742 token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token);
743 if (colA)
744 token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token);
745 token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token);
746 token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token);
747 token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token);
748 token = genCopyMemRef(builder&: rewriter, loc, dst: bufC, src: matC, token);
749 token = genDeallocMemRef(builder&: rewriter, loc, mem: matC, token);
750 tokens.push_back(Elt: token);
751 genBlockingWait(builder&: rewriter, loc, operands: tokens);
752 tokens.clear();
753
754 // Done.
755 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
756 return success();
757}
758
759// Match and rewrite SpGEMM kernel.
760static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
761 linalg::GenericOp op, bool enableRT) {
762 Location loc = op.getLoc();
763 Value a = op.getOperand(0);
764 Value b = op.getOperand(1);
765 Value c = op.getOperand(2); // we have C = AB
766 SmallVector<Value> tokens;
767
768 // Only CSR <- CSR x CSR supported.
769 auto format = CuSparseFormat::kCSR;
770 SparseTensorType aTp = getSparseTensorType(val: a);
771 SparseTensorType bTp = getSparseTensorType(val: b);
772 SparseTensorType cTp = getSparseTensorType(val: c);
773 if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(aTp&: bTp) || !isAdmissibleCSR(aTp&: cTp))
774 return failure();
775
776 // Start sparse kernel and copy data from host to device.
777 // a : amemR/amemC/amemV -> rowA,colA,valA
778 // b : bmemR/bmemC/bmemV -> rowB,colB,valB
779 // c : materializes
780 auto dnCType = cTp.getElementType();
781 Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
782 Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b);
783 Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0);
784 Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1);
785 Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1);
786 Value amemR = genFirstPosOrCrds(builder&: rewriter, loc, a, format, enableRT);
787 Value amemC = genSecondCrds(builder&: rewriter, loc, a, format, enableRT); // not empty
788 Value amemV = rewriter.create<ToValuesOp>(loc, a);
789 Value bmemR = genFirstPosOrCrds(builder&: rewriter, loc, a: b, format, enableRT);
790 Value bmemC = genSecondCrds(builder&: rewriter, loc, a: b, format, enableRT); // not empty
791 Value bmemV = rewriter.create<ToValuesOp>(loc, b);
792 Value rowA = genAllocCopy(builder&: rewriter, loc, b: amemR, tokens);
793 Value colA = genAllocCopy(builder&: rewriter, loc, b: amemC, tokens);
794 Value valA = genAllocCopy(builder&: rewriter, loc, b: amemV, tokens);
795 Value rowB = genAllocCopy(builder&: rewriter, loc, b: bmemR, tokens);
796 Value colB = genAllocCopy(builder&: rewriter, loc, b: bmemC, tokens);
797 Value valB = genAllocCopy(builder&: rewriter, loc, b: bmemV, tokens);
798 genBlockingWait(builder&: rewriter, loc, operands: tokens);
799 tokens.clear();
800
801 // Create sparse environment and sparse matrix/dense vector handles.
802 Type indexTp = rewriter.getIndexType();
803 Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
804 Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
805 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
806 Value token = genFirstWait(builder&: rewriter, loc);
807 Operation *spGenA =
808 genSpMat(builder&: rewriter, loc, aTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szm, sz2: szk,
809 nseA, rowA, colA, valA, format, enableRT);
810 Value spMatA = spGenA->getResult(idx: 0);
811 token = spGenA->getResult(idx: 1);
812 Operation *spGenB =
813 genSpMat(builder&: rewriter, loc, aTp&: bTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szk, sz2: szn,
814 nseA: nseB, rowA: rowB, colA: colB, valA: valB, format, enableRT);
815 Value spMatB = spGenB->getResult(idx: 0);
816 token = spGenB->getResult(idx: 1);
817
818 // Sparse matrix C materializes (also assumes beta == 0).
819 Value zero = constantIndex(builder&: rewriter, loc, i: 0);
820 Value one = constantIndex(builder&: rewriter, loc, i: 1);
821 Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one);
822 auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
823 Value rowC = e1.getResult(0);
824 token = e1.getAsyncToken();
825 auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
826 Value colC = e2.getResult(0); // no free needed
827 token = e2.getAsyncToken();
828 auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
829 Value valC = e3.getResult(0); // no free needed
830 token = e3.getAsyncToken();
831 Operation *spGenC =
832 genSpMat(builder&: rewriter, loc, aTp&: cTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szm, sz2: szn,
833 nseA: zero, rowA: rowC, colA: colC, valA: valC, format, enableRT);
834 Value spMatC = spGenC->getResult(idx: 0);
835 token = spGenC->getResult(idx: 1);
836
837 // Precompute buffersizes for SpGEMM.
838 Operation *descOp =
839 rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token);
840 Value desc = descOp->getResult(idx: 0);
841 token = descOp->getResult(idx: 1);
842 Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
843 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
844 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
845 valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
846 Value bufferSz1 = work1->getResult(idx: 0);
847 token = work1->getResult(idx: 1);
848 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
849 Value buffer1 = buf1.getResult(0);
850 token = buf1.getAsyncToken();
851 Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
852 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
853 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
854 bufferSz1, buffer1,
855 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
856 token = work2->getResult(idx: 1);
857
858 // Compute step.
859 Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
860 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
861 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
862 valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
863 Value bufferSz2 = compute1->getResult(idx: 0);
864 token = compute1->getResult(idx: 1);
865 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
866 Value buffer2 = buf2.getResult(0);
867 token = buf2.getAsyncToken();
868 Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
869 loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
870 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
871 bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
872 token = compute2->getResult(idx: 1);
873
874 // Get sizes.
875 Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>(
876 loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
877 Value nnz = sizes->getResult(idx: 2);
878 token = sizes->getResult(idx: 3);
879 auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
880 colC = a2.getResult(0);
881 token = a2.getAsyncToken();
882 auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
883 valC = a3.getResult(0);
884 token = a3.getAsyncToken();
885
886 // Update C with new pointers and copy final product back into C.
887 Operation *update = rewriter.create<gpu::SetCsrPointersOp>(
888 loc, tokenTp, token, spMatC, rowC, colC, valC);
889 token = update->getResult(idx: 0);
890 Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>(
891 loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
892 gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
893 token = copy->getResult(idx: 0);
894
895 // Allocate buffers on host.
896 Value rowH = genHostBuffer(builder&: rewriter, loc, type: cTp.getPosType(), size: mplus1);
897 Value colH = genHostBuffer(builder&: rewriter, loc, type: cTp.getCrdType(), size: nnz);
898 Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
899
900 // Copy data back to host and free all the resoures.
901 token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc)
902 .getAsyncToken();
903 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
904 .getAsyncToken();
905 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB)
906 .getAsyncToken();
907 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
908 .getAsyncToken();
909 token = genCopyMemRef(builder&: rewriter, loc, dst: rowH, src: rowC, token);
910 token = genCopyMemRef(builder&: rewriter, loc, dst: colH, src: colC, token);
911 token = genCopyMemRef(builder&: rewriter, loc, dst: valH, src: valC, token);
912 token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token);
913 token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token);
914 token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token);
915 token = genDeallocMemRef(builder&: rewriter, loc, mem: rowB, token);
916 token = genDeallocMemRef(builder&: rewriter, loc, mem: colB, token);
917 token = genDeallocMemRef(builder&: rewriter, loc, mem: valB, token);
918 token = genDeallocMemRef(builder&: rewriter, loc, mem: rowC, token);
919 token = genDeallocMemRef(builder&: rewriter, loc, mem: colC, token);
920 token = genDeallocMemRef(builder&: rewriter, loc, mem: valC, token);
921 token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer1, token);
922 token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer2, token);
923 tokens.push_back(Elt: token);
924 genBlockingWait(builder&: rewriter, loc, operands: tokens);
925 tokens.clear();
926
927 // Done.
928 Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
929 Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
930 Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
931 rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
932 vt);
933 return success();
934}
935
936// Match and rewrite 2:4 SpMM kernel.
937static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
938 linalg::GenericOp op) {
939 Location loc = op.getLoc();
940 Value A = op.getOperand(0);
941 Value B = op.getOperand(1);
942 Value C = op.getOperand(2); // we have C = AB
943 SmallVector<Value> tokens;
944
945 // The cuSparselt API currently only allows pruning and compression
946 // to occur on the device. So we recognize the pattern
947 // A' = convert A ; dense to 2:4
948 // C = A'B ; 2:4 matrix mult
949 // and then perform compression and matrix multiplication on device.
950 auto cnv = A.getDefiningOp<ConvertOp>();
951 assert(cnv);
952 A = cnv.getSource();
953
954 // All input should be dense tensors.
955 if (!isDenseTensor(v: A) || !isDenseTensor(v: B) || !isDenseTensor(v: C))
956 return failure();
957
958 // Start sparse kernel and copy data from host to device.
959 // a : bufA -> matA
960 // b : bufB -> matB
961 // c : bufC -> matC
962 Value bufA = genTensorToMemref(rewriter, loc, tensor: A);
963 Value matA = genAllocCopy(builder&: rewriter, loc, b: bufA, tokens);
964 Value bufB = genTensorToMemref(rewriter, loc, tensor: B);
965 Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens);
966 Value bufC = genTensorToMemref(rewriter, loc, tensor: C);
967 Value matC = genAllocCopy(builder&: rewriter, loc, b: bufC, tokens);
968 genBlockingWait(builder&: rewriter, loc, operands: tokens);
969 tokens.clear();
970
971 // Create sparse environment and sparse matrix/dense vector handles.
972 Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matA, dim: 0);
973 Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matB, dim: 0);
974 Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matC, dim: 1);
975 Type indexTp = rewriter.getIndexType();
976 Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
977 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
978 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
979 Value token = genFirstWait(builder&: rewriter, loc);
980 Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>(
981 loc, spMatHandleTp, tokenTp, token, szm, szk,
982 gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
983 Value spMatA = spGenA->getResult(idx: 0);
984 token = spGenA->getResult(idx: 1);
985 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
986 loc, dnTensorHandleTp, tokenTp, token, matB,
987 SmallVector<Value>{szk, szn});
988 Value dnB = dmatB.getResult(0);
989 token = dmatB.getAsyncToken();
990 auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
991 loc, dnTensorHandleTp, tokenTp, token, matC,
992 SmallVector<Value>{szm, szn});
993 Value dnC = dmatC.getResult(0);
994 token = dmatC.getAsyncToken();
995 auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
996
997 // Precompute buffersize for SpMM.
998 SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
999 TypeRange bufferTypes(bufferTypes_);
1000 auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
1001 loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
1002 gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
1003 /*computeType=*/dmatCType);
1004 token = bufferComp.getAsyncToken();
1005
1006 // Allocate buffers on host.
1007 Value bufferSz1 = bufferComp.getResult(0);
1008 auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
1009 Value buffer1 = buf1.getResult(0);
1010 token = buf1.getAsyncToken();
1011 Value bufferSz2 = bufferComp.getResult(1);
1012 auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
1013 Value buffer2 = buf2.getResult(0);
1014 token = buf2.getAsyncToken();
1015 Value bufferSz3 = bufferComp.getResult(2);
1016 auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
1017 Value buffer3 = buf3.getResult(0);
1018 token = buf3.getAsyncToken();
1019
1020 // Perform the SpMM.
1021 auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
1022 auto spmmComp = rewriter.create<gpu::SpMMOp>(
1023 loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
1024 SmallVector<Value>{buffer1, buffer2, buffer3});
1025 token = spmmComp.getAsyncToken();
1026
1027 // Copy data back to host and free all the resources.
1028 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
1029 .getAsyncToken();
1030 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1031 .getAsyncToken();
1032 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
1033 .getAsyncToken();
1034 SmallVector<Value> newDynamicSizes;
1035 token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer1, token);
1036 token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer2, token);
1037 token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer3, token);
1038 token = genDeallocMemRef(builder&: rewriter, loc, mem: matA, token);
1039 token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token);
1040 token = genCopyMemRef(builder&: rewriter, loc, dst: bufC, src: matC, token);
1041 token = genDeallocMemRef(builder&: rewriter, loc, mem: matC, token);
1042 tokens.push_back(Elt: token);
1043 genBlockingWait(builder&: rewriter, loc, operands: tokens);
1044 tokens.clear();
1045
1046 // Done.
1047 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
1048 return success();
1049}
1050
1051/// Match and rewrite SDDMM kernel.
1052static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
1053 linalg::GenericOp op, bool enableRT) {
1054 Location loc = op.getLoc();
1055 Value a = op.getOperand(0);
1056 Value b = op.getOperand(1);
1057 Value c = op.getOperand(2);
1058 SmallVector<Value> tokens;
1059
1060 // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
1061 SparseTensorType aTp = getSparseTensorType(val: a);
1062 SparseTensorType bTp = getSparseTensorType(val: b);
1063 SparseTensorType cTp = getSparseTensorType(val: c);
1064 auto format = getCuSparseFormat(aTp: cTp, bTp, cTp: aTp, enableRT, /*isMatVec=*/false);
1065 if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
1066 format == CuSparseFormat::kCSC)
1067 return failure();
1068
1069 // The SDDMM does the in-place operation.
1070 // Start sparse kernel and copy data from host to device.
1071 // a : bufA -> matA
1072 // b : bufB -> matB
1073 // c : memR/memC/memV -> rowC,colC,valC
1074 Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c);
1075 Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0);
1076 Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1);
1077 Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1);
1078 Value bufA = genTensorToMemref(rewriter, loc, tensor: a);
1079 Value matA = genAllocCopy(builder&: rewriter, loc, b: bufA, tokens);
1080 Value bufB = genTensorToMemref(rewriter, loc, tensor: b);
1081 Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens);
1082 Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
1083 Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty
1084 Value memV = rewriter.create<ToValuesOp>(loc, c);
1085 Value rowC = genAllocCopy(builder&: rewriter, loc, b: memR, tokens);
1086 Value colC = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value();
1087 Value valC = genAllocCopy(builder&: rewriter, loc, b: memV, tokens);
1088 genBlockingWait(builder&: rewriter, loc, operands: tokens);
1089 tokens.clear();
1090
1091 // Create sparse environment and sparse matrix/dense matrix handles.
1092 Type indexTp = rewriter.getIndexType();
1093 Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
1094 Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
1095 Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
1096 Value token = genFirstWait(builder&: rewriter, loc);
1097 auto dmatA = rewriter.create<gpu::CreateDnTensorOp>(
1098 loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk});
1099 Value dnA = dmatA.getResult(0);
1100 token = dmatA.getAsyncToken();
1101 auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
1102 loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
1103 Value dnB = dmatB.getResult(0);
1104 token = dmatB.getAsyncToken();
1105 Operation *spGenC =
1106 genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
1107 nseC, rowC, colC, valC, format, enableRT);
1108 Value spMatC = spGenC->getResult(idx: 0);
1109 token = spGenC->getResult(idx: 1);
1110 auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
1111
1112 // Precompute buffersize for SDDMM.
1113 auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>(
1114 loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
1115 Value bufferSz = bufferComp.getResult(0);
1116 token = bufferComp.getAsyncToken();
1117 auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
1118 Value buffer = buf.getResult(0);
1119 token = buf.getAsyncToken();
1120
1121 // Perform the SDDMM.
1122 auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB,
1123 spMatC, dnCType, buffer);
1124 token = sddmmComp.getAsyncToken();
1125
1126 // Copy data back to host and free all the resoures.
1127 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA)
1128 .getAsyncToken();
1129 token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
1130 .getAsyncToken();
1131 token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
1132 .getAsyncToken();
1133 token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token);
1134 token = genDeallocMemRef(builder&: rewriter, loc, mem: matA, token);
1135 token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token);
1136 token = genDeallocMemRef(builder&: rewriter, loc, mem: rowC, token);
1137 if (colC)
1138 token = genDeallocMemRef(builder&: rewriter, loc, mem: colC, token);
1139 token = genCopyMemRef(builder&: rewriter, loc, dst: memV, src: valC, token);
1140 token = genDeallocMemRef(builder&: rewriter, loc, mem: valC, token);
1141 tokens.push_back(Elt: token);
1142 genBlockingWait(builder&: rewriter, loc, operands: tokens);
1143 tokens.clear();
1144
1145 // Done.
1146 rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
1147 return success();
1148}
1149
1150//===----------------------------------------------------------------------===//
1151// Rewriting rules for direct code generation.
1152//===----------------------------------------------------------------------===//
1153
1154/// Proof-of-concept rewriter. This rule generates a GPU implementation
1155/// for each outermost forall loop generated by the sparsifier.
1156/// TODO: right now works with parallelization-strategy=dense-outer-loop
1157/// but give this its own flags in the future
1158struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
1159 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
1160
1161 ForallRewriter(MLIRContext *context, unsigned nT)
1162 : OpRewritePattern(context), numThreads(nT){};
1163
1164 LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
1165 PatternRewriter &rewriter) const override {
1166 // Reject inadmissible loop form.
1167 // Essentially only accept a loop, generated by the sparsifier,
1168 // of the form
1169 // forall (i = 0; i < N; i++)
1170 // so that cyclic scheduling over the threads is easy.
1171 if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
1172 forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
1173 !matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
1174 !matchPattern(forallOp.getStep()[0], m_One()))
1175 return failure();
1176 // Collect every value that is computed outside the parallel loop.
1177 SetVector<Value> invariants; // stable iteration!
1178 forallOp->walk([&](Operation *op) {
1179 // Collect all values of admissible ops.
1180 for (OpOperand &o : op->getOpOperands()) {
1181 Value val = o.get();
1182 Block *block;
1183 if (auto arg = dyn_cast<BlockArgument>(Val&: val))
1184 block = arg.getOwner();
1185 else
1186 block = val.getDefiningOp()->getBlock();
1187 if (!forallOp.getRegion().findAncestorBlockInRegion(*block))
1188 invariants.insert(X: val);
1189 }
1190 });
1191 // Outline the outside values as proper parameters. Fail when sharing
1192 // value between host and device is not straightforward.
1193 SmallVector<Value> constants;
1194 SmallVector<Value> scalars;
1195 SmallVector<Value> buffers;
1196 for (Value val : invariants) {
1197 Type tp = val.getType();
1198 if (val.getDefiningOp<arith::ConstantOp>())
1199 constants.push_back(Elt: val);
1200 else if (isa<FloatType>(Val: tp) || tp.isIntOrIndex())
1201 scalars.push_back(Elt: val);
1202 else if (isa<MemRefType>(Val: tp))
1203 buffers.push_back(Elt: val);
1204 else
1205 return failure(); // don't know how to share
1206 }
1207 // Pass outlined non-constant values.
1208 // TODO: Experiment with `useHostRegistrationForOut` to see if we want to
1209 // keep the feature at all (either through a heuristic or compiler
1210 // option for gpu codegen).
1211 Location loc = forallOp->getLoc();
1212 SmallVector<Value> args;
1213 SmallVector<Value> tokens;
1214 Value out = genParametersIn(builder&: rewriter, loc, scalars, buffers, args, tokens,
1215 /*useHostRegistrationForOut=*/false);
1216 // Set up GPU module and construct GPU function.
1217 auto saveIp = rewriter.saveInsertionPoint();
1218 ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
1219 auto gpuModule = genGPUModule(rewriter, topModule);
1220 auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
1221 genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
1222 // Generate code that launches the kernel asynchronously, blocking on all
1223 // opens tokens and yielding a new token for the output.
1224 // TODO: Passing in tokens to launch up does not seem to be properly lowered
1225 // by cubin yet, hence the current blocking wait.
1226 rewriter.restoreInsertionPoint(ip: saveIp);
1227 genBlockingWait(builder&: rewriter, loc, operands: tokens);
1228 tokens.clear();
1229 Value kernelToken =
1230 genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
1231 // Finalize the outlined arguments.
1232 genParametersOut(builder&: rewriter, loc, out, kernelToken, scalars, buffers, args,
1233 tokens);
1234 genBlockingWait(builder&: rewriter, loc, operands: tokens);
1235 rewriter.eraseOp(op: forallOp);
1236 return success();
1237 }
1238
1239private:
1240 unsigned numThreads;
1241};
1242
1243//===----------------------------------------------------------------------===//
1244// Rewriting rules for library recognition and code generation.
1245//===----------------------------------------------------------------------===//
1246
1247/// Proof-of-concept rewriter. This rule recognizes certain math kernels
1248/// and replaces these with corresponding calls into a sparse library.
1249struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
1250 using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
1251
1252 LinalgOpRewriter(MLIRContext *context, bool rt)
1253 : OpRewritePattern(context), enableRT(rt) {}
1254
1255 LogicalResult matchAndRewrite(linalg::GenericOp op,
1256 PatternRewriter &rewriter) const override {
1257 if (op.getNumDpsInits() != 1)
1258 return failure(); // reject multi-output
1259
1260 const unsigned numLoops = op.getNumLoops();
1261 const unsigned numTensors = op->getNumOperands();
1262 const auto iteratorTypes = op.getIteratorTypesArray();
1263 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1264
1265 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1266 auto infer = [&](MapList m) {
1267 return AffineMap::inferFromExprList(m, op.getContext());
1268 };
1269 AffineExpr i, j, k;
1270 bindDims(getContext(), i, j, k);
1271
1272 // TODO: more robust patterns, tranposed versions, more kernels,
1273 // identify alpha and beta and pass them to the CUDA calls.
1274
1275 // Recognize a SpMV kernel.
1276 if (numLoops == 2 && numTensors == 3 &&
1277 linalg::isParallelIterator(iteratorTypes[0]) &&
1278 linalg::isReductionIterator(iteratorTypes[1]) &&
1279 maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
1280 return rewriteSpMV(rewriter, op, enableRT);
1281 }
1282
1283 // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
1284 if (numLoops == 3 && numTensors == 3 &&
1285 linalg::isParallelIterator(iteratorTypes[0]) &&
1286 linalg::isParallelIterator(iteratorTypes[1]) &&
1287 linalg::isReductionIterator(iteratorTypes[2]) &&
1288 maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
1289 if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
1290 return rewriteSpGEMM(rewriter, op, enableRT);
1291 if (isConversionInto24(op.getOperand(0)))
1292 return rewrite2To4SpMM(rewriter, op);
1293 return rewriteSpMM(rewriter, op, enableRT);
1294 }
1295
1296 // Recognize a SDDMM kernel.
1297 if (numLoops == 3 && numTensors == 3 &&
1298 linalg::isParallelIterator(iteratorTypes[0]) &&
1299 linalg::isParallelIterator(iteratorTypes[1]) &&
1300 linalg::isReductionIterator(iteratorTypes[2]) &&
1301 maps == infer({{i, k}, {k, j}, {i, j}}) &&
1302 matchSumReductionOfMulUnary(op)) {
1303 return rewriteSDDMM(rewriter, op, enableRT);
1304 }
1305
1306 return failure();
1307 }
1308
1309private:
1310 bool enableRT;
1311};
1312
1313} // namespace
1314
1315//===----------------------------------------------------------------------===//
1316// Public method for populating GPU rewriting rules.
1317//
1318// Currently two set of rewriting rules are made available. The first set
1319// implements direct code generation, currently by means of convering the
1320// outermost paralell loop into GPU threads. The second set implements
1321// libary recognition of a set of sparse operations. Eventually, the right
1322// combination of these two approaches has to be found.
1323//===----------------------------------------------------------------------===//
1324
1325void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
1326 unsigned numThreads) {
1327 patterns.add<ForallRewriter>(arg: patterns.getContext(), args&: numThreads);
1328}
1329
1330void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
1331 bool enableRT) {
1332 patterns.add<LinalgOpRewriter>(arg: patterns.getContext(), args&: enableRT);
1333}
1334

source code of mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp