1 | //===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This is a prototype GPU codegenerator for the sparsifier. |
10 | // The objective is to eventually use the right combination of |
11 | // direct code generation and libary calls into vendor-specific |
12 | // highly optimized sparse libraries (e.g. cuSparse for CUDA). |
13 | // |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "Utils/CodegenUtils.h" |
17 | #include "Utils/LoopEmitter.h" |
18 | |
19 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
20 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
21 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
22 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
25 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
26 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
27 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
28 | #include "mlir/IR/IRMapping.h" |
29 | #include "mlir/IR/Matchers.h" |
30 | |
31 | using namespace mlir; |
32 | using namespace mlir::sparse_tensor; |
33 | |
34 | namespace { |
35 | |
36 | // Sparse formats supported by cuSparse. |
37 | enum class CuSparseFormat { |
38 | kNone, |
39 | kCOO, |
40 | kCSR, |
41 | kCSC, |
42 | kBSR, |
43 | }; |
44 | |
45 | //===----------------------------------------------------------------------===// |
46 | // Helper methods. |
47 | //===----------------------------------------------------------------------===// |
48 | |
49 | /// Marks the given top module as a GPU container module. |
50 | static void markAsGPUContainer(ModuleOp topModule) { |
51 | topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), |
52 | UnitAttr::get(topModule->getContext())); |
53 | } |
54 | |
55 | /// Constructs a new GPU module (for GPU kernels) inside the given top module, |
56 | /// or returns an existing GPU module if one was built previously. |
57 | static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) { |
58 | for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>()) |
59 | return op; // existing |
60 | markAsGPUContainer(topModule); |
61 | builder.setInsertionPointToStart(&topModule.getBodyRegion().front()); |
62 | return builder.create<gpu::GPUModuleOp>(topModule->getLoc(), |
63 | "sparse_kernels" ); |
64 | } |
65 | |
66 | /// Constructs a new GPU kernel in the given GPU module. |
67 | static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule, |
68 | SmallVectorImpl<Value> &args) { |
69 | // Get a unique kernel name. Not very creative, |
70 | // but we simply try kernel0, kernel1, etc. |
71 | unsigned kernelNumber = 0; |
72 | SmallString<16> kernelName; |
73 | do { |
74 | kernelName.clear(); |
75 | ("kernel" + Twine(kernelNumber++)).toStringRef(Out&: kernelName); |
76 | } while (gpuModule.lookupSymbol(kernelName)); |
77 | // Then we insert a new kernel with given arguments into the module. |
78 | builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front()); |
79 | SmallVector<Type> argsTp; |
80 | for (auto arg : args) |
81 | argsTp.push_back(Elt: arg.getType()); |
82 | FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {}); |
83 | auto gpuFunc = |
84 | builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type); |
85 | gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), |
86 | builder.getUnitAttr()); |
87 | return gpuFunc; |
88 | } |
89 | |
90 | /// Constructs code to launch GPU kernel. |
91 | static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, |
92 | SmallVectorImpl<Value> &args, |
93 | SmallVectorImpl<Value> &tokens, |
94 | unsigned numThreads) { |
95 | Location loc = gpuFunc->getLoc(); |
96 | Value none = TypedValue<::mlir::IntegerType>{}; |
97 | Value one = constantIndex(builder, loc, i: 1); |
98 | Value numT = constantIndex(builder, loc, i: numThreads); |
99 | gpu::KernelDim3 gridSize = {.x: one, .y: one, .z: one}; |
100 | gpu::KernelDim3 blckSize = {.x: numT, .y: one, .z: one}; |
101 | return builder |
102 | .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize, |
103 | /*dynSharedMemSz*/ none, args, |
104 | builder.getType<gpu::AsyncTokenType>(), tokens) |
105 | .getAsyncToken(); |
106 | } |
107 | |
108 | /// Maps the provided ranked host buffer into the device address space. |
109 | /// Writes from the host are guaranteed to be visible to device kernels |
110 | /// that are launched afterwards. Writes from the device are guaranteed |
111 | /// to be visible on the host after synchronizing with the device kernel |
112 | /// completion. Needs to cast the buffer to a unranked buffer. |
113 | static Value genHostRegisterMemref(OpBuilder &builder, Location loc, |
114 | Value mem) { |
115 | MemRefType memTp = cast<MemRefType>(mem.getType()); |
116 | UnrankedMemRefType resTp = |
117 | UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); |
118 | Value cast = builder.create<memref::CastOp>(loc, resTp, mem); |
119 | builder.create<gpu::HostRegisterOp>(loc, cast); |
120 | return cast; |
121 | } |
122 | |
123 | /// Unmaps the provided buffer, expecting the casted buffer. |
124 | static void genHostUnregisterMemref(OpBuilder &builder, Location loc, |
125 | Value cast) { |
126 | builder.create<gpu::HostUnregisterOp>(loc, cast); |
127 | } |
128 | |
129 | /// Generates first wait in an asynchronous chain. |
130 | static Value genFirstWait(OpBuilder &builder, Location loc) { |
131 | Type tokenType = builder.getType<gpu::AsyncTokenType>(); |
132 | return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange()) |
133 | .getAsyncToken(); |
134 | } |
135 | |
136 | /// Generates last, blocking wait in an asynchronous chain. |
137 | static void genBlockingWait(OpBuilder &builder, Location loc, |
138 | ValueRange operands) { |
139 | builder.create<gpu::WaitOp>(loc, Type(), operands); |
140 | } |
141 | |
142 | /// Allocates memory on the device. |
143 | /// TODO: A `host_shared` attribute could be used to indicate that |
144 | /// the buffer is visible by both host and device, but lowering |
145 | /// that feature does not seem to be fully supported yet. |
146 | static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, |
147 | Value token) { |
148 | auto tp = cast<ShapedType>(mem.getType()); |
149 | auto elemTp = tp.getElementType(); |
150 | auto shape = tp.getShape(); |
151 | auto memTp = MemRefType::get(shape, elemTp); |
152 | SmallVector<Value> dynamicSizes; |
153 | for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) { |
154 | if (shape[r] == ShapedType::kDynamic) { |
155 | Value dimOp = linalg::createOrFoldDimOp(b&: builder, loc, val: mem, dim: r); |
156 | dynamicSizes.push_back(Elt: dimOp); |
157 | } |
158 | } |
159 | return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), |
160 | token, dynamicSizes, ValueRange()); |
161 | } |
162 | |
163 | // Allocates a typed buffer on the host with given size. |
164 | static Value genHostBuffer(OpBuilder &builder, Location loc, Type type, |
165 | Value size) { |
166 | const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); |
167 | return builder.create<memref::AllocOp>(loc, memTp, size).getResult(); |
168 | } |
169 | |
170 | // Allocates a typed buffer on the device with given size. |
171 | static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type, |
172 | Value size, Value token) { |
173 | const auto memTp = MemRefType::get({ShapedType::kDynamic}, type); |
174 | return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}), |
175 | token, size, ValueRange()); |
176 | } |
177 | |
178 | // Allocates a void buffer on the device with given size. |
179 | static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size, |
180 | Value token) { |
181 | return genAllocBuffer(builder, loc, builder.getI8Type(), size, token); |
182 | } |
183 | |
184 | /// Deallocates memory from the device. |
185 | static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem, |
186 | Value token) { |
187 | return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem) |
188 | .getAsyncToken(); |
189 | } |
190 | |
191 | /// Copies memory between host and device (direction is implicit). |
192 | static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst, |
193 | Value src, Value token) { |
194 | return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src) |
195 | .getAsyncToken(); |
196 | } |
197 | |
198 | /// Generates an alloc/copy pair. |
199 | static Value genAllocCopy(OpBuilder &builder, Location loc, Value b, |
200 | SmallVectorImpl<Value> &tokens) { |
201 | Value firstToken = genFirstWait(builder, loc); |
202 | auto alloc = genAllocMemRef(builder, loc, b, firstToken); |
203 | Value devMem = alloc.getResult(0); |
204 | Value depToken = alloc.getAsyncToken(); // copy-after-alloc |
205 | tokens.push_back(Elt: genCopyMemRef(builder, loc, dst: devMem, src: b, token: depToken)); |
206 | return devMem; |
207 | } |
208 | |
209 | /// Generates a memref from tensor operation. |
210 | static Value genTensorToMemref(PatternRewriter &rewriter, Location loc, |
211 | Value tensor) { |
212 | auto tensorType = llvm::cast<ShapedType>(tensor.getType()); |
213 | auto memrefType = |
214 | MemRefType::get(tensorType.getShape(), tensorType.getElementType()); |
215 | return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor); |
216 | } |
217 | |
218 | /// Prepares the outlined arguments, passing scalars and buffers in. Here we |
219 | /// assume that the first buffer is the one allocated for output. We create |
220 | /// a set of properly chained asynchronous allocation/copy pairs to increase |
221 | /// overlap before launching the kernel. |
222 | static Value genParametersIn(OpBuilder &builder, Location loc, |
223 | SmallVectorImpl<Value> &scalars, |
224 | SmallVectorImpl<Value> &buffers, |
225 | SmallVectorImpl<Value> &args, |
226 | SmallVectorImpl<Value> &tokens, |
227 | bool useHostRegistrationForOut) { |
228 | Value out; |
229 | // Scalars are passed by value. |
230 | for (Value s : scalars) |
231 | args.push_back(Elt: s); |
232 | // Buffers are need to be made visible on device. |
233 | for (Value b : buffers) { |
234 | if (useHostRegistrationForOut) { |
235 | out = genHostRegisterMemref(builder, loc, mem: b); |
236 | args.push_back(Elt: b); |
237 | useHostRegistrationForOut = false; |
238 | continue; |
239 | } |
240 | args.push_back(Elt: genAllocCopy(builder, loc, b, tokens)); |
241 | } |
242 | return out; |
243 | } |
244 | |
245 | /// Finalizes the outlined arguments. The output buffer is copied depending |
246 | /// on the kernel token and then deallocated. All other buffers are simply |
247 | /// deallocated. Then we wait for all operations to complete. |
248 | static void genParametersOut(OpBuilder &builder, Location loc, Value out, |
249 | Value kernelToken, SmallVectorImpl<Value> &scalars, |
250 | SmallVectorImpl<Value> &buffers, |
251 | SmallVectorImpl<Value> &args, |
252 | SmallVectorImpl<Value> &tokens) { |
253 | unsigned base = scalars.size(); |
254 | for (unsigned i = base, e = args.size(); i < e; i++) { |
255 | Value firstToken; |
256 | if (i == base) { |
257 | // Assumed output parameter: unregister or copy-out. |
258 | if (out) { |
259 | genHostUnregisterMemref(builder, loc, cast: out); |
260 | out = Value(); |
261 | continue; |
262 | } |
263 | firstToken = |
264 | genCopyMemRef(builder, loc, dst: buffers[0], src: args[i], token: kernelToken); |
265 | } else { |
266 | firstToken = genFirstWait(builder, loc); |
267 | } |
268 | tokens.push_back(Elt: genDeallocMemRef(builder, loc, mem: args[i], token: firstToken)); |
269 | } |
270 | } |
271 | |
272 | /// Constructs code for new GPU kernel. |
273 | static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, |
274 | scf::ParallelOp forallOp, |
275 | SmallVectorImpl<Value> &constants, |
276 | SmallVectorImpl<Value> &scalars, |
277 | SmallVectorImpl<Value> &buffers) { |
278 | Location loc = gpuFunc->getLoc(); |
279 | Block &block = gpuFunc.getBody().front(); |
280 | rewriter.setInsertionPointToStart(&block); |
281 | |
282 | // Re-generate the constants, recapture all arguments. |
283 | unsigned arg = 0; |
284 | IRMapping irMap; |
285 | for (Value c : constants) |
286 | irMap.map(from: c, to: rewriter.clone(op&: *c.getDefiningOp())->getResult(idx: 0)); |
287 | for (Value s : scalars) |
288 | irMap.map(from: s, to: block.getArgument(i: arg++)); |
289 | for (Value b : buffers) |
290 | irMap.map(from: b, to: block.getArgument(i: arg++)); |
291 | |
292 | // Assume 1-dimensional grid/block configuration (only x dimension), |
293 | // so that: |
294 | // row = blockIdx.x * blockDim.x + threadIdx.x |
295 | // inc = blockDim.x * gridDim.x |
296 | Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x); |
297 | Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x); |
298 | Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); |
299 | Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x); |
300 | Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz); |
301 | Value row = rewriter.create<arith::AddIOp>(loc, mul, tid); |
302 | Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz); |
303 | |
304 | // Construct the iteration over the computational space that |
305 | // accounts for the fact that the total number of threads and |
306 | // the amount of work to be done usually do not match precisely. |
307 | // for (r = row; r < N; r += inc) { |
308 | // <loop-body> |
309 | // } |
310 | Value upper = irMap.lookup(forallOp.getUpperBound()[0]); |
311 | scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc); |
312 | // The scf.for builder creates an empty block. scf.for does not allow multiple |
313 | // blocks in its region, so delete the block before `cloneRegionBefore` adds |
314 | // an additional block. |
315 | rewriter.eraseBlock(block: forOp.getBody()); |
316 | rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(), |
317 | forOp.getRegion().begin(), irMap); |
318 | // Replace the scf.reduce terminator. |
319 | rewriter.setInsertionPoint(forOp.getBody()->getTerminator()); |
320 | rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator()); |
321 | |
322 | // Done. |
323 | rewriter.setInsertionPointAfter(forOp); |
324 | rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc()); |
325 | } |
326 | |
327 | //===----------------------------------------------------------------------===// |
328 | // Library helper methods. |
329 | //===----------------------------------------------------------------------===// |
330 | |
331 | /// Helper to detect a + b with arguments taken from given block. |
332 | static bool matchAddOfArgs(Block *block, Value val) { |
333 | if (auto *def = val.getDefiningOp()) { |
334 | if (isa<arith::AddFOp, arith::AddIOp>(def)) { |
335 | Value a = block->getArguments()[0]; |
336 | Value b = block->getArguments()[1]; |
337 | return (def->getOperand(idx: 0) == a && def->getOperand(idx: 1) == b) || |
338 | (def->getOperand(idx: 0) == b && def->getOperand(idx: 1) == a); |
339 | } |
340 | } |
341 | return false; |
342 | } |
343 | |
344 | /// Helper to detect a * b with arguments taken from given block. |
345 | static bool matchMulOfArgs(Block *block, Value val) { |
346 | if (auto *def = val.getDefiningOp()) { |
347 | if (isa<arith::MulFOp, arith::MulIOp>(def)) { |
348 | Value a = block->getArguments()[0]; |
349 | Value b = block->getArguments()[1]; |
350 | return (def->getOperand(idx: 0) == a && def->getOperand(idx: 1) == b) || |
351 | (def->getOperand(idx: 0) == b && def->getOperand(idx: 1) == a); |
352 | } |
353 | } |
354 | return false; |
355 | } |
356 | |
357 | /// Helper to detect x = x + a * b |
358 | static bool matchSumOfMultOfArgs(linalg::GenericOp op) { |
359 | auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); |
360 | if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { |
361 | if (isa<arith::AddFOp, arith::AddIOp>(def)) { |
362 | Value x = op.getBlock()->getArguments()[2]; |
363 | return (def->getOperand(0) == x && |
364 | matchMulOfArgs(op.getBlock(), def->getOperand(1))) || |
365 | (def->getOperand(1) == x && |
366 | matchMulOfArgs(op.getBlock(), def->getOperand(0))); |
367 | } |
368 | } |
369 | return false; |
370 | } |
371 | |
372 | // Helper to detect c += spy(s) x (a * b) |
373 | static bool matchSumReductionOfMulUnary(linalg::GenericOp op) { |
374 | auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); |
375 | // The linalg yields a custom reduce result. |
376 | Value s_out = op.getBlock()->getArguments()[2]; |
377 | if (auto redOp = |
378 | yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) { |
379 | // The reduce consumes the output. |
380 | Value other; |
381 | if (s_out == redOp->getOperand(0)) |
382 | other = redOp->getOperand(1); |
383 | else if (s_out == redOp->getOperand(1)) |
384 | other = redOp->getOperand(0); |
385 | else |
386 | return false; |
387 | // The reduce op also consumes an unary which also consumes the output |
388 | // and does not define an absent value. |
389 | if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) { |
390 | if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty()) |
391 | return false; |
392 | // And the bodies are as expected. |
393 | auto yieldUn = cast<sparse_tensor::YieldOp>( |
394 | unOp.getRegion(0).front().getTerminator()); |
395 | auto yieldRed = cast<sparse_tensor::YieldOp>( |
396 | redOp.getRegion().front().getTerminator()); |
397 | return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) && |
398 | matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0)); |
399 | } |
400 | } |
401 | return false; |
402 | } |
403 | |
404 | /// Test for dense tensor. |
405 | static bool isDenseTensor(Value v) { |
406 | auto sTp = getSparseTensorType(val: v); |
407 | return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense(); |
408 | } |
409 | |
410 | /// Test for suitable positions/coordinates width. |
411 | static bool isAdmissibleMetaData(SparseTensorType &aTp) { |
412 | return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) && |
413 | (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16); |
414 | } |
415 | |
416 | /// Test for sorted COO matrix with suitable metadata. |
417 | static bool isAdmissibleCOO(SparseTensorType &aTp) { |
418 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && |
419 | aTp.isCompressedLvl(l: 0) && aTp.isOrderedLvl(l: 0) && !aTp.isUniqueLvl(l: 0) && |
420 | aTp.isSingletonLvl(l: 1) && aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) && |
421 | isAdmissibleMetaData(aTp); |
422 | } |
423 | |
424 | /// Test for CSR matrix with suitable metadata. |
425 | static bool isAdmissibleCSR(SparseTensorType &aTp) { |
426 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() && |
427 | aTp.isDenseLvl(l: 0) && aTp.isCompressedLvl(l: 1) && aTp.isOrderedLvl(l: 1) && |
428 | aTp.isUniqueLvl(l: 1) && isAdmissibleMetaData(aTp); |
429 | } |
430 | |
431 | /// Test for CSC matrix with suitable metadata. |
432 | static bool isAdmissibleCSC(SparseTensorType &aTp) { |
433 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() && |
434 | aTp.isPermutation() && aTp.isDenseLvl(l: 0) && aTp.isCompressedLvl(l: 1) && |
435 | aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) && isAdmissibleMetaData(aTp); |
436 | } |
437 | |
438 | /// Test for BSR matrix with suitable metadata. |
439 | static bool isAdmissibleBSR(SparseTensorType &aTp) { |
440 | if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(l: 0) && |
441 | aTp.isCompressedLvl(l: 1) && aTp.isOrderedLvl(l: 1) && aTp.isUniqueLvl(l: 1) && |
442 | aTp.isDenseLvl(l: 2) && aTp.isDenseLvl(l: 3) && isAdmissibleMetaData(aTp)) { |
443 | // CuSparse only supports "square" blocks currently. |
444 | SmallVector<unsigned> dims = getBlockSize(dimToLvl: aTp.getDimToLvl()); |
445 | assert(dims.size() == 2); |
446 | return dims[0] == dims[1] && dims[0] > 1; |
447 | } |
448 | return false; |
449 | } |
450 | |
451 | /// Test for 2:4 matrix with suitable metadata. |
452 | static bool isAdmissible24(SparseTensorType &aTp) { |
453 | return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(l: 0) && |
454 | aTp.isDenseLvl(l: 1) && aTp.isNOutOfMLvl(l: 2) && isAdmissibleMetaData(aTp); |
455 | } |
456 | |
457 | /// Test for conversion into 2:4 matrix. |
458 | static bool isConversionInto24(Value v) { |
459 | if (auto cnv = v.getDefiningOp<ConvertOp>()) { |
460 | Value a = cnv.getResult(); |
461 | Value d = cnv.getSource(); |
462 | SparseTensorType aTp = getSparseTensorType(val: a); |
463 | return isDenseTensor(v: d) && isAdmissible24(aTp); |
464 | } |
465 | return false; |
466 | } |
467 | |
468 | /// Returns a suitable sparse format for the operation and given operand |
469 | /// types with cuSparse, or kNone if none is available. |
470 | static CuSparseFormat getCuSparseFormat(SparseTensorType aTp, |
471 | SparseTensorType bTp, |
472 | SparseTensorType cTp, bool enableRT, |
473 | bool isMatVec) { |
474 | // The other operands have a dense type. |
475 | if (bTp.hasEncoding() || cTp.hasEncoding()) |
476 | return CuSparseFormat::kNone; |
477 | // Now check for suitable operand type for the main operand. |
478 | if (isAdmissibleCOO(aTp)) |
479 | #ifdef CUSPARSE_COO_AOS |
480 | return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone; |
481 | #else |
482 | return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone; |
483 | #endif |
484 | if (isAdmissibleCSR(aTp)) |
485 | return CuSparseFormat::kCSR; |
486 | if (isAdmissibleCSC(aTp)) |
487 | return CuSparseFormat::kCSC; |
488 | if (isAdmissibleBSR(aTp)) |
489 | return CuSparseFormat::kBSR; |
490 | return CuSparseFormat::kNone; |
491 | } |
492 | |
493 | /// Generates the first positions/coordinates of a sparse matrix. |
494 | static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a, |
495 | CuSparseFormat format, bool enableRT) { |
496 | if (format == CuSparseFormat::kCOO) { |
497 | // Library uses SoA COO, direct IR uses AoS COO. |
498 | if (enableRT) |
499 | return builder.create<ToCoordinatesOp>(loc, a, 0); |
500 | return builder.create<ToCoordinatesBufferOp>(loc, a); |
501 | } |
502 | // Formats CSR/CSC and BSR use positions at 1. |
503 | return builder.create<ToPositionsOp>(loc, a, 1); |
504 | } |
505 | |
506 | /// Generates the second coordinates of a sparse matrix. |
507 | static Value genSecondCrds(OpBuilder &builder, Location loc, Value a, |
508 | CuSparseFormat format, bool enableRT) { |
509 | bool isCOO = format == CuSparseFormat::kCOO; |
510 | if (isCOO && !enableRT) |
511 | return Value(); // nothing needed |
512 | // Formats CSR/CSC and BSR use coordinates at 1. |
513 | return builder.create<ToCoordinatesOp>(loc, a, 1); |
514 | } |
515 | |
516 | /// Generates the sparse matrix handle. |
517 | static Operation *genSpMat(OpBuilder &builder, Location loc, |
518 | SparseTensorType &aTp, Type handleTp, Type tokenTp, |
519 | Value token, Value sz1, Value sz2, Value nseA, |
520 | Value rowA, Value colA, Value valA, |
521 | CuSparseFormat format, bool enableRT) { |
522 | if (format == CuSparseFormat::kCOO) { |
523 | // Library uses SoA COO, direct IR uses AoS COO. |
524 | if (enableRT) { |
525 | assert(colA); |
526 | return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token, |
527 | sz1, sz2, nseA, rowA, colA, valA); |
528 | } |
529 | #ifdef CUSPARSE_COO_AOS |
530 | assert(!colA); |
531 | return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token, |
532 | sz1, sz2, nseA, rowA, valA); |
533 | #else |
534 | llvm_unreachable("gpu::CreateCooAoSOp is deprecated" ); |
535 | #endif |
536 | } |
537 | assert(colA); |
538 | if (format == CuSparseFormat::kCSR) |
539 | return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1, |
540 | sz2, nseA, rowA, colA, valA); |
541 | if (format == CuSparseFormat::kCSC) |
542 | return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1, |
543 | sz2, nseA, rowA, colA, valA); |
544 | // BSR requires a bit more work since we need to pass in the block size |
545 | // and all others sizes in terms of blocks (#block-rows, #block-cols, |
546 | // #nonzero-blocks). |
547 | assert(format == CuSparseFormat::kBSR); |
548 | SmallVector<unsigned> dims = getBlockSize(dimToLvl: aTp.getDimToLvl()); |
549 | assert(dims.size() == 2 && dims[0] == dims[1]); |
550 | uint64_t b = dims[0]; |
551 | Value bSz = constantIndex(builder, loc, i: b); |
552 | Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz); |
553 | Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz); |
554 | Value bNum = builder.create<arith::DivUIOp>( |
555 | loc, nseA, constantIndex(builder, loc, b * b)); |
556 | return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows, |
557 | bCols, bNum, bSz, bSz, rowA, colA, |
558 | valA); |
559 | } |
560 | |
561 | /// Match and rewrite SpMV kernel. |
562 | static LogicalResult rewriteSpMV(PatternRewriter &rewriter, |
563 | linalg::GenericOp op, bool enableRT) { |
564 | Location loc = op.getLoc(); |
565 | Value a = op.getOperand(0); |
566 | Value x = op.getOperand(1); |
567 | Value y = op.getOperand(2); // we have y = Ax |
568 | SmallVector<Value> tokens; |
569 | |
570 | // Only admissible sparse matrix format and dense vectors (no BSR). |
571 | SparseTensorType aTp = getSparseTensorType(val: a); |
572 | SparseTensorType xTp = getSparseTensorType(val: x); |
573 | SparseTensorType yTp = getSparseTensorType(val: y); |
574 | auto format = getCuSparseFormat(aTp, bTp: xTp, cTp: yTp, enableRT, /*isMatVec=*/true); |
575 | if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) |
576 | return failure(); |
577 | |
578 | // Start sparse kernel and copy data from host to device. |
579 | // a : memR/memC/memV -> rowA,colA,valA |
580 | // x : memX -> vecX |
581 | // y : memY -> vecY |
582 | Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); |
583 | Value szY = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
584 | Value szX = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
585 | Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); |
586 | Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty |
587 | Value memV = rewriter.create<ToValuesOp>(loc, a); |
588 | Value rowA = genAllocCopy(builder&: rewriter, loc, b: memR, tokens); |
589 | Value colA = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value(); |
590 | Value valA = genAllocCopy(builder&: rewriter, loc, b: memV, tokens); |
591 | Value memX = genTensorToMemref(rewriter, loc, tensor: x); |
592 | Value vecX = genAllocCopy(builder&: rewriter, loc, b: memX, tokens); |
593 | Value memY = genTensorToMemref(rewriter, loc, tensor: y); |
594 | Value vecY = genAllocCopy(builder&: rewriter, loc, b: memY, tokens); |
595 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
596 | tokens.clear(); |
597 | |
598 | // Create sparse environment and sparse matrix/dense vector handles. |
599 | Type indexTp = rewriter.getIndexType(); |
600 | Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
601 | Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
602 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
603 | Value token = genFirstWait(builder&: rewriter, loc); |
604 | Operation *spGenA = |
605 | genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX, |
606 | nseA, rowA, colA, valA, format, enableRT); |
607 | Value spMatA = spGenA->getResult(idx: 0); |
608 | token = spGenA->getResult(idx: 1); |
609 | auto dvecX = rewriter.create<gpu::CreateDnTensorOp>( |
610 | loc, dnTensorHandleTp, tokenTp, token, vecX, szX); |
611 | Value dnX = dvecX.getResult(0); |
612 | token = dvecX.getAsyncToken(); |
613 | auto dvecY = rewriter.create<gpu::CreateDnTensorOp>( |
614 | loc, dnTensorHandleTp, tokenTp, token, vecY, szY); |
615 | Value dnY = dvecY.getResult(0); |
616 | token = dvecY.getAsyncToken(); |
617 | auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType(); |
618 | |
619 | // Precompute buffersize for SpMV. |
620 | auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>( |
621 | loc, indexTp, tokenTp, token, spMatA, dnX, dnY, |
622 | /*computeType=*/dnYType); |
623 | Value bufferSz = bufferComp.getResult(0); |
624 | token = bufferComp.getAsyncToken(); |
625 | auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); |
626 | Value buffer = buf.getResult(0); |
627 | token = buf.getAsyncToken(); |
628 | |
629 | // Perform the SpMV. |
630 | auto spmvComp = rewriter.create<gpu::SpMVOp>( |
631 | loc, tokenTp, token, spMatA, dnX, dnY, /*computeType=*/dnYType, buffer); |
632 | token = spmvComp.getAsyncToken(); |
633 | |
634 | // Copy data back to host and free all the resoures. |
635 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) |
636 | .getAsyncToken(); |
637 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX) |
638 | .getAsyncToken(); |
639 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY) |
640 | .getAsyncToken(); |
641 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token); |
642 | if (colA) |
643 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token); |
644 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token); |
645 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token); |
646 | token = genDeallocMemRef(builder&: rewriter, loc, mem: vecX, token); |
647 | token = genCopyMemRef(builder&: rewriter, loc, dst: memY, src: vecY, token); |
648 | token = genDeallocMemRef(builder&: rewriter, loc, mem: vecY, token); |
649 | tokens.push_back(Elt: token); |
650 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
651 | tokens.clear(); |
652 | |
653 | // Done. |
654 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY); |
655 | return success(); |
656 | } |
657 | |
658 | /// Match and rewrite SpMM kernel. |
659 | static LogicalResult rewriteSpMM(PatternRewriter &rewriter, |
660 | linalg::GenericOp op, bool enableRT) { |
661 | Location loc = op.getLoc(); |
662 | Value a = op.getOperand(0); |
663 | Value b = op.getOperand(1); |
664 | Value c = op.getOperand(2); // we have C = AB |
665 | SmallVector<Value> tokens; |
666 | |
667 | // Only admissible sparse matrix format and dense matrices (no BSR). |
668 | SparseTensorType aTp = getSparseTensorType(val: a); |
669 | SparseTensorType bTp = getSparseTensorType(val: b); |
670 | SparseTensorType cTp = getSparseTensorType(val: c); |
671 | auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false); |
672 | if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR) |
673 | return failure(); |
674 | |
675 | // Start sparse kernel and copy data from host to device. |
676 | // a : memR/memC/memV -> rowA,colA,valA |
677 | // b : bufB -> matB |
678 | // c : bufC -> matC |
679 | Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); |
680 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
681 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
682 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1); |
683 | Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT); |
684 | Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty |
685 | Value memV = rewriter.create<ToValuesOp>(loc, a); |
686 | Value rowA = genAllocCopy(builder&: rewriter, loc, b: memR, tokens); |
687 | Value colA = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value(); |
688 | Value valA = genAllocCopy(builder&: rewriter, loc, b: memV, tokens); |
689 | Value bufB = genTensorToMemref(rewriter, loc, tensor: b); |
690 | Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens); |
691 | Value bufC = genTensorToMemref(rewriter, loc, tensor: c); |
692 | Value matC = genAllocCopy(builder&: rewriter, loc, b: bufC, tokens); |
693 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
694 | tokens.clear(); |
695 | |
696 | // Create sparse environment and sparse matrix/dense matrix handles. |
697 | Type indexTp = rewriter.getIndexType(); |
698 | Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
699 | Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
700 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
701 | Value token = genFirstWait(builder&: rewriter, loc); |
702 | Operation *spGenA = |
703 | genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk, |
704 | nseA, rowA, colA, valA, format, enableRT); |
705 | Value spMatA = spGenA->getResult(idx: 0); |
706 | token = spGenA->getResult(idx: 1); |
707 | auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( |
708 | loc, dnTensorHandleTp, tokenTp, token, matB, |
709 | SmallVector<Value>{szk, szn}); |
710 | Value dnB = dmatB.getResult(0); |
711 | token = dmatB.getAsyncToken(); |
712 | auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( |
713 | loc, dnTensorHandleTp, tokenTp, token, matC, |
714 | SmallVector<Value>{szm, szn}); |
715 | Value dnC = dmatC.getResult(0); |
716 | token = dmatC.getAsyncToken(); |
717 | auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType(); |
718 | |
719 | // Precompute buffersize for SpMM. |
720 | auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( |
721 | loc, indexTp, tokenTp, token, spMatA, dnB, dnC, |
722 | /*computeType=*/dmatCType); |
723 | Value bufferSz = bufferComp.getResult(0); |
724 | token = bufferComp.getAsyncToken(); |
725 | auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); |
726 | Value buffer = buf.getResult(0); |
727 | token = buf.getAsyncToken(); |
728 | auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); |
729 | |
730 | // Perform the SpMM. |
731 | auto spmmComp = rewriter.create<gpu::SpMMOp>( |
732 | loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, buffer); |
733 | token = spmmComp.getAsyncToken(); |
734 | |
735 | // Copy data back to host and free all the resoures. |
736 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) |
737 | .getAsyncToken(); |
738 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) |
739 | .getAsyncToken(); |
740 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) |
741 | .getAsyncToken(); |
742 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token); |
743 | if (colA) |
744 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token); |
745 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token); |
746 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token); |
747 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token); |
748 | token = genCopyMemRef(builder&: rewriter, loc, dst: bufC, src: matC, token); |
749 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matC, token); |
750 | tokens.push_back(Elt: token); |
751 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
752 | tokens.clear(); |
753 | |
754 | // Done. |
755 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); |
756 | return success(); |
757 | } |
758 | |
759 | // Match and rewrite SpGEMM kernel. |
760 | static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter, |
761 | linalg::GenericOp op, bool enableRT) { |
762 | Location loc = op.getLoc(); |
763 | Value a = op.getOperand(0); |
764 | Value b = op.getOperand(1); |
765 | Value c = op.getOperand(2); // we have C = AB |
766 | SmallVector<Value> tokens; |
767 | |
768 | // Only CSR <- CSR x CSR supported. |
769 | auto format = CuSparseFormat::kCSR; |
770 | SparseTensorType aTp = getSparseTensorType(val: a); |
771 | SparseTensorType bTp = getSparseTensorType(val: b); |
772 | SparseTensorType cTp = getSparseTensorType(val: c); |
773 | if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(aTp&: bTp) || !isAdmissibleCSR(aTp&: cTp)) |
774 | return failure(); |
775 | |
776 | // Start sparse kernel and copy data from host to device. |
777 | // a : amemR/amemC/amemV -> rowA,colA,valA |
778 | // b : bmemR/bmemC/bmemV -> rowB,colB,valB |
779 | // c : materializes |
780 | auto dnCType = cTp.getElementType(); |
781 | Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a); |
782 | Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b); |
783 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
784 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
785 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1); |
786 | Value amemR = genFirstPosOrCrds(builder&: rewriter, loc, a, format, enableRT); |
787 | Value amemC = genSecondCrds(builder&: rewriter, loc, a, format, enableRT); // not empty |
788 | Value amemV = rewriter.create<ToValuesOp>(loc, a); |
789 | Value bmemR = genFirstPosOrCrds(builder&: rewriter, loc, a: b, format, enableRT); |
790 | Value bmemC = genSecondCrds(builder&: rewriter, loc, a: b, format, enableRT); // not empty |
791 | Value bmemV = rewriter.create<ToValuesOp>(loc, b); |
792 | Value rowA = genAllocCopy(builder&: rewriter, loc, b: amemR, tokens); |
793 | Value colA = genAllocCopy(builder&: rewriter, loc, b: amemC, tokens); |
794 | Value valA = genAllocCopy(builder&: rewriter, loc, b: amemV, tokens); |
795 | Value rowB = genAllocCopy(builder&: rewriter, loc, b: bmemR, tokens); |
796 | Value colB = genAllocCopy(builder&: rewriter, loc, b: bmemC, tokens); |
797 | Value valB = genAllocCopy(builder&: rewriter, loc, b: bmemV, tokens); |
798 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
799 | tokens.clear(); |
800 | |
801 | // Create sparse environment and sparse matrix/dense vector handles. |
802 | Type indexTp = rewriter.getIndexType(); |
803 | Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
804 | Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>(); |
805 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
806 | Value token = genFirstWait(builder&: rewriter, loc); |
807 | Operation *spGenA = |
808 | genSpMat(builder&: rewriter, loc, aTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szm, sz2: szk, |
809 | nseA, rowA, colA, valA, format, enableRT); |
810 | Value spMatA = spGenA->getResult(idx: 0); |
811 | token = spGenA->getResult(idx: 1); |
812 | Operation *spGenB = |
813 | genSpMat(builder&: rewriter, loc, aTp&: bTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szk, sz2: szn, |
814 | nseA: nseB, rowA: rowB, colA: colB, valA: valB, format, enableRT); |
815 | Value spMatB = spGenB->getResult(idx: 0); |
816 | token = spGenB->getResult(idx: 1); |
817 | |
818 | // Sparse matrix C materializes (also assumes beta == 0). |
819 | Value zero = constantIndex(builder&: rewriter, loc, i: 0); |
820 | Value one = constantIndex(builder&: rewriter, loc, i: 1); |
821 | Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one); |
822 | auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token); |
823 | Value rowC = e1.getResult(0); |
824 | token = e1.getAsyncToken(); |
825 | auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token); |
826 | Value colC = e2.getResult(0); // no free needed |
827 | token = e2.getAsyncToken(); |
828 | auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token); |
829 | Value valC = e3.getResult(0); // no free needed |
830 | token = e3.getAsyncToken(); |
831 | Operation *spGenC = |
832 | genSpMat(builder&: rewriter, loc, aTp&: cTp, handleTp: spmatHandleTp, tokenTp, token, sz1: szm, sz2: szn, |
833 | nseA: zero, rowA: rowC, colA: colC, valA: valC, format, enableRT); |
834 | Value spMatC = spGenC->getResult(idx: 0); |
835 | token = spGenC->getResult(idx: 1); |
836 | |
837 | // Precompute buffersizes for SpGEMM. |
838 | Operation *descOp = |
839 | rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token); |
840 | Value desc = descOp->getResult(idx: 0); |
841 | token = descOp->getResult(idx: 1); |
842 | Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
843 | loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, |
844 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, |
845 | valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); |
846 | Value bufferSz1 = work1->getResult(idx: 0); |
847 | token = work1->getResult(idx: 1); |
848 | auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); |
849 | Value buffer1 = buf1.getResult(0); |
850 | token = buf1.getAsyncToken(); |
851 | Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
852 | loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, |
853 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, |
854 | bufferSz1, buffer1, |
855 | gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION); |
856 | token = work2->getResult(idx: 1); |
857 | |
858 | // Compute step. |
859 | Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
860 | loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, |
861 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero, |
862 | valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); |
863 | Value bufferSz2 = compute1->getResult(idx: 0); |
864 | token = compute1->getResult(idx: 1); |
865 | auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); |
866 | Value buffer2 = buf2.getResult(0); |
867 | token = buf2.getAsyncToken(); |
868 | Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>( |
869 | loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, |
870 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, |
871 | bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE); |
872 | token = compute2->getResult(idx: 1); |
873 | |
874 | // Get sizes. |
875 | Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>( |
876 | loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC); |
877 | Value nnz = sizes->getResult(idx: 2); |
878 | token = sizes->getResult(idx: 3); |
879 | auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token); |
880 | colC = a2.getResult(0); |
881 | token = a2.getAsyncToken(); |
882 | auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token); |
883 | valC = a3.getResult(0); |
884 | token = a3.getAsyncToken(); |
885 | |
886 | // Update C with new pointers and copy final product back into C. |
887 | Operation *update = rewriter.create<gpu::SetCsrPointersOp>( |
888 | loc, tokenTp, token, spMatC, rowC, colC, valC); |
889 | token = update->getResult(idx: 0); |
890 | Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>( |
891 | loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE, |
892 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType); |
893 | token = copy->getResult(idx: 0); |
894 | |
895 | // Allocate buffers on host. |
896 | Value rowH = genHostBuffer(builder&: rewriter, loc, type: cTp.getPosType(), size: mplus1); |
897 | Value colH = genHostBuffer(builder&: rewriter, loc, type: cTp.getCrdType(), size: nnz); |
898 | Value valH = genHostBuffer(rewriter, loc, dnCType, nnz); |
899 | |
900 | // Copy data back to host and free all the resoures. |
901 | token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc) |
902 | .getAsyncToken(); |
903 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) |
904 | .getAsyncToken(); |
905 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB) |
906 | .getAsyncToken(); |
907 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) |
908 | .getAsyncToken(); |
909 | token = genCopyMemRef(builder&: rewriter, loc, dst: rowH, src: rowC, token); |
910 | token = genCopyMemRef(builder&: rewriter, loc, dst: colH, src: colC, token); |
911 | token = genCopyMemRef(builder&: rewriter, loc, dst: valH, src: valC, token); |
912 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowA, token); |
913 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colA, token); |
914 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valA, token); |
915 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowB, token); |
916 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colB, token); |
917 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valB, token); |
918 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowC, token); |
919 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colC, token); |
920 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valC, token); |
921 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer1, token); |
922 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer2, token); |
923 | tokens.push_back(Elt: token); |
924 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
925 | tokens.clear(); |
926 | |
927 | // Done. |
928 | Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH); |
929 | Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH); |
930 | Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH); |
931 | rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct}, |
932 | vt); |
933 | return success(); |
934 | } |
935 | |
936 | // Match and rewrite 2:4 SpMM kernel. |
937 | static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter, |
938 | linalg::GenericOp op) { |
939 | Location loc = op.getLoc(); |
940 | Value A = op.getOperand(0); |
941 | Value B = op.getOperand(1); |
942 | Value C = op.getOperand(2); // we have C = AB |
943 | SmallVector<Value> tokens; |
944 | |
945 | // The cuSparselt API currently only allows pruning and compression |
946 | // to occur on the device. So we recognize the pattern |
947 | // A' = convert A ; dense to 2:4 |
948 | // C = A'B ; 2:4 matrix mult |
949 | // and then perform compression and matrix multiplication on device. |
950 | auto cnv = A.getDefiningOp<ConvertOp>(); |
951 | assert(cnv); |
952 | A = cnv.getSource(); |
953 | |
954 | // All input should be dense tensors. |
955 | if (!isDenseTensor(v: A) || !isDenseTensor(v: B) || !isDenseTensor(v: C)) |
956 | return failure(); |
957 | |
958 | // Start sparse kernel and copy data from host to device. |
959 | // a : bufA -> matA |
960 | // b : bufB -> matB |
961 | // c : bufC -> matC |
962 | Value bufA = genTensorToMemref(rewriter, loc, tensor: A); |
963 | Value matA = genAllocCopy(builder&: rewriter, loc, b: bufA, tokens); |
964 | Value bufB = genTensorToMemref(rewriter, loc, tensor: B); |
965 | Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens); |
966 | Value bufC = genTensorToMemref(rewriter, loc, tensor: C); |
967 | Value matC = genAllocCopy(builder&: rewriter, loc, b: bufC, tokens); |
968 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
969 | tokens.clear(); |
970 | |
971 | // Create sparse environment and sparse matrix/dense vector handles. |
972 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matA, dim: 0); |
973 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matB, dim: 0); |
974 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: matC, dim: 1); |
975 | Type indexTp = rewriter.getIndexType(); |
976 | Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
977 | Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
978 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
979 | Value token = genFirstWait(builder&: rewriter, loc); |
980 | Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>( |
981 | loc, spMatHandleTp, tokenTp, token, szm, szk, |
982 | gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA); |
983 | Value spMatA = spGenA->getResult(idx: 0); |
984 | token = spGenA->getResult(idx: 1); |
985 | auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( |
986 | loc, dnTensorHandleTp, tokenTp, token, matB, |
987 | SmallVector<Value>{szk, szn}); |
988 | Value dnB = dmatB.getResult(0); |
989 | token = dmatB.getAsyncToken(); |
990 | auto dmatC = rewriter.create<gpu::CreateDnTensorOp>( |
991 | loc, dnTensorHandleTp, tokenTp, token, matC, |
992 | SmallVector<Value>{szm, szn}); |
993 | Value dnC = dmatC.getResult(0); |
994 | token = dmatC.getAsyncToken(); |
995 | auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); |
996 | |
997 | // Precompute buffersize for SpMM. |
998 | SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp}; |
999 | TypeRange bufferTypes(bufferTypes_); |
1000 | auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>( |
1001 | loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE, |
1002 | gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC, |
1003 | /*computeType=*/dmatCType); |
1004 | token = bufferComp.getAsyncToken(); |
1005 | |
1006 | // Allocate buffers on host. |
1007 | Value bufferSz1 = bufferComp.getResult(0); |
1008 | auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token); |
1009 | Value buffer1 = buf1.getResult(0); |
1010 | token = buf1.getAsyncToken(); |
1011 | Value bufferSz2 = bufferComp.getResult(1); |
1012 | auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token); |
1013 | Value buffer2 = buf2.getResult(0); |
1014 | token = buf2.getAsyncToken(); |
1015 | Value bufferSz3 = bufferComp.getResult(2); |
1016 | auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token); |
1017 | Value buffer3 = buf3.getResult(0); |
1018 | token = buf3.getAsyncToken(); |
1019 | |
1020 | // Perform the SpMM. |
1021 | auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType(); |
1022 | auto spmmComp = rewriter.create<gpu::SpMMOp>( |
1023 | loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType, |
1024 | SmallVector<Value>{buffer1, buffer2, buffer3}); |
1025 | token = spmmComp.getAsyncToken(); |
1026 | |
1027 | // Copy data back to host and free all the resources. |
1028 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA) |
1029 | .getAsyncToken(); |
1030 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) |
1031 | .getAsyncToken(); |
1032 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC) |
1033 | .getAsyncToken(); |
1034 | SmallVector<Value> newDynamicSizes; |
1035 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer1, token); |
1036 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer2, token); |
1037 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer3, token); |
1038 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matA, token); |
1039 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token); |
1040 | token = genCopyMemRef(builder&: rewriter, loc, dst: bufC, src: matC, token); |
1041 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matC, token); |
1042 | tokens.push_back(Elt: token); |
1043 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
1044 | tokens.clear(); |
1045 | |
1046 | // Done. |
1047 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC); |
1048 | return success(); |
1049 | } |
1050 | |
1051 | /// Match and rewrite SDDMM kernel. |
1052 | static LogicalResult rewriteSDDMM(PatternRewriter &rewriter, |
1053 | linalg::GenericOp op, bool enableRT) { |
1054 | Location loc = op.getLoc(); |
1055 | Value a = op.getOperand(0); |
1056 | Value b = op.getOperand(1); |
1057 | Value c = op.getOperand(2); |
1058 | SmallVector<Value> tokens; |
1059 | |
1060 | // Only admissible sparse matrix format (no COO/CSC) and dense matrices. |
1061 | SparseTensorType aTp = getSparseTensorType(val: a); |
1062 | SparseTensorType bTp = getSparseTensorType(val: b); |
1063 | SparseTensorType cTp = getSparseTensorType(val: c); |
1064 | auto format = getCuSparseFormat(aTp: cTp, bTp, cTp: aTp, enableRT, /*isMatVec=*/false); |
1065 | if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO || |
1066 | format == CuSparseFormat::kCSC) |
1067 | return failure(); |
1068 | |
1069 | // The SDDMM does the in-place operation. |
1070 | // Start sparse kernel and copy data from host to device. |
1071 | // a : bufA -> matA |
1072 | // b : bufB -> matB |
1073 | // c : memR/memC/memV -> rowC,colC,valC |
1074 | Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c); |
1075 | Value szm = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 0); |
1076 | Value szk = linalg::createOrFoldDimOp(b&: rewriter, loc, val: a, dim: 1); |
1077 | Value szn = linalg::createOrFoldDimOp(b&: rewriter, loc, val: b, dim: 1); |
1078 | Value bufA = genTensorToMemref(rewriter, loc, tensor: a); |
1079 | Value matA = genAllocCopy(builder&: rewriter, loc, b: bufA, tokens); |
1080 | Value bufB = genTensorToMemref(rewriter, loc, tensor: b); |
1081 | Value matB = genAllocCopy(builder&: rewriter, loc, b: bufB, tokens); |
1082 | Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT); |
1083 | Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty |
1084 | Value memV = rewriter.create<ToValuesOp>(loc, c); |
1085 | Value rowC = genAllocCopy(builder&: rewriter, loc, b: memR, tokens); |
1086 | Value colC = memC ? genAllocCopy(builder&: rewriter, loc, b: memC, tokens) : Value(); |
1087 | Value valC = genAllocCopy(builder&: rewriter, loc, b: memV, tokens); |
1088 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
1089 | tokens.clear(); |
1090 | |
1091 | // Create sparse environment and sparse matrix/dense matrix handles. |
1092 | Type indexTp = rewriter.getIndexType(); |
1093 | Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>(); |
1094 | Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>(); |
1095 | Type tokenTp = rewriter.getType<gpu::AsyncTokenType>(); |
1096 | Value token = genFirstWait(builder&: rewriter, loc); |
1097 | auto dmatA = rewriter.create<gpu::CreateDnTensorOp>( |
1098 | loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk}); |
1099 | Value dnA = dmatA.getResult(0); |
1100 | token = dmatA.getAsyncToken(); |
1101 | auto dmatB = rewriter.create<gpu::CreateDnTensorOp>( |
1102 | loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn}); |
1103 | Value dnB = dmatB.getResult(0); |
1104 | token = dmatB.getAsyncToken(); |
1105 | Operation *spGenC = |
1106 | genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn, |
1107 | nseC, rowC, colC, valC, format, enableRT); |
1108 | Value spMatC = spGenC->getResult(idx: 0); |
1109 | token = spGenC->getResult(idx: 1); |
1110 | auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType(); |
1111 | |
1112 | // Precompute buffersize for SDDMM. |
1113 | auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>( |
1114 | loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType); |
1115 | Value bufferSz = bufferComp.getResult(0); |
1116 | token = bufferComp.getAsyncToken(); |
1117 | auto buf = genAllocBuffer(rewriter, loc, bufferSz, token); |
1118 | Value buffer = buf.getResult(0); |
1119 | token = buf.getAsyncToken(); |
1120 | |
1121 | // Perform the SDDMM. |
1122 | auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB, |
1123 | spMatC, dnCType, buffer); |
1124 | token = sddmmComp.getAsyncToken(); |
1125 | |
1126 | // Copy data back to host and free all the resoures. |
1127 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA) |
1128 | .getAsyncToken(); |
1129 | token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB) |
1130 | .getAsyncToken(); |
1131 | token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC) |
1132 | .getAsyncToken(); |
1133 | token = genDeallocMemRef(builder&: rewriter, loc, mem: buffer, token); |
1134 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matA, token); |
1135 | token = genDeallocMemRef(builder&: rewriter, loc, mem: matB, token); |
1136 | token = genDeallocMemRef(builder&: rewriter, loc, mem: rowC, token); |
1137 | if (colC) |
1138 | token = genDeallocMemRef(builder&: rewriter, loc, mem: colC, token); |
1139 | token = genCopyMemRef(builder&: rewriter, loc, dst: memV, src: valC, token); |
1140 | token = genDeallocMemRef(builder&: rewriter, loc, mem: valC, token); |
1141 | tokens.push_back(Elt: token); |
1142 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
1143 | tokens.clear(); |
1144 | |
1145 | // Done. |
1146 | rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c); |
1147 | return success(); |
1148 | } |
1149 | |
1150 | //===----------------------------------------------------------------------===// |
1151 | // Rewriting rules for direct code generation. |
1152 | //===----------------------------------------------------------------------===// |
1153 | |
1154 | /// Proof-of-concept rewriter. This rule generates a GPU implementation |
1155 | /// for each outermost forall loop generated by the sparsifier. |
1156 | /// TODO: right now works with parallelization-strategy=dense-outer-loop |
1157 | /// but give this its own flags in the future |
1158 | struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> { |
1159 | using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; |
1160 | |
1161 | ForallRewriter(MLIRContext *context, unsigned nT) |
1162 | : OpRewritePattern(context), numThreads(nT){}; |
1163 | |
1164 | LogicalResult matchAndRewrite(scf::ParallelOp forallOp, |
1165 | PatternRewriter &rewriter) const override { |
1166 | // Reject inadmissible loop form. |
1167 | // Essentially only accept a loop, generated by the sparsifier, |
1168 | // of the form |
1169 | // forall (i = 0; i < N; i++) |
1170 | // so that cyclic scheduling over the threads is easy. |
1171 | if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) || |
1172 | forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 || |
1173 | !matchPattern(forallOp.getLowerBound()[0], m_Zero()) || |
1174 | !matchPattern(forallOp.getStep()[0], m_One())) |
1175 | return failure(); |
1176 | // Collect every value that is computed outside the parallel loop. |
1177 | SetVector<Value> invariants; // stable iteration! |
1178 | forallOp->walk([&](Operation *op) { |
1179 | // Collect all values of admissible ops. |
1180 | for (OpOperand &o : op->getOpOperands()) { |
1181 | Value val = o.get(); |
1182 | Block *block; |
1183 | if (auto arg = dyn_cast<BlockArgument>(Val&: val)) |
1184 | block = arg.getOwner(); |
1185 | else |
1186 | block = val.getDefiningOp()->getBlock(); |
1187 | if (!forallOp.getRegion().findAncestorBlockInRegion(*block)) |
1188 | invariants.insert(X: val); |
1189 | } |
1190 | }); |
1191 | // Outline the outside values as proper parameters. Fail when sharing |
1192 | // value between host and device is not straightforward. |
1193 | SmallVector<Value> constants; |
1194 | SmallVector<Value> scalars; |
1195 | SmallVector<Value> buffers; |
1196 | for (Value val : invariants) { |
1197 | Type tp = val.getType(); |
1198 | if (val.getDefiningOp<arith::ConstantOp>()) |
1199 | constants.push_back(Elt: val); |
1200 | else if (isa<FloatType>(Val: tp) || tp.isIntOrIndex()) |
1201 | scalars.push_back(Elt: val); |
1202 | else if (isa<MemRefType>(Val: tp)) |
1203 | buffers.push_back(Elt: val); |
1204 | else |
1205 | return failure(); // don't know how to share |
1206 | } |
1207 | // Pass outlined non-constant values. |
1208 | // TODO: Experiment with `useHostRegistrationForOut` to see if we want to |
1209 | // keep the feature at all (either through a heuristic or compiler |
1210 | // option for gpu codegen). |
1211 | Location loc = forallOp->getLoc(); |
1212 | SmallVector<Value> args; |
1213 | SmallVector<Value> tokens; |
1214 | Value out = genParametersIn(builder&: rewriter, loc, scalars, buffers, args, tokens, |
1215 | /*useHostRegistrationForOut=*/false); |
1216 | // Set up GPU module and construct GPU function. |
1217 | auto saveIp = rewriter.saveInsertionPoint(); |
1218 | ModuleOp topModule = forallOp->getParentOfType<ModuleOp>(); |
1219 | auto gpuModule = genGPUModule(rewriter, topModule); |
1220 | auto gpuFunc = genGPUFunc(rewriter, gpuModule, args); |
1221 | genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers); |
1222 | // Generate code that launches the kernel asynchronously, blocking on all |
1223 | // opens tokens and yielding a new token for the output. |
1224 | // TODO: Passing in tokens to launch up does not seem to be properly lowered |
1225 | // by cubin yet, hence the current blocking wait. |
1226 | rewriter.restoreInsertionPoint(ip: saveIp); |
1227 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
1228 | tokens.clear(); |
1229 | Value kernelToken = |
1230 | genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads); |
1231 | // Finalize the outlined arguments. |
1232 | genParametersOut(builder&: rewriter, loc, out, kernelToken, scalars, buffers, args, |
1233 | tokens); |
1234 | genBlockingWait(builder&: rewriter, loc, operands: tokens); |
1235 | rewriter.eraseOp(op: forallOp); |
1236 | return success(); |
1237 | } |
1238 | |
1239 | private: |
1240 | unsigned numThreads; |
1241 | }; |
1242 | |
1243 | //===----------------------------------------------------------------------===// |
1244 | // Rewriting rules for library recognition and code generation. |
1245 | //===----------------------------------------------------------------------===// |
1246 | |
1247 | /// Proof-of-concept rewriter. This rule recognizes certain math kernels |
1248 | /// and replaces these with corresponding calls into a sparse library. |
1249 | struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> { |
1250 | using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; |
1251 | |
1252 | LinalgOpRewriter(MLIRContext *context, bool rt) |
1253 | : OpRewritePattern(context), enableRT(rt) {} |
1254 | |
1255 | LogicalResult matchAndRewrite(linalg::GenericOp op, |
1256 | PatternRewriter &rewriter) const override { |
1257 | if (op.getNumDpsInits() != 1) |
1258 | return failure(); // reject multi-output |
1259 | |
1260 | const unsigned numLoops = op.getNumLoops(); |
1261 | const unsigned numTensors = op->getNumOperands(); |
1262 | const auto iteratorTypes = op.getIteratorTypesArray(); |
1263 | SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); |
1264 | |
1265 | using MapList = ArrayRef<ArrayRef<AffineExpr>>; |
1266 | auto infer = [&](MapList m) { |
1267 | return AffineMap::inferFromExprList(m, op.getContext()); |
1268 | }; |
1269 | AffineExpr i, j, k; |
1270 | bindDims(getContext(), i, j, k); |
1271 | |
1272 | // TODO: more robust patterns, tranposed versions, more kernels, |
1273 | // identify alpha and beta and pass them to the CUDA calls. |
1274 | |
1275 | // Recognize a SpMV kernel. |
1276 | if (numLoops == 2 && numTensors == 3 && |
1277 | linalg::isParallelIterator(iteratorTypes[0]) && |
1278 | linalg::isReductionIterator(iteratorTypes[1]) && |
1279 | maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) { |
1280 | return rewriteSpMV(rewriter, op, enableRT); |
1281 | } |
1282 | |
1283 | // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel. |
1284 | if (numLoops == 3 && numTensors == 3 && |
1285 | linalg::isParallelIterator(iteratorTypes[0]) && |
1286 | linalg::isParallelIterator(iteratorTypes[1]) && |
1287 | linalg::isReductionIterator(iteratorTypes[2]) && |
1288 | maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) { |
1289 | if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1))) |
1290 | return rewriteSpGEMM(rewriter, op, enableRT); |
1291 | if (isConversionInto24(op.getOperand(0))) |
1292 | return rewrite2To4SpMM(rewriter, op); |
1293 | return rewriteSpMM(rewriter, op, enableRT); |
1294 | } |
1295 | |
1296 | // Recognize a SDDMM kernel. |
1297 | if (numLoops == 3 && numTensors == 3 && |
1298 | linalg::isParallelIterator(iteratorTypes[0]) && |
1299 | linalg::isParallelIterator(iteratorTypes[1]) && |
1300 | linalg::isReductionIterator(iteratorTypes[2]) && |
1301 | maps == infer({{i, k}, {k, j}, {i, j}}) && |
1302 | matchSumReductionOfMulUnary(op)) { |
1303 | return rewriteSDDMM(rewriter, op, enableRT); |
1304 | } |
1305 | |
1306 | return failure(); |
1307 | } |
1308 | |
1309 | private: |
1310 | bool enableRT; |
1311 | }; |
1312 | |
1313 | } // namespace |
1314 | |
1315 | //===----------------------------------------------------------------------===// |
1316 | // Public method for populating GPU rewriting rules. |
1317 | // |
1318 | // Currently two set of rewriting rules are made available. The first set |
1319 | // implements direct code generation, currently by means of convering the |
1320 | // outermost paralell loop into GPU threads. The second set implements |
1321 | // libary recognition of a set of sparse operations. Eventually, the right |
1322 | // combination of these two approaches has to be found. |
1323 | //===----------------------------------------------------------------------===// |
1324 | |
1325 | void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, |
1326 | unsigned numThreads) { |
1327 | patterns.add<ForallRewriter>(arg: patterns.getContext(), args&: numThreads); |
1328 | } |
1329 | |
1330 | void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns, |
1331 | bool enableRT) { |
1332 | patterns.add<LinalgOpRewriter>(arg: patterns.getContext(), args&: enableRT); |
1333 | } |
1334 | |