1 | //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements lowering of vector operations to GPU dialect ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" |
14 | |
15 | #include <type_traits> |
16 | |
17 | #include "mlir/Analysis/SliceAnalysis.h" |
18 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
19 | #include "mlir/Dialect/Arith/IR/Arith.h" |
20 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
21 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
22 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
23 | #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" |
24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
25 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
26 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
27 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
28 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
29 | #include "mlir/IR/Builders.h" |
30 | #include "mlir/IR/BuiltinOps.h" |
31 | #include "mlir/IR/Region.h" |
32 | #include "mlir/Pass/Pass.h" |
33 | #include "mlir/Support/LogicalResult.h" |
34 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
35 | #include "mlir/Transforms/Passes.h" |
36 | #include "llvm/ADT/STLExtras.h" |
37 | #include "llvm/ADT/TypeSwitch.h" |
38 | |
39 | #define DEBUG_TYPE "vector-to-gpu" |
40 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
41 | #define DBGSNL() (llvm::dbgs() << "\n") |
42 | |
43 | namespace mlir { |
44 | #define GEN_PASS_DEF_CONVERTVECTORTOGPU |
45 | #include "mlir/Conversion/Passes.h.inc" |
46 | } // namespace mlir |
47 | |
48 | using namespace mlir; |
49 | |
50 | /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an |
51 | /// AffineMap representing offsets to apply to indices, the function fills |
52 | /// `indices` with the original indices plus the offsets. The offsets are |
53 | /// applied by taking into account the permutation map of the transfer op. If |
54 | /// the `offsetMap` has dimension placeholders, those should be provided in |
55 | /// `dimValues`. |
56 | template <typename TransferOpType> |
57 | static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, |
58 | AffineMap offsetMap, ArrayRef<Value> dimValues, |
59 | SmallVector<Value, 4> &indices) { |
60 | indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); |
61 | Location loc = xferOp.getLoc(); |
62 | unsigned offsetsIdx = 0; |
63 | for (auto expr : xferOp.getPermutationMap().getResults()) { |
64 | if (auto dim = dyn_cast<AffineDimExpr>(expr)) { |
65 | Value prevIdx = indices[dim.getPosition()]; |
66 | SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end()); |
67 | dims.push_back(prevIdx); |
68 | AffineExpr d0 = rewriter.getAffineDimExpr(position: offsetMap.getNumDims()); |
69 | indices[dim.getPosition()] = affine::makeComposedAffineApply( |
70 | rewriter, loc, d0 + offsetMap.getResult(idx: offsetsIdx++), dims); |
71 | continue; |
72 | } |
73 | } |
74 | } |
75 | |
76 | // Return true if the contract op can be convert to MMA matmul. |
77 | static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, |
78 | bool useNvGpu) { |
79 | using MapList = ArrayRef<ArrayRef<AffineExpr>>; |
80 | auto infer = [&](MapList m) { |
81 | return AffineMap::inferFromExprList(m, contract.getContext()); |
82 | }; |
83 | AffineExpr m, n, k; |
84 | bindDims(contract.getContext(), m, n, k); |
85 | auto iteratorTypes = contract.getIteratorTypes().getValue(); |
86 | if (!(vector::isParallelIterator(attr: iteratorTypes[0]) && |
87 | vector::isParallelIterator(attr: iteratorTypes[1]) && |
88 | vector::isReductionIterator(attr: iteratorTypes[2]))) |
89 | return false; |
90 | |
91 | // The contract needs to represent a matmul to be able to convert to |
92 | // MMAMatrix matmul. |
93 | if (!useNvGpu && |
94 | contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}})) |
95 | return false; |
96 | if (useNvGpu && |
97 | contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}})) |
98 | return false; |
99 | |
100 | return true; |
101 | } |
102 | |
103 | // Return true if the given map represents a transposed matrix load, |
104 | // i.e. (d0, d1, ...) -> (dn-1, dn-2). |
105 | static bool isTransposeMatrixLoadMap(AffineMap permutationMap) { |
106 | MLIRContext *ctx = permutationMap.getContext(); |
107 | // Local OpBuilder is fine here, we just build attributes. |
108 | OpBuilder b(ctx); |
109 | auto nDim = permutationMap.getNumDims(); |
110 | AffineExpr zero = b.getAffineConstantExpr(constant: 0); |
111 | if (nDim < 2) { |
112 | // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>. |
113 | AffineExpr dim0 = b.getAffineDimExpr(position: 0); |
114 | return permutationMap == AffineMap::get(dimCount: 1, symbolCount: 0, results: {dim0, zero}, context: ctx); |
115 | } |
116 | |
117 | AffineExpr innerDim = b.getAffineDimExpr(position: nDim - 1); |
118 | AffineExpr outerDim = b.getAffineDimExpr(position: nDim - 2); |
119 | // Support both transposed and transposed+broadcasted cases. |
120 | return permutationMap == AffineMap::get(dimCount: nDim, symbolCount: 0, results: {innerDim, outerDim}, context: ctx) || |
121 | permutationMap == AffineMap::get(dimCount: nDim, symbolCount: 0, results: {innerDim, zero}, context: ctx); |
122 | } |
123 | |
124 | // Return the stide for the second-to-last dimension of |type| if it is a memref |
125 | // and has a constant stride. |
126 | static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) { |
127 | auto memrefType = dyn_cast<MemRefType>(type); |
128 | if (!memrefType) |
129 | return false; |
130 | // If the memref is 0 or 1D the horizontal stride is 0. |
131 | if (memrefType.getRank() < 2) |
132 | return 0; |
133 | int64_t offset = 0; |
134 | SmallVector<int64_t, 2> strides; |
135 | if (failed(getStridesAndOffset(memrefType, strides, offset)) || |
136 | strides.back() != 1) |
137 | return std::nullopt; |
138 | int64_t stride = strides[strides.size() - 2]; |
139 | if (stride == ShapedType::kDynamic) |
140 | return std::nullopt; |
141 | return stride; |
142 | } |
143 | |
144 | // Return true if the transfer op can be converted to a MMA matrix load. |
145 | static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { |
146 | if (readOp.getMask() || readOp.hasOutOfBoundsDim() || |
147 | readOp.getVectorType().getRank() != 2) |
148 | return false; |
149 | if (!getStaticallyKnownRowStride(readOp.getShapedType())) |
150 | return false; |
151 | |
152 | // Only allow integer types if the signedness can be inferred. |
153 | if (readOp.getVectorType().getElementType().isInteger(8)) |
154 | if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) && |
155 | !isa<arith::ExtUIOp>(*readOp->user_begin()))) |
156 | return false; |
157 | |
158 | AffineMap map = readOp.getPermutationMap(); |
159 | MLIRContext *ctx = readOp.getContext(); |
160 | AffineExpr innerDim = getAffineDimExpr(position: map.getNumDims() - 1, context: ctx); |
161 | AffineExpr zero = getAffineConstantExpr(constant: 0, context: ctx); |
162 | auto broadcastInnerDim = |
163 | AffineMap::get(dimCount: map.getNumDims(), symbolCount: 0, results: {zero, innerDim}, context: ctx); |
164 | return map.isMinorIdentity() || map == broadcastInnerDim || |
165 | isTransposeMatrixLoadMap(permutationMap: map); |
166 | } |
167 | |
168 | // Return true if the transfer op can be converted to a MMA matrix store. |
169 | static bool |
170 | transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { |
171 | // TODO: support 0-d corner case. |
172 | if (writeOp.getTransferRank() == 0) |
173 | return false; |
174 | |
175 | if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || |
176 | writeOp.getVectorType().getRank() != 2) |
177 | return false; |
178 | if (!getStaticallyKnownRowStride(writeOp.getShapedType())) |
179 | return false; |
180 | // TODO: Support transpose once it is added to GPU dialect ops. |
181 | if (!writeOp.getPermutationMap().isMinorIdentity()) |
182 | return false; |
183 | return true; |
184 | } |
185 | |
186 | /// Return true if the constant is a splat to a 2D vector so that it can be |
187 | /// converted to a MMA constant matrix op. |
188 | static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { |
189 | auto vecType = dyn_cast<VectorType>(constantOp.getType()); |
190 | if (!vecType || vecType.getRank() != 2) |
191 | return false; |
192 | return isa<SplatElementsAttr>(constantOp.getValue()); |
193 | } |
194 | |
195 | /// Return true if this is a broadcast from scalar to a 2D vector. |
196 | static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { |
197 | return broadcastOp.getResultVectorType().getRank() == 2; |
198 | } |
199 | |
200 | /// Return true if this integer extend op can be folded into a contract op. |
201 | template <typename ExtOpTy> |
202 | static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) { |
203 | if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp())) |
204 | return false; |
205 | return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>); |
206 | } |
207 | |
208 | static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; } |
209 | |
210 | /// Return the MMA elementwise enum associated with `op` if it is supported. |
211 | /// Return `std::nullopt` otherwise. |
212 | static std::optional<gpu::MMAElementwiseOp> |
213 | convertElementwiseOpToMMA(Operation *op) { |
214 | if (isa<arith::AddFOp>(op)) |
215 | return gpu::MMAElementwiseOp::ADDF; |
216 | if (isa<arith::MulFOp>(op)) |
217 | return gpu::MMAElementwiseOp::MULF; |
218 | if (isa<arith::SubFOp>(op)) |
219 | return gpu::MMAElementwiseOp::SUBF; |
220 | if (isa<arith::MaximumFOp>(op)) |
221 | return gpu::MMAElementwiseOp::MAXF; |
222 | if (isa<arith::MinimumFOp>(op)) |
223 | return gpu::MMAElementwiseOp::MINF; |
224 | if (isa<arith::DivFOp>(op)) |
225 | return gpu::MMAElementwiseOp::DIVF; |
226 | if (isa<arith::AddIOp>(op)) |
227 | return gpu::MMAElementwiseOp::ADDI; |
228 | if (isa<arith::MulIOp>(op)) |
229 | return gpu::MMAElementwiseOp::MULI; |
230 | if (isa<arith::SubIOp>(op)) |
231 | return gpu::MMAElementwiseOp::SUBI; |
232 | if (isa<arith::DivSIOp>(op)) |
233 | return gpu::MMAElementwiseOp::DIVS; |
234 | if (isa<arith::DivUIOp>(op)) |
235 | return gpu::MMAElementwiseOp::DIVU; |
236 | if (isa<arith::NegFOp>(op)) |
237 | return gpu::MMAElementwiseOp::NEGATEF; |
238 | if (isa<arith::ExtFOp>(op)) |
239 | return gpu::MMAElementwiseOp::EXTF; |
240 | return std::nullopt; |
241 | } |
242 | |
243 | /// Return true if the op is supported as elementwise op on MMAMatrix type. |
244 | static bool elementwiseSupportsMMAMatrixType(Operation *op) { |
245 | return convertElementwiseOpToMMA(op).has_value(); |
246 | } |
247 | |
248 | /// Returns true if the extract strided slice op is supported with `mma.sync` |
249 | /// path. |
250 | static bool |
251 | (vector::ExtractStridedSliceOp op) { |
252 | |
253 | FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = |
254 | nvgpu::getWarpMatrixInfo(op); |
255 | if (failed(warpMatrixInfo)) |
256 | return false; |
257 | |
258 | FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op); |
259 | if (failed(contractOp)) |
260 | return false; |
261 | |
262 | // Handle vector.extract_strided_slice on registers containing |
263 | // matrixB and matrixC operands. vector.extract_strided_slice op |
264 | // is not supported on registers containing matrixA operands. |
265 | if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) |
266 | return (cast<VectorType>(op->getResult(0).getType()) == |
267 | cast<VectorType>((*contractOp).getRhs().getType())); |
268 | if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) |
269 | return (cast<VectorType>(op->getResult(0).getType()) == |
270 | cast<VectorType>((*contractOp).getAcc().getType())); |
271 | |
272 | return false; |
273 | } |
274 | |
275 | static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { |
276 | if (isa<scf::ForOp, scf::YieldOp>(op)) |
277 | return true; |
278 | if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) |
279 | return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead) |
280 | : transferReadSupportsMMAMatrixType(transferRead); |
281 | if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) |
282 | return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite) |
283 | : transferWriteSupportsMMAMatrixType(transferWrite); |
284 | if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op)) |
285 | return useNvGpu && |
286 | extractStridedSliceSupportsMMAMatrixType(extractStridedSlice); |
287 | if (auto contract = dyn_cast<vector::ContractionOp>(op)) |
288 | return contractSupportsMMAMatrixType(contract, useNvGpu); |
289 | if (auto constant = dyn_cast<arith::ConstantOp>(op)) |
290 | return constantSupportsMMAMatrixType(constant); |
291 | if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) |
292 | return broadcastSupportsMMAMatrixType(broadcast); |
293 | if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op)) |
294 | return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend); |
295 | if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op)) |
296 | return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend); |
297 | if (auto fpExtend = dyn_cast<arith::ExtFOp>(op)) |
298 | return fpExtendSupportsMMAMatrixType(fpExtend); |
299 | return elementwiseSupportsMMAMatrixType(op); |
300 | } |
301 | |
302 | /// Return an unsorted slice handling scf.for region differently than |
303 | /// `getSlice`. In scf.for we only want to include as part of the slice elements |
304 | /// that are part of the use/def chain. |
305 | static SetVector<Operation *> |
306 | getSliceContract(Operation *op, |
307 | const BackwardSliceOptions &backwardSliceOptions, |
308 | const ForwardSliceOptions &forwardSliceOptions) { |
309 | SetVector<Operation *> slice; |
310 | slice.insert(X: op); |
311 | unsigned currentIndex = 0; |
312 | SetVector<Operation *> backwardSlice; |
313 | SetVector<Operation *> forwardSlice; |
314 | while (currentIndex != slice.size()) { |
315 | auto *currentOp = (slice)[currentIndex]; |
316 | // Compute and insert the backwardSlice starting from currentOp. |
317 | backwardSlice.clear(); |
318 | getBackwardSlice(op: currentOp, backwardSlice: &backwardSlice, options: backwardSliceOptions); |
319 | slice.insert(Start: backwardSlice.begin(), End: backwardSlice.end()); |
320 | |
321 | // Compute and insert the forwardSlice starting from currentOp. |
322 | forwardSlice.clear(); |
323 | // Special case for ForOp, we don't want to include the whole region but |
324 | // only the value using the region arguments. |
325 | // TODO: We should refine this to only care about the region arguments being |
326 | // converted to matrix type. |
327 | if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) { |
328 | for (Value forOpResult : forOp.getResults()) |
329 | getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions); |
330 | for (BlockArgument &arg : forOp.getRegionIterArgs()) |
331 | getForwardSlice(arg, &forwardSlice, forwardSliceOptions); |
332 | } else { |
333 | getForwardSlice(op: currentOp, forwardSlice: &forwardSlice, options: forwardSliceOptions); |
334 | } |
335 | slice.insert(Start: forwardSlice.begin(), End: forwardSlice.end()); |
336 | ++currentIndex; |
337 | } |
338 | return slice; |
339 | } |
340 | |
341 | // Analyze slice of operations based on convert op to figure out if the whole |
342 | // slice can be converted to MMA operations. |
343 | static SetVector<Operation *> getOpToConvert(mlir::Operation *op, |
344 | bool useNvGpu) { |
345 | auto hasVectorDest = [](Operation *op) { |
346 | return llvm::any_of(Range: op->getResultTypes(), P: llvm::IsaPred<VectorType>); |
347 | }; |
348 | BackwardSliceOptions backwardSliceOptions; |
349 | backwardSliceOptions.filter = hasVectorDest; |
350 | |
351 | auto hasVectorSrc = [](Operation *op) { |
352 | return llvm::any_of(Range: op->getOperandTypes(), P: llvm::IsaPred<VectorType>); |
353 | }; |
354 | ForwardSliceOptions forwardSliceOptions; |
355 | forwardSliceOptions.filter = hasVectorSrc; |
356 | |
357 | SetVector<Operation *> opToConvert; |
358 | op->walk([&](vector::ContractionOp contract) { |
359 | if (opToConvert.contains(key: contract.getOperation())) |
360 | return; |
361 | SetVector<Operation *> dependentOps = |
362 | getSliceContract(contract, backwardSliceOptions, forwardSliceOptions); |
363 | // If any instruction cannot use MMA matrix type drop the whole |
364 | // chain. MMA matrix are stored in an opaque type so they cannot be used |
365 | // by all operations. |
366 | if (llvm::any_of(Range&: dependentOps, P: [useNvGpu](Operation *op) { |
367 | if (!supportsMMaMatrixType(op, useNvGpu)) { |
368 | LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n" ); |
369 | return true; |
370 | } |
371 | return false; |
372 | })) |
373 | return; |
374 | |
375 | opToConvert.insert(Start: dependentOps.begin(), End: dependentOps.end()); |
376 | }); |
377 | // Sort the operations so that we can convert them in topological order. |
378 | return topologicalSort(toSort: opToConvert); |
379 | } |
380 | |
381 | namespace { |
382 | // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted |
383 | // to MMA matmul. |
384 | struct PrepareContractToGPUMMA |
385 | : public OpRewritePattern<vector::ContractionOp> { |
386 | using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; |
387 | |
388 | LogicalResult matchAndRewrite(vector::ContractionOp op, |
389 | PatternRewriter &rewriter) const override { |
390 | Location loc = op.getLoc(); |
391 | Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); |
392 | |
393 | // Set up the parallel/reduction structure in right form. |
394 | using MapList = ArrayRef<ArrayRef<AffineExpr>>; |
395 | auto infer = [&](MapList m) { |
396 | return AffineMap::inferFromExprList(m, op.getContext()); |
397 | }; |
398 | AffineExpr m, n, k; |
399 | bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k); |
400 | static constexpr std::array<int64_t, 2> perm = {1, 0}; |
401 | auto iteratorTypes = op.getIteratorTypes().getValue(); |
402 | SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); |
403 | if (!(vector::isParallelIterator(attr: iteratorTypes[0]) && |
404 | vector::isParallelIterator(attr: iteratorTypes[1]) && |
405 | vector::isReductionIterator(attr: iteratorTypes[2]))) |
406 | return rewriter.notifyMatchFailure(op, "not a gemm contraction" ); |
407 | // |
408 | // Two outer parallel, one inner reduction (matmat flavor). |
409 | // |
410 | // This is the classical row-major matmul, nothing to do. |
411 | if (maps == infer({{m, k}, {k, n}, {m, n}})) |
412 | return rewriter.notifyMatchFailure(op, "contraction already prepared" ); |
413 | if (maps == infer({{m, k}, {n, k}, {m, n}})) { |
414 | rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); |
415 | } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { |
416 | lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); |
417 | } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { |
418 | rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); |
419 | lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); |
420 | } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { |
421 | std::swap(a&: rhs, b&: lhs); |
422 | rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); |
423 | lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); |
424 | } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { |
425 | std::swap(a&: rhs, b&: lhs); |
426 | rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); |
427 | } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { |
428 | std::swap(a&: lhs, b&: rhs); |
429 | lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); |
430 | } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { |
431 | std::swap(a&: lhs, b&: rhs); |
432 | } else { |
433 | // TODO: llvm_unreachable ? |
434 | return rewriter.notifyMatchFailure(op, "unexpected contraction case" ); |
435 | } |
436 | rewriter.replaceOpWithNewOp<vector::ContractionOp>( |
437 | op, lhs, rhs, res, |
438 | rewriter.getAffineMapArrayAttr(values: infer({{m, k}, {k, n}, {m, n}})), |
439 | op.getIteratorTypes()); |
440 | return success(); |
441 | } |
442 | }; |
443 | |
444 | // Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports |
445 | // row-, column-, and row-major layout for matrixA, matrixB, and matrixC, |
446 | // respectively. We can fold the transpose operation when loading the data from |
447 | // Shared Memory to registers. |
448 | struct CombineTransferReadOpTranspose final |
449 | : public OpRewritePattern<vector::TransposeOp> { |
450 | using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; |
451 | |
452 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
453 | PatternRewriter &rewriter) const override { |
454 | // Look through integer extend ops. |
455 | Value source = op.getVector(); |
456 | Type resultType = op.getType(); |
457 | Operation *extOp; |
458 | if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) || |
459 | (extOp = source.getDefiningOp<arith::ExtUIOp>()) || |
460 | (extOp = source.getDefiningOp<arith::ExtFOp>())) { |
461 | source = extOp->getOperand(idx: 0); |
462 | resultType = |
463 | VectorType::get(cast<VectorType>(resultType).getShape(), |
464 | cast<VectorType>(source.getType()).getElementType()); |
465 | } |
466 | |
467 | auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>(); |
468 | if (!transferReadOp) |
469 | return rewriter.notifyMatchFailure(op, "no transfer read" ); |
470 | |
471 | // TODO: support 0-d corner case. |
472 | if (transferReadOp.getTransferRank() == 0) |
473 | return rewriter.notifyMatchFailure(op, "0-D transfer read" ); |
474 | |
475 | if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) |
476 | return rewriter.notifyMatchFailure(op, "not inbounds transfer read" ); |
477 | |
478 | AffineMap permutationMap = |
479 | AffineMap::getPermutationMap(op.getPermutation(), op.getContext()); |
480 | AffineMap newMap = |
481 | permutationMap.compose(transferReadOp.getPermutationMap()); |
482 | |
483 | auto loc = op.getLoc(); |
484 | Value result = |
485 | rewriter |
486 | .create<vector::TransferReadOp>( |
487 | loc, resultType, transferReadOp.getSource(), |
488 | transferReadOp.getIndices(), AffineMapAttr::get(newMap), |
489 | transferReadOp.getPadding(), transferReadOp.getMask(), |
490 | transferReadOp.getInBoundsAttr()) |
491 | .getResult(); |
492 | |
493 | // Fuse through the integer extend op. |
494 | if (extOp) { |
495 | if (isa<arith::ExtSIOp>(extOp)) |
496 | result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result) |
497 | .getResult(); |
498 | else if (isa<arith::ExtUIOp>(extOp)) |
499 | result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result) |
500 | .getResult(); |
501 | else |
502 | result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result) |
503 | .getResult(); |
504 | } |
505 | |
506 | rewriter.replaceOp(op, result); |
507 | return success(); |
508 | } |
509 | }; |
510 | |
511 | } // namespace |
512 | |
513 | // MMA types have different layout based on how they are used in matmul ops. |
514 | // Figure the right layout to use by looking at op uses. |
515 | // TODO: Change the GPU dialect to abstract the layout at the this level and |
516 | // only care about it during lowering to NVVM. |
517 | static const char *inferFragType(Operation *op) { |
518 | for (Operation *users : op->getUsers()) { |
519 | auto contract = dyn_cast<vector::ContractionOp>(users); |
520 | if (!contract) |
521 | continue; |
522 | assert(op->getNumResults() == 1); |
523 | if (contract.getLhs() == op->getResult(idx: 0)) |
524 | return "AOp" ; |
525 | if (contract.getRhs() == op->getResult(idx: 0)) |
526 | return "BOp" ; |
527 | } |
528 | return "COp" ; |
529 | } |
530 | |
531 | static LogicalResult |
532 | convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, |
533 | llvm::DenseMap<Value, Value> &valueMapping) { |
534 | OpBuilder::InsertionGuard g(rewriter); |
535 | rewriter.setInsertionPoint(op); |
536 | |
537 | assert(op.getTransferRank() > 0 && "unexpected 0-d transfer" ); |
538 | assert(transferReadSupportsMMAMatrixType(op) && |
539 | "expected convertible operation" ); |
540 | |
541 | std::optional<int64_t> stride = |
542 | getStaticallyKnownRowStride(op.getShapedType()); |
543 | if (!stride.has_value()) { |
544 | LLVM_DEBUG(DBGS() << "no stride\n" ); |
545 | return rewriter.notifyMatchFailure(op, "no stride" ); |
546 | } |
547 | |
548 | AffineMap map = op.getPermutationMap(); |
549 | bool isTranspose = isTransposeMatrixLoadMap(permutationMap: map); |
550 | |
551 | // Handle broadcast by setting the stride to 0. |
552 | if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) { |
553 | assert(cstExpr.getValue() == 0); |
554 | stride = 0; |
555 | } |
556 | |
557 | Value mappingResult = op.getResult(); |
558 | auto elType = op.getVectorType().getElementType(); |
559 | const char *fragType = inferFragType(op); |
560 | if (op->hasOneUse()) { |
561 | auto *user = *op->user_begin(); |
562 | // Infer the signedness of the mma type from the integer extend. |
563 | bool isSignedExtend = isa<arith::ExtSIOp>(user); |
564 | if (isSignedExtend || isa<arith::ExtUIOp>(user)) { |
565 | elType = IntegerType::get( |
566 | op.getContext(), cast<IntegerType>(elType).getWidth(), |
567 | isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned); |
568 | mappingResult = user->getResult(0); |
569 | fragType = inferFragType(user); |
570 | } |
571 | } |
572 | gpu::MMAMatrixType type = |
573 | gpu::MMAMatrixType::get(shape: op.getVectorType().getShape(), elementType: elType, operand: fragType); |
574 | Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>( |
575 | op.getLoc(), type, op.getSource(), op.getIndices(), |
576 | rewriter.getIndexAttr(*stride), |
577 | isTranspose ? rewriter.getUnitAttr() : UnitAttr()); |
578 | valueMapping[mappingResult] = load; |
579 | |
580 | LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n" ); |
581 | return success(); |
582 | } |
583 | |
584 | static LogicalResult |
585 | convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, |
586 | llvm::DenseMap<Value, Value> &valueMapping) { |
587 | OpBuilder::InsertionGuard g(rewriter); |
588 | rewriter.setInsertionPoint(op); |
589 | |
590 | assert(transferWriteSupportsMMAMatrixType(op)); |
591 | std::optional<int64_t> stride = |
592 | getStaticallyKnownRowStride(op.getShapedType()); |
593 | if (!stride.has_value()) { |
594 | LLVM_DEBUG(DBGS() << "no stride\n" ); |
595 | return rewriter.notifyMatchFailure(op, "no stride" ); |
596 | } |
597 | |
598 | auto it = valueMapping.find(op.getVector()); |
599 | if (it == valueMapping.end()) { |
600 | LLVM_DEBUG(DBGS() << "no mapping\n" ); |
601 | return rewriter.notifyMatchFailure(op, "no mapping" ); |
602 | } |
603 | |
604 | Value matrix = it->second; |
605 | auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>( |
606 | op.getLoc(), matrix, op.getSource(), op.getIndices(), |
607 | rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); |
608 | (void)store; |
609 | |
610 | LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n" ); |
611 | |
612 | LLVM_DEBUG(DBGS() << "erase: " << op << "\n" ); |
613 | rewriter.eraseOp(op: op); |
614 | return success(); |
615 | } |
616 | |
617 | /// Returns the vector type which represents a matrix fragment. |
618 | static VectorType |
619 | getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { |
620 | SmallVector<int64_t> shape{regInfo.numRegistersPerFragment, |
621 | regInfo.elementsPerRegister}; |
622 | Type elType = regInfo.registerLLVMType; |
623 | if (auto vecType = dyn_cast<VectorType>(elType)) |
624 | elType = vecType.getElementType(); |
625 | return VectorType::get(shape, elType); |
626 | } |
627 | |
628 | /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. |
629 | static LogicalResult |
630 | convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, |
631 | llvm::DenseMap<Value, Value> &valueMapping) { |
632 | OpBuilder::InsertionGuard g(rewriter); |
633 | rewriter.setInsertionPoint(op); |
634 | |
635 | FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = |
636 | nvgpu::getWarpMatrixInfo(op: op); |
637 | if (failed(result: warpMatrixInfo)) { |
638 | LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n" ); |
639 | return rewriter.notifyMatchFailure(op, "no warpMatrixInfo" ); |
640 | } |
641 | |
642 | FailureOr<nvgpu::FragmentElementInfo> regInfo = |
643 | nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo); |
644 | if (failed(result: regInfo)) { |
645 | LLVM_DEBUG(DBGS() << "not mma sync reg info\n" ); |
646 | return rewriter.notifyMatchFailure(op, "not mma sync reg info" ); |
647 | } |
648 | |
649 | VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); |
650 | auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); |
651 | if (!dense) { |
652 | LLVM_DEBUG(DBGS() << "not a splat\n" ); |
653 | return rewriter.notifyMatchFailure(op, "not a splat" ); |
654 | } |
655 | |
656 | Value result = rewriter.create<arith::ConstantOp>( |
657 | op.getLoc(), vectorType, |
658 | DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); |
659 | valueMapping[op.getResult()] = result; |
660 | return success(); |
661 | } |
662 | |
663 | /// Check if the loaded matrix operand requires transposed. |
664 | /// Transposed Map Example: |
665 | /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2) |
666 | /// Example 2 : (d0, d1, d2, d3) -> (d3, d2) |
667 | /// The code below checks if the output 2D is transposed using a generalized |
668 | /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn) |
669 | /// Returns : true; if m > n, false o.w. |
670 | static FailureOr<bool> isTransposed(vector::TransferReadOp op) { |
671 | mlir::AffineMap map = op.getPermutationMap(); |
672 | |
673 | if (map.getNumResults() != 2) { |
674 | LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " |
675 | "is not a 2d operand\n" ); |
676 | return failure(); |
677 | } |
678 | |
679 | // Output 2D matrix dimensions in the order of d0, d1. |
680 | mlir::AffineExpr dM = map.getResult(idx: 0); |
681 | mlir::AffineExpr dN = map.getResult(idx: 1); |
682 | |
683 | // Find the position of these expressions in the input. |
684 | auto exprM = dyn_cast<AffineDimExpr>(Val&: dM); |
685 | auto exprN = dyn_cast<AffineDimExpr>(Val&: dN); |
686 | |
687 | if (!exprM || !exprN) { |
688 | LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " |
689 | "expressions, then transpose cannot be determined.\n" ); |
690 | return failure(); |
691 | } |
692 | |
693 | return exprM.getPosition() > exprN.getPosition(); |
694 | } |
695 | |
696 | static LogicalResult |
697 | creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, |
698 | llvm::DenseMap<Value, Value> &valueMapping) { |
699 | OpBuilder::InsertionGuard g(rewriter); |
700 | rewriter.setInsertionPoint(op); |
701 | Location loc = op->getLoc(); |
702 | |
703 | FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = |
704 | nvgpu::getWarpMatrixInfo(op: op); |
705 | if (failed(result: warpMatrixInfo)) { |
706 | LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n" ); |
707 | return rewriter.notifyMatchFailure(op, "no warpMatrixInfo" ); |
708 | } |
709 | |
710 | FailureOr<nvgpu::FragmentElementInfo> regInfo = |
711 | nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo); |
712 | if (failed(result: regInfo)) { |
713 | LLVM_DEBUG(DBGS() << "not mma sync reg info\n" ); |
714 | return rewriter.notifyMatchFailure(op, "not mma sync reg info" ); |
715 | } |
716 | |
717 | FailureOr<bool> transpose = isTransposed(op); |
718 | if (failed(result: transpose)) { |
719 | LLVM_DEBUG(DBGS() << "failed to determine the transpose\n" ); |
720 | return rewriter.notifyMatchFailure( |
721 | op, "Op should likely not be converted to a nvgpu.ldmatrix call." ); |
722 | } |
723 | |
724 | FailureOr<nvgpu::LdMatrixParams> params = |
725 | nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); |
726 | |
727 | if (failed(params)) { |
728 | LLVM_DEBUG( |
729 | DBGS() |
730 | << "failed to convert vector.transfer_read to ldmatrix. " |
731 | << "Op should likely not be converted to a nvgpu.ldmatrix call.\n" ); |
732 | return rewriter.notifyMatchFailure( |
733 | op, "failed to convert vector.transfer_read to ldmatrix; this op " |
734 | "likely should not be converted to a nvgpu.ldmatrix call." ); |
735 | } |
736 | |
737 | // Adjust the load offset. |
738 | auto laneId = rewriter.create<gpu::LaneIdOp>(loc); |
739 | FailureOr<AffineMap> offsets = |
740 | nvgpu::getLaneIdToLdMatrixMatrixCoord(builder&: rewriter, loc, params: *params); |
741 | if (failed(result: offsets)) { |
742 | LLVM_DEBUG(DBGS() << "no offsets\n" ); |
743 | return rewriter.notifyMatchFailure(op, "no offsets" ); |
744 | } |
745 | |
746 | VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); |
747 | |
748 | SmallVector<Value, 4> indices; |
749 | getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId}, |
750 | indices); |
751 | |
752 | nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>( |
753 | loc, vectorType, op.getSource(), indices, *transpose, params->numTiles); |
754 | valueMapping[op] = newOp->getResult(0); |
755 | return success(); |
756 | } |
757 | |
758 | static LogicalResult |
759 | createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, |
760 | llvm::DenseMap<Value, Value> &valueMapping) { |
761 | OpBuilder::InsertionGuard g(rewriter); |
762 | rewriter.setInsertionPoint(op); |
763 | |
764 | Location loc = op.getLoc(); |
765 | FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = |
766 | nvgpu::getWarpMatrixInfo(op: op); |
767 | if (failed(result: warpMatrixInfo)) |
768 | return rewriter.notifyMatchFailure(op, "no warpMatrixInfo" ); |
769 | FailureOr<nvgpu::FragmentElementInfo> regInfo = |
770 | nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo); |
771 | if (failed(result: regInfo)) { |
772 | return rewriter.notifyMatchFailure( |
773 | op, "Failed to deduce register fragment type during " |
774 | "conversion to distributed non-ldmatrix compatible load" ); |
775 | } |
776 | |
777 | Value laneId = rewriter.create<gpu::LaneIdOp>(loc); |
778 | SmallVector<Value, 4> elements; |
779 | |
780 | // This is the individual element type. |
781 | Type loadedElType = regInfo->registerLLVMType; |
782 | VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); |
783 | |
784 | Value fill = rewriter.create<arith::ConstantOp>( |
785 | op.getLoc(), vectorType.getElementType(), |
786 | rewriter.getZeroAttr(vectorType.getElementType())); |
787 | Value result = |
788 | rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType); |
789 | |
790 | bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); |
791 | |
792 | // If we are not transposing, then we can use vectorized loads. Otherwise, we |
793 | // must load each element individually. |
794 | if (!isTransposeLoad) { |
795 | if (!isa<VectorType>(Val: loadedElType)) { |
796 | loadedElType = VectorType::get({1}, loadedElType); |
797 | } |
798 | |
799 | for (int i = 0; i < vectorType.getShape()[0]; i++) { |
800 | FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( |
801 | builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo); |
802 | if (failed(result: coords)) |
803 | return rewriter.notifyMatchFailure(op, "no coords" ); |
804 | |
805 | Value logicalValueId = rewriter.create<arith::ConstantOp>( |
806 | loc, rewriter.getIndexType(), |
807 | rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); |
808 | SmallVector<Value, 4> newIndices; |
809 | getXferIndices<vector::TransferReadOp>( |
810 | rewriter, op, *coords, {laneId, logicalValueId}, newIndices); |
811 | |
812 | Value el = rewriter.create<vector::LoadOp>(loc, loadedElType, |
813 | op.getSource(), newIndices); |
814 | result = rewriter.create<vector::InsertOp>(loc, el, result, i); |
815 | } |
816 | } else { |
817 | if (auto vecType = dyn_cast<VectorType>(loadedElType)) { |
818 | loadedElType = vecType.getElementType(); |
819 | } |
820 | for (int i = 0; i < vectorType.getShape()[0]; i++) { |
821 | for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; |
822 | innerIdx++) { |
823 | |
824 | Value logicalValueId = rewriter.create<arith::ConstantOp>( |
825 | loc, rewriter.getIndexType(), |
826 | rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); |
827 | FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( |
828 | builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo); |
829 | if (failed(result: coords)) |
830 | return rewriter.notifyMatchFailure(op, "no coords" ); |
831 | |
832 | SmallVector<Value, 4> newIndices; |
833 | getXferIndices<vector::TransferReadOp>( |
834 | rewriter, op, *coords, {laneId, logicalValueId}, newIndices); |
835 | Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType, |
836 | op.getSource(), newIndices); |
837 | result = rewriter.create<vector::InsertOp>( |
838 | op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx}); |
839 | } |
840 | } |
841 | } |
842 | |
843 | valueMapping[op.getResult()] = result; |
844 | return success(); |
845 | } |
846 | |
847 | /// Return true if this is a shared memory memref type. |
848 | static bool isSharedMemory(MemRefType type) { |
849 | auto addressSpace = |
850 | dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace()); |
851 | return addressSpace && |
852 | addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace(); |
853 | } |
854 | |
855 | /// Converts a `vector.transfer_read` operation directly to either a |
856 | /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be |
857 | /// used when converting to `nvgpu.mma.sync` operations. |
858 | static LogicalResult |
859 | convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, |
860 | llvm::DenseMap<Value, Value> &valueMapping) { |
861 | OpBuilder::InsertionGuard g(rewriter); |
862 | rewriter.setInsertionPoint(op); |
863 | |
864 | FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = |
865 | nvgpu::getWarpMatrixInfo(op: op); |
866 | if (failed(result: warpMatrixInfo)) |
867 | return rewriter.notifyMatchFailure(op, "no warpMatrixInfo" ); |
868 | |
869 | bool isLdMatrixCompatible = |
870 | isSharedMemory(cast<MemRefType>(op.getSource().getType())) && |
871 | nvgpu::inferTileWidthInBits(type: *warpMatrixInfo) == 128; |
872 | |
873 | VectorType vecTy = op.getVectorType(); |
874 | int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); |
875 | |
876 | // When we are transposing the B operand, ldmatrix will only work if we have |
877 | // at least 8 rows to read and the width to read for the transpose is 128 |
878 | // bits. |
879 | if (!op.getPermutationMap().isMinorIdentity() && |
880 | (bitWidth != 16 || vecTy.getDimSize(1) < 8 || |
881 | vecTy.getDimSize(0) * bitWidth < 128)) |
882 | isLdMatrixCompatible = false; |
883 | |
884 | if (!isLdMatrixCompatible) |
885 | return createNonLdMatrixLoads(rewriter, op, valueMapping); |
886 | |
887 | return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping); |
888 | } |
889 | |
890 | static LogicalResult |
891 | convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, |
892 | llvm::DenseMap<Value, Value> &valueMapping) { |
893 | OpBuilder::InsertionGuard g(rewriter); |
894 | rewriter.setInsertionPoint(op); |
895 | |
896 | Location loc = op->getLoc(); |
897 | auto it = valueMapping.find(op.getVector()); |
898 | if (it == valueMapping.end()) |
899 | return rewriter.notifyMatchFailure(op, "no mapping" ); |
900 | Value matrix = it->second; |
901 | |
902 | FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = |
903 | nvgpu::getWarpMatrixInfo(op: op); |
904 | if (failed(result: warpMatrixInfo)) |
905 | return rewriter.notifyMatchFailure(op, "no warpMatrixInfo" ); |
906 | FailureOr<nvgpu::FragmentElementInfo> regInfo = |
907 | nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo); |
908 | if (failed(result: regInfo)) |
909 | return rewriter.notifyMatchFailure(op, "not mma sync reg info" ); |
910 | |
911 | VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); |
912 | Value laneId = rewriter.create<gpu::LaneIdOp>(loc); |
913 | |
914 | for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { |
915 | Value logicalValueId = rewriter.create<arith::ConstantOp>( |
916 | loc, rewriter.getIndexType(), |
917 | rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); |
918 | FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( |
919 | builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo); |
920 | if (failed(result: coords)) |
921 | return rewriter.notifyMatchFailure(op, "no coords" ); |
922 | |
923 | Value el = |
924 | rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); |
925 | SmallVector<Value, 4> newIndices; |
926 | getXferIndices<vector::TransferWriteOp>( |
927 | rewriter, op, *coords, {laneId, logicalValueId}, newIndices); |
928 | rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices); |
929 | } |
930 | |
931 | LLVM_DEBUG(DBGS() << "erase: " << op << "\n" ); |
932 | rewriter.eraseOp(op: op); |
933 | return success(); |
934 | } |
935 | |
936 | static void populateFromInt64AttrArray(ArrayAttr arrayAttr, |
937 | SmallVectorImpl<int64_t> &results) { |
938 | for (auto attr : arrayAttr) |
939 | results.push_back(cast<IntegerAttr>(attr).getInt()); |
940 | } |
941 | |
942 | static LogicalResult |
943 | (RewriterBase &rewriter, |
944 | vector::ExtractStridedSliceOp op, |
945 | llvm::DenseMap<Value, Value> &valueMapping) { |
946 | OpBuilder::InsertionGuard g(rewriter); |
947 | rewriter.setInsertionPoint(op); |
948 | |
949 | Location loc = op->getLoc(); |
950 | |
951 | FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = |
952 | nvgpu::getWarpMatrixInfo(op: op); |
953 | if (failed(result: warpMatrixInfo)) |
954 | return rewriter.notifyMatchFailure(op, "no warpMatrixInfo" ); |
955 | |
956 | FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo = |
957 | nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo); |
958 | if (failed(result: mmaSyncFragmentInfo)) |
959 | return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo" ); |
960 | |
961 | // Find the vector.transer_read whose result vector is being sliced. |
962 | auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>(); |
963 | if (!transferReadOp) |
964 | return rewriter.notifyMatchFailure(op, "no transfer read" ); |
965 | |
966 | warpMatrixInfo = nvgpu::getWarpMatrixInfo(op: transferReadOp); |
967 | if (failed(result: warpMatrixInfo)) |
968 | return rewriter.notifyMatchFailure(op, "no warpMatrixInfo" ); |
969 | |
970 | FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo = |
971 | nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo); |
972 | if (failed(result: ldFragmentInfo)) |
973 | return rewriter.notifyMatchFailure(op, "no ldFragmentInfo" ); |
974 | |
975 | assert( |
976 | (mmaSyncFragmentInfo->elementsPerRegister == |
977 | ldFragmentInfo->elementsPerRegister) && |
978 | "Number of elements per register should be same for load and mma.sync" ); |
979 | |
980 | // Create vector.extract_strided_slice op for thread-owned fragments. |
981 | std::array<int64_t, 2> strides = {1, |
982 | 1}; // stride for extract slice is always 1. |
983 | std::array<int64_t, 2> sliceShape = { |
984 | mmaSyncFragmentInfo->numRegistersPerFragment, |
985 | mmaSyncFragmentInfo->elementsPerRegister}; |
986 | auto it = valueMapping.find(transferReadOp); |
987 | if (it == valueMapping.end()) |
988 | return rewriter.notifyMatchFailure(op, "no mapping" ); |
989 | auto sourceVector = it->second; |
990 | |
991 | // offset and sizes at warp-level of onwership. |
992 | SmallVector<int64_t> offsets; |
993 | populateFromInt64AttrArray(op.getOffsets(), offsets); |
994 | |
995 | SmallVector<int64_t> sizes; |
996 | populateFromInt64AttrArray(op.getSizes(), sizes); |
997 | ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape(); |
998 | |
999 | // Compute offset in vector registers. Note that the mma.sync vector registers |
1000 | // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector |
1001 | // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0]. |
1002 | std::array<int64_t, 2> sliceOffset = {0, 0}; |
1003 | |
1004 | if (offsets[0] && offsets[1]) |
1005 | return op->emitError() << "Slicing fragments in 2D is not supported. " ; |
1006 | if (offsets[0]) |
1007 | sliceOffset[0] = (warpVectorShape[0] / offsets[0]); |
1008 | else if (offsets[1]) |
1009 | sliceOffset[0] = (warpVectorShape[1] / offsets[1]); |
1010 | |
1011 | Value newOp = rewriter.create<vector::ExtractStridedSliceOp>( |
1012 | loc, sourceVector, sliceOffset, sliceShape, strides); |
1013 | |
1014 | valueMapping[op] = newOp; |
1015 | return success(); |
1016 | } |
1017 | |
1018 | static LogicalResult |
1019 | convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, |
1020 | llvm::DenseMap<Value, Value> &valueMapping) { |
1021 | OpBuilder::InsertionGuard g(rewriter); |
1022 | rewriter.setInsertionPoint(op); |
1023 | |
1024 | auto itA = valueMapping.find(op.getLhs()); |
1025 | auto itB = valueMapping.find(op.getRhs()); |
1026 | auto itC = valueMapping.find(op.getAcc()); |
1027 | if (itA == valueMapping.end() || itB == valueMapping.end() || |
1028 | itC == valueMapping.end()) |
1029 | return rewriter.notifyMatchFailure(op, "no mapping" ); |
1030 | Value opA = itA->second, opB = itB->second, opC = itC->second; |
1031 | Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>( |
1032 | op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), |
1033 | /*b_transpose=*/UnitAttr()); |
1034 | valueMapping[op.getResult()] = matmul; |
1035 | return success(); |
1036 | } |
1037 | |
1038 | static LogicalResult |
1039 | convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, |
1040 | llvm::DenseMap<Value, Value> &valueMapping) { |
1041 | OpBuilder::InsertionGuard g(rewriter); |
1042 | rewriter.setInsertionPoint(op); |
1043 | |
1044 | auto itA = valueMapping.find(op.getLhs()); |
1045 | auto itB = valueMapping.find(op.getRhs()); |
1046 | auto itC = valueMapping.find(op.getAcc()); |
1047 | if (itA == valueMapping.end() || itB == valueMapping.end() || |
1048 | itC == valueMapping.end()) |
1049 | return rewriter.notifyMatchFailure(op, "no mapping" ); |
1050 | Value opA = itA->second, opB = itB->second, opC = itC->second; |
1051 | int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0]; |
1052 | int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0]; |
1053 | int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1]; |
1054 | Value matmul = rewriter.create<nvgpu::MmaSyncOp>( |
1055 | op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); |
1056 | valueMapping[op.getResult()] = matmul; |
1057 | return success(); |
1058 | } |
1059 | |
1060 | /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. |
1061 | static LogicalResult |
1062 | convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, |
1063 | llvm::DenseMap<Value, Value> &valueMapping) { |
1064 | OpBuilder::InsertionGuard g(rewriter); |
1065 | rewriter.setInsertionPoint(op); |
1066 | |
1067 | assert(constantSupportsMMAMatrixType(op)); |
1068 | |
1069 | auto splat = |
1070 | cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>(); |
1071 | auto scalarConstant = |
1072 | rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); |
1073 | const char *fragType = inferFragType(op); |
1074 | auto vecType = cast<VectorType>(op.getType()); |
1075 | gpu::MMAMatrixType type = gpu::MMAMatrixType::get( |
1076 | shape: vecType.getShape(), elementType: vecType.getElementType(), operand: llvm::StringRef(fragType)); |
1077 | auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( |
1078 | op.getLoc(), type, scalarConstant); |
1079 | valueMapping[op.getResult()] = matrix; |
1080 | return success(); |
1081 | } |
1082 | |
1083 | /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. |
1084 | static LogicalResult |
1085 | convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, |
1086 | llvm::DenseMap<Value, Value> &valueMapping) { |
1087 | OpBuilder::InsertionGuard g(rewriter); |
1088 | rewriter.setInsertionPoint(op); |
1089 | |
1090 | assert(broadcastSupportsMMAMatrixType(op)); |
1091 | |
1092 | const char *fragType = inferFragType(op); |
1093 | auto vecType = op.getResultVectorType(); |
1094 | gpu::MMAMatrixType type = gpu::MMAMatrixType::get( |
1095 | shape: vecType.getShape(), elementType: vecType.getElementType(), operand: llvm::StringRef(fragType)); |
1096 | auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( |
1097 | op.getLoc(), type, op.getSource()); |
1098 | valueMapping[op.getResult()] = matrix; |
1099 | return success(); |
1100 | } |
1101 | |
1102 | // Replace ForOp with a new ForOp with extra operands. The YieldOp is not |
1103 | // updated and needs to be updated separately for the loop to be correct. |
1104 | static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, |
1105 | scf::ForOp loop, |
1106 | ValueRange newInitArgs) { |
1107 | OpBuilder::InsertionGuard g(rewriter); |
1108 | rewriter.setInsertionPoint(loop); |
1109 | |
1110 | // Create a new loop before the existing one, with the extra operands. |
1111 | rewriter.setInsertionPoint(loop); |
1112 | auto operands = llvm::to_vector<4>(loop.getInitArgs()); |
1113 | llvm::append_range(operands, newInitArgs); |
1114 | scf::ForOp newLoop = rewriter.create<scf::ForOp>( |
1115 | loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), |
1116 | operands); |
1117 | rewriter.eraseBlock(block: newLoop.getBody()); |
1118 | |
1119 | newLoop.getRegion().getBlocks().splice( |
1120 | newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); |
1121 | for (Value operand : newInitArgs) |
1122 | newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); |
1123 | |
1124 | for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( |
1125 | loop.getNumResults()))) |
1126 | rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); |
1127 | |
1128 | LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n" ); |
1129 | LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n" ); |
1130 | LLVM_DEBUG(DBGS() << "erase: " << loop); |
1131 | |
1132 | rewriter.eraseOp(op: loop); |
1133 | return newLoop; |
1134 | } |
1135 | |
1136 | static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, |
1137 | llvm::DenseMap<Value, Value> &valueMapping) { |
1138 | OpBuilder::InsertionGuard g(rewriter); |
1139 | rewriter.setInsertionPoint(op); |
1140 | |
1141 | SmallVector<Value> newOperands; |
1142 | SmallVector<std::pair<size_t, size_t>> argMapping; |
1143 | for (const auto &operand : llvm::enumerate(op.getInitArgs())) { |
1144 | auto it = valueMapping.find(operand.value()); |
1145 | if (it == valueMapping.end()) { |
1146 | LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n" ); |
1147 | continue; |
1148 | } |
1149 | argMapping.push_back(std::make_pair( |
1150 | operand.index(), op.getInitArgs().size() + newOperands.size())); |
1151 | newOperands.push_back(it->second); |
1152 | } |
1153 | |
1154 | scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands); |
1155 | Block &loopBody = *newForOp.getBody(); |
1156 | for (auto mapping : argMapping) { |
1157 | valueMapping[newForOp.getResult(mapping.first)] = |
1158 | newForOp.getResult(mapping.second); |
1159 | valueMapping[loopBody.getArgument(i: mapping.first + |
1160 | newForOp.getNumInductionVars())] = |
1161 | loopBody.getArgument(i: mapping.second + newForOp.getNumInductionVars()); |
1162 | } |
1163 | |
1164 | LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n" ); |
1165 | return success(); |
1166 | } |
1167 | |
1168 | static LogicalResult |
1169 | convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, |
1170 | llvm::DenseMap<Value, Value> &valueMapping) { |
1171 | OpBuilder::InsertionGuard g(rewriter); |
1172 | rewriter.setInsertionPoint(op); |
1173 | |
1174 | auto loop = cast<scf::ForOp>(op->getParentOp()); |
1175 | auto yieldOperands = llvm::to_vector<4>(op.getOperands()); |
1176 | for (const auto &operand : llvm::enumerate(op.getOperands())) { |
1177 | auto it = valueMapping.find(operand.value()); |
1178 | if (it == valueMapping.end()) |
1179 | continue; |
1180 | // Replace the yield of old value with the for op argument to make it easier |
1181 | // to remove the dead code. |
1182 | yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()]; |
1183 | yieldOperands.push_back(it->second); |
1184 | } |
1185 | rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands); |
1186 | |
1187 | LLVM_DEBUG(DBGS() << "erase: " << op << "\n" ); |
1188 | rewriter.eraseOp(op: op); |
1189 | return success(); |
1190 | } |
1191 | |
1192 | /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. |
1193 | static LogicalResult |
1194 | convertElementwiseOp(RewriterBase &rewriter, Operation *op, |
1195 | gpu::MMAElementwiseOp opType, |
1196 | llvm::DenseMap<Value, Value> &valueMapping) { |
1197 | OpBuilder::InsertionGuard g(rewriter); |
1198 | rewriter.setInsertionPoint(op); |
1199 | |
1200 | SmallVector<Value> matrixOperands; |
1201 | for (Value operand : op->getOperands()) { |
1202 | auto it = valueMapping.find(Val: operand); |
1203 | if (it == valueMapping.end()) |
1204 | return rewriter.notifyMatchFailure(arg&: op, msg: "no mapping" ); |
1205 | matrixOperands.push_back(Elt: it->second); |
1206 | } |
1207 | auto resultType = cast<gpu::MMAMatrixType>(Val: matrixOperands[0].getType()); |
1208 | if (opType == gpu::MMAElementwiseOp::EXTF) { |
1209 | // The floating point extension case has a different result type. |
1210 | auto vectorType = cast<VectorType>(op->getResultTypes()[0]); |
1211 | resultType = gpu::MMAMatrixType::get(shape: resultType.getShape(), |
1212 | elementType: vectorType.getElementType(), |
1213 | operand: resultType.getOperand()); |
1214 | } |
1215 | |
1216 | Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>( |
1217 | op->getLoc(), resultType, matrixOperands, opType); |
1218 | valueMapping[op->getResult(idx: 0)] = newOp; |
1219 | return success(); |
1220 | } |
1221 | |
1222 | void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, |
1223 | bool useNvGpu) { |
1224 | if (!useNvGpu) { |
1225 | patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( |
1226 | arg: patterns.getContext()); |
1227 | return; |
1228 | } |
1229 | vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); |
1230 | patterns.add<CombineTransferReadOpTranspose>(arg: patterns.getContext()); |
1231 | } |
1232 | |
1233 | LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, |
1234 | Operation *rootOp) { |
1235 | SetVector<Operation *> ops = getOpToConvert(op: rootOp, /*useNvGpu=*/false); |
1236 | llvm::DenseMap<Value, Value> valueMapping; |
1237 | |
1238 | auto globalRes = LogicalResult::success(); |
1239 | for (Operation *op : ops) { |
1240 | LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n" ); |
1241 | // Apparently callers do not want to early exit on failure here. |
1242 | auto res = LogicalResult::success(); |
1243 | if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { |
1244 | res = convertTransferReadOp(rewriter, transferRead, valueMapping); |
1245 | } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { |
1246 | res = convertTransferWriteOp(rewriter, transferWrite, valueMapping); |
1247 | } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { |
1248 | res = convertContractOp(rewriter, contractOp, valueMapping); |
1249 | } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { |
1250 | res = convertConstantOp(rewriter, constantOp, valueMapping); |
1251 | } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { |
1252 | res = convertBroadcastOp(rewriter, broadcastOp, valueMapping); |
1253 | } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { |
1254 | res = convertForOp(rewriter, forOp, valueMapping); |
1255 | } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) { |
1256 | res = convertYieldOp(rewriter, yieldOp, valueMapping); |
1257 | } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { |
1258 | res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping); |
1259 | } |
1260 | if (failed(result: res)) |
1261 | globalRes = failure(); |
1262 | } |
1263 | return globalRes; |
1264 | } |
1265 | |
1266 | LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, |
1267 | Operation *rootOp) { |
1268 | SetVector<Operation *> ops = getOpToConvert(op: rootOp, /*useNvGpu=*/true); |
1269 | llvm::DenseMap<Value, Value> valueMapping; |
1270 | for (Operation *op : ops) { |
1271 | if (llvm::TypeSwitch<Operation *, LogicalResult>(op) |
1272 | .Case(caseFn: [&](vector::TransferReadOp transferReadOp) { |
1273 | return convertTransferReadToLoads(rewriter, transferReadOp, |
1274 | valueMapping); |
1275 | }) |
1276 | .Case(caseFn: [&](vector::TransferWriteOp transferWriteOp) { |
1277 | return convertTransferWriteToStores(rewriter, transferWriteOp, |
1278 | valueMapping); |
1279 | }) |
1280 | .Case(caseFn: [&](vector::ExtractStridedSliceOp ) { |
1281 | return convertExtractStridedSlice(rewriter, extractStridedSliceOp, |
1282 | valueMapping); |
1283 | }) |
1284 | .Case(caseFn: [&](vector::ContractionOp contractionOp) { |
1285 | return convertContractOpToMmaSync(rewriter, contractionOp, |
1286 | valueMapping); |
1287 | }) |
1288 | .Case(caseFn: [&](scf::ForOp forOp) { |
1289 | return convertForOp(rewriter, forOp, valueMapping); |
1290 | }) |
1291 | .Case(caseFn: [&](scf::YieldOp yieldOp) { |
1292 | return convertYieldOp(rewriter, yieldOp, valueMapping); |
1293 | }) |
1294 | .Case(caseFn: [&](arith::ConstantOp constOp) { |
1295 | return convertConstantOpMmaSync(rewriter, constOp, valueMapping); |
1296 | }) |
1297 | .Default(defaultFn: [&](Operation *op) { |
1298 | return op->emitError() << "unhandled vector to mma type: " << *op; |
1299 | }) |
1300 | .failed()) { |
1301 | return op->emitOpError() |
1302 | << "failed to convert op during vector-to-nvgpu conversion" ; |
1303 | } |
1304 | } |
1305 | return success(); |
1306 | } |
1307 | |
1308 | namespace { |
1309 | |
1310 | struct ConvertVectorToGPUPass |
1311 | : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> { |
1312 | |
1313 | explicit ConvertVectorToGPUPass(bool useNvGpu_) { |
1314 | useNvGpu.setValue(useNvGpu_); |
1315 | } |
1316 | |
1317 | void runOnOperation() override { |
1318 | RewritePatternSet patterns(&getContext()); |
1319 | populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); |
1320 | if (failed( |
1321 | applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) |
1322 | return signalPassFailure(); |
1323 | |
1324 | IRRewriter rewriter(&getContext()); |
1325 | if (useNvGpu) { |
1326 | if (failed( |
1327 | convertVectorToNVVMCompatibleMMASync(rewriter, getOperation()))) |
1328 | return signalPassFailure(); |
1329 | return; |
1330 | } |
1331 | (void)convertVectorToMMAOps(rewriter, getOperation()); |
1332 | } |
1333 | }; |
1334 | |
1335 | } // namespace |
1336 | |
1337 | std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) { |
1338 | return std::make_unique<ConvertVectorToGPUPass>(args&: useNvGpu); |
1339 | } |
1340 | |