1//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to GPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
14
15#include <type_traits>
16
17#include "mlir/Analysis/SliceAnalysis.h"
18#include "mlir/Dialect/Affine/IR/AffineOps.h"
19#include "mlir/Dialect/Arith/IR/Arith.h"
20#include "mlir/Dialect/GPU/IR/GPUDialect.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
26#include "mlir/Dialect/Vector/IR/VectorOps.h"
27#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
28#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
29#include "mlir/IR/Builders.h"
30#include "mlir/IR/BuiltinOps.h"
31#include "mlir/IR/Region.h"
32#include "mlir/Pass/Pass.h"
33#include "mlir/Support/LogicalResult.h"
34#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35#include "mlir/Transforms/Passes.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38
39#define DEBUG_TYPE "vector-to-gpu"
40#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41#define DBGSNL() (llvm::dbgs() << "\n")
42
43namespace mlir {
44#define GEN_PASS_DEF_CONVERTVECTORTOGPU
45#include "mlir/Conversion/Passes.h.inc"
46} // namespace mlir
47
48using namespace mlir;
49
50/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51/// AffineMap representing offsets to apply to indices, the function fills
52/// `indices` with the original indices plus the offsets. The offsets are
53/// applied by taking into account the permutation map of the transfer op. If
54/// the `offsetMap` has dimension placeholders, those should be provided in
55/// `dimValues`.
56template <typename TransferOpType>
57static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58 AffineMap offsetMap, ArrayRef<Value> dimValues,
59 SmallVector<Value, 4> &indices) {
60 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61 Location loc = xferOp.getLoc();
62 unsigned offsetsIdx = 0;
63 for (auto expr : xferOp.getPermutationMap().getResults()) {
64 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65 Value prevIdx = indices[dim.getPosition()];
66 SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end());
67 dims.push_back(prevIdx);
68 AffineExpr d0 = rewriter.getAffineDimExpr(position: offsetMap.getNumDims());
69 indices[dim.getPosition()] = affine::makeComposedAffineApply(
70 rewriter, loc, d0 + offsetMap.getResult(idx: offsetsIdx++), dims);
71 continue;
72 }
73 }
74}
75
76// Return true if the contract op can be convert to MMA matmul.
77static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78 bool useNvGpu) {
79 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80 auto infer = [&](MapList m) {
81 return AffineMap::inferFromExprList(m, contract.getContext());
82 };
83 AffineExpr m, n, k;
84 bindDims(contract.getContext(), m, n, k);
85 auto iteratorTypes = contract.getIteratorTypes().getValue();
86 if (!(vector::isParallelIterator(attr: iteratorTypes[0]) &&
87 vector::isParallelIterator(attr: iteratorTypes[1]) &&
88 vector::isReductionIterator(attr: iteratorTypes[2])))
89 return false;
90
91 // The contract needs to represent a matmul to be able to convert to
92 // MMAMatrix matmul.
93 if (!useNvGpu &&
94 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
95 return false;
96 if (useNvGpu &&
97 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
98 return false;
99
100 return true;
101}
102
103// Return true if the given map represents a transposed matrix load,
104// i.e. (d0, d1, ...) -> (dn-1, dn-2).
105static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
106 MLIRContext *ctx = permutationMap.getContext();
107 // Local OpBuilder is fine here, we just build attributes.
108 OpBuilder b(ctx);
109 auto nDim = permutationMap.getNumDims();
110 AffineExpr zero = b.getAffineConstantExpr(constant: 0);
111 if (nDim < 2) {
112 // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
113 AffineExpr dim0 = b.getAffineDimExpr(position: 0);
114 return permutationMap == AffineMap::get(dimCount: 1, symbolCount: 0, results: {dim0, zero}, context: ctx);
115 }
116
117 AffineExpr innerDim = b.getAffineDimExpr(position: nDim - 1);
118 AffineExpr outerDim = b.getAffineDimExpr(position: nDim - 2);
119 // Support both transposed and transposed+broadcasted cases.
120 return permutationMap == AffineMap::get(dimCount: nDim, symbolCount: 0, results: {innerDim, outerDim}, context: ctx) ||
121 permutationMap == AffineMap::get(dimCount: nDim, symbolCount: 0, results: {innerDim, zero}, context: ctx);
122}
123
124// Return the stide for the second-to-last dimension of |type| if it is a memref
125// and has a constant stride.
126static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
127 auto memrefType = dyn_cast<MemRefType>(type);
128 if (!memrefType)
129 return false;
130 // If the memref is 0 or 1D the horizontal stride is 0.
131 if (memrefType.getRank() < 2)
132 return 0;
133 int64_t offset = 0;
134 SmallVector<int64_t, 2> strides;
135 if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
136 strides.back() != 1)
137 return std::nullopt;
138 int64_t stride = strides[strides.size() - 2];
139 if (stride == ShapedType::kDynamic)
140 return std::nullopt;
141 return stride;
142}
143
144// Return true if the transfer op can be converted to a MMA matrix load.
145static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
146 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147 readOp.getVectorType().getRank() != 2)
148 return false;
149 if (!getStaticallyKnownRowStride(readOp.getShapedType()))
150 return false;
151
152 // Only allow integer types if the signedness can be inferred.
153 if (readOp.getVectorType().getElementType().isInteger(8))
154 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155 !isa<arith::ExtUIOp>(*readOp->user_begin())))
156 return false;
157
158 AffineMap map = readOp.getPermutationMap();
159 MLIRContext *ctx = readOp.getContext();
160 AffineExpr innerDim = getAffineDimExpr(position: map.getNumDims() - 1, context: ctx);
161 AffineExpr zero = getAffineConstantExpr(constant: 0, context: ctx);
162 auto broadcastInnerDim =
163 AffineMap::get(dimCount: map.getNumDims(), symbolCount: 0, results: {zero, innerDim}, context: ctx);
164 return map.isMinorIdentity() || map == broadcastInnerDim ||
165 isTransposeMatrixLoadMap(permutationMap: map);
166}
167
168// Return true if the transfer op can be converted to a MMA matrix store.
169static bool
170transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
171 // TODO: support 0-d corner case.
172 if (writeOp.getTransferRank() == 0)
173 return false;
174
175 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176 writeOp.getVectorType().getRank() != 2)
177 return false;
178 if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
179 return false;
180 // TODO: Support transpose once it is added to GPU dialect ops.
181 if (!writeOp.getPermutationMap().isMinorIdentity())
182 return false;
183 return true;
184}
185
186/// Return true if the constant is a splat to a 2D vector so that it can be
187/// converted to a MMA constant matrix op.
188static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
189 auto vecType = dyn_cast<VectorType>(constantOp.getType());
190 if (!vecType || vecType.getRank() != 2)
191 return false;
192 return isa<SplatElementsAttr>(constantOp.getValue());
193}
194
195/// Return true if this is a broadcast from scalar to a 2D vector.
196static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
197 return broadcastOp.getResultVectorType().getRank() == 2;
198}
199
200/// Return true if this integer extend op can be folded into a contract op.
201template <typename ExtOpTy>
202static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203 if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
204 return false;
205 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
206}
207
208static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
209
210/// Return the MMA elementwise enum associated with `op` if it is supported.
211/// Return `std::nullopt` otherwise.
212static std::optional<gpu::MMAElementwiseOp>
213convertElementwiseOpToMMA(Operation *op) {
214 if (isa<arith::AddFOp>(op))
215 return gpu::MMAElementwiseOp::ADDF;
216 if (isa<arith::MulFOp>(op))
217 return gpu::MMAElementwiseOp::MULF;
218 if (isa<arith::SubFOp>(op))
219 return gpu::MMAElementwiseOp::SUBF;
220 if (isa<arith::MaximumFOp>(op))
221 return gpu::MMAElementwiseOp::MAXF;
222 if (isa<arith::MinimumFOp>(op))
223 return gpu::MMAElementwiseOp::MINF;
224 if (isa<arith::DivFOp>(op))
225 return gpu::MMAElementwiseOp::DIVF;
226 if (isa<arith::AddIOp>(op))
227 return gpu::MMAElementwiseOp::ADDI;
228 if (isa<arith::MulIOp>(op))
229 return gpu::MMAElementwiseOp::MULI;
230 if (isa<arith::SubIOp>(op))
231 return gpu::MMAElementwiseOp::SUBI;
232 if (isa<arith::DivSIOp>(op))
233 return gpu::MMAElementwiseOp::DIVS;
234 if (isa<arith::DivUIOp>(op))
235 return gpu::MMAElementwiseOp::DIVU;
236 if (isa<arith::NegFOp>(op))
237 return gpu::MMAElementwiseOp::NEGATEF;
238 if (isa<arith::ExtFOp>(op))
239 return gpu::MMAElementwiseOp::EXTF;
240 return std::nullopt;
241}
242
243/// Return true if the op is supported as elementwise op on MMAMatrix type.
244static bool elementwiseSupportsMMAMatrixType(Operation *op) {
245 return convertElementwiseOpToMMA(op).has_value();
246}
247
248/// Returns true if the extract strided slice op is supported with `mma.sync`
249/// path.
250static bool
251extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
252
253 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
254 nvgpu::getWarpMatrixInfo(op);
255 if (failed(warpMatrixInfo))
256 return false;
257
258 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
259 if (failed(contractOp))
260 return false;
261
262 // Handle vector.extract_strided_slice on registers containing
263 // matrixB and matrixC operands. vector.extract_strided_slice op
264 // is not supported on registers containing matrixA operands.
265 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
266 return (cast<VectorType>(op->getResult(0).getType()) ==
267 cast<VectorType>((*contractOp).getRhs().getType()));
268 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
269 return (cast<VectorType>(op->getResult(0).getType()) ==
270 cast<VectorType>((*contractOp).getAcc().getType()));
271
272 return false;
273}
274
275static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
276 if (isa<scf::ForOp, scf::YieldOp>(op))
277 return true;
278 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
279 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
280 : transferReadSupportsMMAMatrixType(transferRead);
281 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
282 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
283 : transferWriteSupportsMMAMatrixType(transferWrite);
284 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
285 return useNvGpu &&
286 extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
287 if (auto contract = dyn_cast<vector::ContractionOp>(op))
288 return contractSupportsMMAMatrixType(contract, useNvGpu);
289 if (auto constant = dyn_cast<arith::ConstantOp>(op))
290 return constantSupportsMMAMatrixType(constant);
291 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
292 return broadcastSupportsMMAMatrixType(broadcast);
293 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
294 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
295 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
296 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
297 if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
298 return fpExtendSupportsMMAMatrixType(fpExtend);
299 return elementwiseSupportsMMAMatrixType(op);
300}
301
302/// Return an unsorted slice handling scf.for region differently than
303/// `getSlice`. In scf.for we only want to include as part of the slice elements
304/// that are part of the use/def chain.
305static SetVector<Operation *>
306getSliceContract(Operation *op,
307 const BackwardSliceOptions &backwardSliceOptions,
308 const ForwardSliceOptions &forwardSliceOptions) {
309 SetVector<Operation *> slice;
310 slice.insert(X: op);
311 unsigned currentIndex = 0;
312 SetVector<Operation *> backwardSlice;
313 SetVector<Operation *> forwardSlice;
314 while (currentIndex != slice.size()) {
315 auto *currentOp = (slice)[currentIndex];
316 // Compute and insert the backwardSlice starting from currentOp.
317 backwardSlice.clear();
318 getBackwardSlice(op: currentOp, backwardSlice: &backwardSlice, options: backwardSliceOptions);
319 slice.insert(Start: backwardSlice.begin(), End: backwardSlice.end());
320
321 // Compute and insert the forwardSlice starting from currentOp.
322 forwardSlice.clear();
323 // Special case for ForOp, we don't want to include the whole region but
324 // only the value using the region arguments.
325 // TODO: We should refine this to only care about the region arguments being
326 // converted to matrix type.
327 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
328 for (Value forOpResult : forOp.getResults())
329 getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
330 for (BlockArgument &arg : forOp.getRegionIterArgs())
331 getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
332 } else {
333 getForwardSlice(op: currentOp, forwardSlice: &forwardSlice, options: forwardSliceOptions);
334 }
335 slice.insert(Start: forwardSlice.begin(), End: forwardSlice.end());
336 ++currentIndex;
337 }
338 return slice;
339}
340
341// Analyze slice of operations based on convert op to figure out if the whole
342// slice can be converted to MMA operations.
343static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
344 bool useNvGpu) {
345 auto hasVectorDest = [](Operation *op) {
346 return llvm::any_of(Range: op->getResultTypes(), P: llvm::IsaPred<VectorType>);
347 };
348 BackwardSliceOptions backwardSliceOptions;
349 backwardSliceOptions.filter = hasVectorDest;
350
351 auto hasVectorSrc = [](Operation *op) {
352 return llvm::any_of(Range: op->getOperandTypes(), P: llvm::IsaPred<VectorType>);
353 };
354 ForwardSliceOptions forwardSliceOptions;
355 forwardSliceOptions.filter = hasVectorSrc;
356
357 SetVector<Operation *> opToConvert;
358 op->walk([&](vector::ContractionOp contract) {
359 if (opToConvert.contains(key: contract.getOperation()))
360 return;
361 SetVector<Operation *> dependentOps =
362 getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
363 // If any instruction cannot use MMA matrix type drop the whole
364 // chain. MMA matrix are stored in an opaque type so they cannot be used
365 // by all operations.
366 if (llvm::any_of(Range&: dependentOps, P: [useNvGpu](Operation *op) {
367 if (!supportsMMaMatrixType(op, useNvGpu)) {
368 LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
369 return true;
370 }
371 return false;
372 }))
373 return;
374
375 opToConvert.insert(Start: dependentOps.begin(), End: dependentOps.end());
376 });
377 // Sort the operations so that we can convert them in topological order.
378 return topologicalSort(toSort: opToConvert);
379}
380
381namespace {
382// Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
383// to MMA matmul.
384struct PrepareContractToGPUMMA
385 : public OpRewritePattern<vector::ContractionOp> {
386 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
387
388 LogicalResult matchAndRewrite(vector::ContractionOp op,
389 PatternRewriter &rewriter) const override {
390 Location loc = op.getLoc();
391 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
392
393 // Set up the parallel/reduction structure in right form.
394 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
395 auto infer = [&](MapList m) {
396 return AffineMap::inferFromExprList(m, op.getContext());
397 };
398 AffineExpr m, n, k;
399 bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k);
400 static constexpr std::array<int64_t, 2> perm = {1, 0};
401 auto iteratorTypes = op.getIteratorTypes().getValue();
402 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
403 if (!(vector::isParallelIterator(attr: iteratorTypes[0]) &&
404 vector::isParallelIterator(attr: iteratorTypes[1]) &&
405 vector::isReductionIterator(attr: iteratorTypes[2])))
406 return rewriter.notifyMatchFailure(op, "not a gemm contraction");
407 //
408 // Two outer parallel, one inner reduction (matmat flavor).
409 //
410 // This is the classical row-major matmul, nothing to do.
411 if (maps == infer({{m, k}, {k, n}, {m, n}}))
412 return rewriter.notifyMatchFailure(op, "contraction already prepared");
413 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
414 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
415 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
416 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
417 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
418 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
419 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
420 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
421 std::swap(a&: rhs, b&: lhs);
422 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
423 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
424 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
425 std::swap(a&: rhs, b&: lhs);
426 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
427 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
428 std::swap(a&: lhs, b&: rhs);
429 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
430 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
431 std::swap(a&: lhs, b&: rhs);
432 } else {
433 // TODO: llvm_unreachable ?
434 return rewriter.notifyMatchFailure(op, "unexpected contraction case");
435 }
436 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
437 op, lhs, rhs, res,
438 rewriter.getAffineMapArrayAttr(values: infer({{m, k}, {k, n}, {m, n}})),
439 op.getIteratorTypes());
440 return success();
441 }
442};
443
444// Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports
445// row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
446// respectively. We can fold the transpose operation when loading the data from
447// Shared Memory to registers.
448struct CombineTransferReadOpTranspose final
449 : public OpRewritePattern<vector::TransposeOp> {
450 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
451
452 LogicalResult matchAndRewrite(vector::TransposeOp op,
453 PatternRewriter &rewriter) const override {
454 // Look through integer extend ops.
455 Value source = op.getVector();
456 Type resultType = op.getType();
457 Operation *extOp;
458 if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
459 (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
460 (extOp = source.getDefiningOp<arith::ExtFOp>())) {
461 source = extOp->getOperand(idx: 0);
462 resultType =
463 VectorType::get(cast<VectorType>(resultType).getShape(),
464 cast<VectorType>(source.getType()).getElementType());
465 }
466
467 auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
468 if (!transferReadOp)
469 return rewriter.notifyMatchFailure(op, "no transfer read");
470
471 // TODO: support 0-d corner case.
472 if (transferReadOp.getTransferRank() == 0)
473 return rewriter.notifyMatchFailure(op, "0-D transfer read");
474
475 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
476 return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
477
478 AffineMap permutationMap =
479 AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
480 AffineMap newMap =
481 permutationMap.compose(transferReadOp.getPermutationMap());
482
483 auto loc = op.getLoc();
484 Value result =
485 rewriter
486 .create<vector::TransferReadOp>(
487 loc, resultType, transferReadOp.getSource(),
488 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
489 transferReadOp.getPadding(), transferReadOp.getMask(),
490 transferReadOp.getInBoundsAttr())
491 .getResult();
492
493 // Fuse through the integer extend op.
494 if (extOp) {
495 if (isa<arith::ExtSIOp>(extOp))
496 result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
497 .getResult();
498 else if (isa<arith::ExtUIOp>(extOp))
499 result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
500 .getResult();
501 else
502 result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
503 .getResult();
504 }
505
506 rewriter.replaceOp(op, result);
507 return success();
508 }
509};
510
511} // namespace
512
513// MMA types have different layout based on how they are used in matmul ops.
514// Figure the right layout to use by looking at op uses.
515// TODO: Change the GPU dialect to abstract the layout at the this level and
516// only care about it during lowering to NVVM.
517static const char *inferFragType(Operation *op) {
518 for (Operation *users : op->getUsers()) {
519 auto contract = dyn_cast<vector::ContractionOp>(users);
520 if (!contract)
521 continue;
522 assert(op->getNumResults() == 1);
523 if (contract.getLhs() == op->getResult(idx: 0))
524 return "AOp";
525 if (contract.getRhs() == op->getResult(idx: 0))
526 return "BOp";
527 }
528 return "COp";
529}
530
531static LogicalResult
532convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
533 llvm::DenseMap<Value, Value> &valueMapping) {
534 OpBuilder::InsertionGuard g(rewriter);
535 rewriter.setInsertionPoint(op);
536
537 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
538 assert(transferReadSupportsMMAMatrixType(op) &&
539 "expected convertible operation");
540
541 std::optional<int64_t> stride =
542 getStaticallyKnownRowStride(op.getShapedType());
543 if (!stride.has_value()) {
544 LLVM_DEBUG(DBGS() << "no stride\n");
545 return rewriter.notifyMatchFailure(op, "no stride");
546 }
547
548 AffineMap map = op.getPermutationMap();
549 bool isTranspose = isTransposeMatrixLoadMap(permutationMap: map);
550
551 // Handle broadcast by setting the stride to 0.
552 if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
553 assert(cstExpr.getValue() == 0);
554 stride = 0;
555 }
556
557 Value mappingResult = op.getResult();
558 auto elType = op.getVectorType().getElementType();
559 const char *fragType = inferFragType(op);
560 if (op->hasOneUse()) {
561 auto *user = *op->user_begin();
562 // Infer the signedness of the mma type from the integer extend.
563 bool isSignedExtend = isa<arith::ExtSIOp>(user);
564 if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
565 elType = IntegerType::get(
566 op.getContext(), cast<IntegerType>(elType).getWidth(),
567 isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
568 mappingResult = user->getResult(0);
569 fragType = inferFragType(user);
570 }
571 }
572 gpu::MMAMatrixType type =
573 gpu::MMAMatrixType::get(shape: op.getVectorType().getShape(), elementType: elType, operand: fragType);
574 Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
575 op.getLoc(), type, op.getSource(), op.getIndices(),
576 rewriter.getIndexAttr(*stride),
577 isTranspose ? rewriter.getUnitAttr() : UnitAttr());
578 valueMapping[mappingResult] = load;
579
580 LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
581 return success();
582}
583
584static LogicalResult
585convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
586 llvm::DenseMap<Value, Value> &valueMapping) {
587 OpBuilder::InsertionGuard g(rewriter);
588 rewriter.setInsertionPoint(op);
589
590 assert(transferWriteSupportsMMAMatrixType(op));
591 std::optional<int64_t> stride =
592 getStaticallyKnownRowStride(op.getShapedType());
593 if (!stride.has_value()) {
594 LLVM_DEBUG(DBGS() << "no stride\n");
595 return rewriter.notifyMatchFailure(op, "no stride");
596 }
597
598 auto it = valueMapping.find(op.getVector());
599 if (it == valueMapping.end()) {
600 LLVM_DEBUG(DBGS() << "no mapping\n");
601 return rewriter.notifyMatchFailure(op, "no mapping");
602 }
603
604 Value matrix = it->second;
605 auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
606 op.getLoc(), matrix, op.getSource(), op.getIndices(),
607 rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
608 (void)store;
609
610 LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
611
612 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
613 rewriter.eraseOp(op: op);
614 return success();
615}
616
617/// Returns the vector type which represents a matrix fragment.
618static VectorType
619getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
620 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
621 regInfo.elementsPerRegister};
622 Type elType = regInfo.registerLLVMType;
623 if (auto vecType = dyn_cast<VectorType>(elType))
624 elType = vecType.getElementType();
625 return VectorType::get(shape, elType);
626}
627
628/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
629static LogicalResult
630convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
631 llvm::DenseMap<Value, Value> &valueMapping) {
632 OpBuilder::InsertionGuard g(rewriter);
633 rewriter.setInsertionPoint(op);
634
635 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
636 nvgpu::getWarpMatrixInfo(op: op);
637 if (failed(result: warpMatrixInfo)) {
638 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
639 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
640 }
641
642 FailureOr<nvgpu::FragmentElementInfo> regInfo =
643 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
644 if (failed(result: regInfo)) {
645 LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
646 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
647 }
648
649 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
650 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
651 if (!dense) {
652 LLVM_DEBUG(DBGS() << "not a splat\n");
653 return rewriter.notifyMatchFailure(op, "not a splat");
654 }
655
656 Value result = rewriter.create<arith::ConstantOp>(
657 op.getLoc(), vectorType,
658 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
659 valueMapping[op.getResult()] = result;
660 return success();
661}
662
663/// Check if the loaded matrix operand requires transposed.
664/// Transposed Map Example:
665/// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
666/// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
667/// The code below checks if the output 2D is transposed using a generalized
668/// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
669/// Returns : true; if m > n, false o.w.
670static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
671 mlir::AffineMap map = op.getPermutationMap();
672
673 if (map.getNumResults() != 2) {
674 LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
675 "is not a 2d operand\n");
676 return failure();
677 }
678
679 // Output 2D matrix dimensions in the order of d0, d1.
680 mlir::AffineExpr dM = map.getResult(idx: 0);
681 mlir::AffineExpr dN = map.getResult(idx: 1);
682
683 // Find the position of these expressions in the input.
684 auto exprM = dyn_cast<AffineDimExpr>(Val&: dM);
685 auto exprN = dyn_cast<AffineDimExpr>(Val&: dN);
686
687 if (!exprM || !exprN) {
688 LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
689 "expressions, then transpose cannot be determined.\n");
690 return failure();
691 }
692
693 return exprM.getPosition() > exprN.getPosition();
694}
695
696static LogicalResult
697creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
698 llvm::DenseMap<Value, Value> &valueMapping) {
699 OpBuilder::InsertionGuard g(rewriter);
700 rewriter.setInsertionPoint(op);
701 Location loc = op->getLoc();
702
703 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
704 nvgpu::getWarpMatrixInfo(op: op);
705 if (failed(result: warpMatrixInfo)) {
706 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
707 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
708 }
709
710 FailureOr<nvgpu::FragmentElementInfo> regInfo =
711 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
712 if (failed(result: regInfo)) {
713 LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
714 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
715 }
716
717 FailureOr<bool> transpose = isTransposed(op);
718 if (failed(result: transpose)) {
719 LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
720 return rewriter.notifyMatchFailure(
721 op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
722 }
723
724 FailureOr<nvgpu::LdMatrixParams> params =
725 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
726
727 if (failed(params)) {
728 LLVM_DEBUG(
729 DBGS()
730 << "failed to convert vector.transfer_read to ldmatrix. "
731 << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
732 return rewriter.notifyMatchFailure(
733 op, "failed to convert vector.transfer_read to ldmatrix; this op "
734 "likely should not be converted to a nvgpu.ldmatrix call.");
735 }
736
737 // Adjust the load offset.
738 auto laneId = rewriter.create<gpu::LaneIdOp>(loc);
739 FailureOr<AffineMap> offsets =
740 nvgpu::getLaneIdToLdMatrixMatrixCoord(builder&: rewriter, loc, params: *params);
741 if (failed(result: offsets)) {
742 LLVM_DEBUG(DBGS() << "no offsets\n");
743 return rewriter.notifyMatchFailure(op, "no offsets");
744 }
745
746 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
747
748 SmallVector<Value, 4> indices;
749 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
750 indices);
751
752 nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
753 loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
754 valueMapping[op] = newOp->getResult(0);
755 return success();
756}
757
758static LogicalResult
759createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
760 llvm::DenseMap<Value, Value> &valueMapping) {
761 OpBuilder::InsertionGuard g(rewriter);
762 rewriter.setInsertionPoint(op);
763
764 Location loc = op.getLoc();
765 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
766 nvgpu::getWarpMatrixInfo(op: op);
767 if (failed(result: warpMatrixInfo))
768 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
769 FailureOr<nvgpu::FragmentElementInfo> regInfo =
770 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
771 if (failed(result: regInfo)) {
772 return rewriter.notifyMatchFailure(
773 op, "Failed to deduce register fragment type during "
774 "conversion to distributed non-ldmatrix compatible load");
775 }
776
777 Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
778 SmallVector<Value, 4> elements;
779
780 // This is the individual element type.
781 Type loadedElType = regInfo->registerLLVMType;
782 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
783
784 Value fill = rewriter.create<arith::ConstantOp>(
785 op.getLoc(), vectorType.getElementType(),
786 rewriter.getZeroAttr(vectorType.getElementType()));
787 Value result =
788 rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
789
790 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
791
792 // If we are not transposing, then we can use vectorized loads. Otherwise, we
793 // must load each element individually.
794 if (!isTransposeLoad) {
795 if (!isa<VectorType>(Val: loadedElType)) {
796 loadedElType = VectorType::get({1}, loadedElType);
797 }
798
799 for (int i = 0; i < vectorType.getShape()[0]; i++) {
800 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
801 builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo);
802 if (failed(result: coords))
803 return rewriter.notifyMatchFailure(op, "no coords");
804
805 Value logicalValueId = rewriter.create<arith::ConstantOp>(
806 loc, rewriter.getIndexType(),
807 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
808 SmallVector<Value, 4> newIndices;
809 getXferIndices<vector::TransferReadOp>(
810 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
811
812 Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
813 op.getSource(), newIndices);
814 result = rewriter.create<vector::InsertOp>(loc, el, result, i);
815 }
816 } else {
817 if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
818 loadedElType = vecType.getElementType();
819 }
820 for (int i = 0; i < vectorType.getShape()[0]; i++) {
821 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
822 innerIdx++) {
823
824 Value logicalValueId = rewriter.create<arith::ConstantOp>(
825 loc, rewriter.getIndexType(),
826 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
827 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
828 builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo);
829 if (failed(result: coords))
830 return rewriter.notifyMatchFailure(op, "no coords");
831
832 SmallVector<Value, 4> newIndices;
833 getXferIndices<vector::TransferReadOp>(
834 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
835 Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
836 op.getSource(), newIndices);
837 result = rewriter.create<vector::InsertOp>(
838 op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
839 }
840 }
841 }
842
843 valueMapping[op.getResult()] = result;
844 return success();
845}
846
847/// Return true if this is a shared memory memref type.
848static bool isSharedMemory(MemRefType type) {
849 auto addressSpace =
850 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
851 return addressSpace &&
852 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
853}
854
855/// Converts a `vector.transfer_read` operation directly to either a
856/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
857/// used when converting to `nvgpu.mma.sync` operations.
858static LogicalResult
859convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
860 llvm::DenseMap<Value, Value> &valueMapping) {
861 OpBuilder::InsertionGuard g(rewriter);
862 rewriter.setInsertionPoint(op);
863
864 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
865 nvgpu::getWarpMatrixInfo(op: op);
866 if (failed(result: warpMatrixInfo))
867 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
868
869 bool isLdMatrixCompatible =
870 isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
871 nvgpu::inferTileWidthInBits(type: *warpMatrixInfo) == 128;
872
873 VectorType vecTy = op.getVectorType();
874 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
875
876 // When we are transposing the B operand, ldmatrix will only work if we have
877 // at least 8 rows to read and the width to read for the transpose is 128
878 // bits.
879 if (!op.getPermutationMap().isMinorIdentity() &&
880 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
881 vecTy.getDimSize(0) * bitWidth < 128))
882 isLdMatrixCompatible = false;
883
884 if (!isLdMatrixCompatible)
885 return createNonLdMatrixLoads(rewriter, op, valueMapping);
886
887 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
888}
889
890static LogicalResult
891convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
892 llvm::DenseMap<Value, Value> &valueMapping) {
893 OpBuilder::InsertionGuard g(rewriter);
894 rewriter.setInsertionPoint(op);
895
896 Location loc = op->getLoc();
897 auto it = valueMapping.find(op.getVector());
898 if (it == valueMapping.end())
899 return rewriter.notifyMatchFailure(op, "no mapping");
900 Value matrix = it->second;
901
902 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
903 nvgpu::getWarpMatrixInfo(op: op);
904 if (failed(result: warpMatrixInfo))
905 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
906 FailureOr<nvgpu::FragmentElementInfo> regInfo =
907 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
908 if (failed(result: regInfo))
909 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
910
911 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
912 Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
913
914 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
915 Value logicalValueId = rewriter.create<arith::ConstantOp>(
916 loc, rewriter.getIndexType(),
917 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
918 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
919 builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo);
920 if (failed(result: coords))
921 return rewriter.notifyMatchFailure(op, "no coords");
922
923 Value el =
924 rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
925 SmallVector<Value, 4> newIndices;
926 getXferIndices<vector::TransferWriteOp>(
927 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
928 rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
929 }
930
931 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
932 rewriter.eraseOp(op: op);
933 return success();
934}
935
936static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
937 SmallVectorImpl<int64_t> &results) {
938 for (auto attr : arrayAttr)
939 results.push_back(cast<IntegerAttr>(attr).getInt());
940}
941
942static LogicalResult
943convertExtractStridedSlice(RewriterBase &rewriter,
944 vector::ExtractStridedSliceOp op,
945 llvm::DenseMap<Value, Value> &valueMapping) {
946 OpBuilder::InsertionGuard g(rewriter);
947 rewriter.setInsertionPoint(op);
948
949 Location loc = op->getLoc();
950
951 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
952 nvgpu::getWarpMatrixInfo(op: op);
953 if (failed(result: warpMatrixInfo))
954 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
955
956 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
957 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
958 if (failed(result: mmaSyncFragmentInfo))
959 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
960
961 // Find the vector.transer_read whose result vector is being sliced.
962 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
963 if (!transferReadOp)
964 return rewriter.notifyMatchFailure(op, "no transfer read");
965
966 warpMatrixInfo = nvgpu::getWarpMatrixInfo(op: transferReadOp);
967 if (failed(result: warpMatrixInfo))
968 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
969
970 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
971 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
972 if (failed(result: ldFragmentInfo))
973 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
974
975 assert(
976 (mmaSyncFragmentInfo->elementsPerRegister ==
977 ldFragmentInfo->elementsPerRegister) &&
978 "Number of elements per register should be same for load and mma.sync");
979
980 // Create vector.extract_strided_slice op for thread-owned fragments.
981 std::array<int64_t, 2> strides = {1,
982 1}; // stride for extract slice is always 1.
983 std::array<int64_t, 2> sliceShape = {
984 mmaSyncFragmentInfo->numRegistersPerFragment,
985 mmaSyncFragmentInfo->elementsPerRegister};
986 auto it = valueMapping.find(transferReadOp);
987 if (it == valueMapping.end())
988 return rewriter.notifyMatchFailure(op, "no mapping");
989 auto sourceVector = it->second;
990
991 // offset and sizes at warp-level of onwership.
992 SmallVector<int64_t> offsets;
993 populateFromInt64AttrArray(op.getOffsets(), offsets);
994
995 SmallVector<int64_t> sizes;
996 populateFromInt64AttrArray(op.getSizes(), sizes);
997 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
998
999 // Compute offset in vector registers. Note that the mma.sync vector registers
1000 // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1001 // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1002 std::array<int64_t, 2> sliceOffset = {0, 0};
1003
1004 if (offsets[0] && offsets[1])
1005 return op->emitError() << "Slicing fragments in 2D is not supported. ";
1006 if (offsets[0])
1007 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1008 else if (offsets[1])
1009 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1010
1011 Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1012 loc, sourceVector, sliceOffset, sliceShape, strides);
1013
1014 valueMapping[op] = newOp;
1015 return success();
1016}
1017
1018static LogicalResult
1019convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1020 llvm::DenseMap<Value, Value> &valueMapping) {
1021 OpBuilder::InsertionGuard g(rewriter);
1022 rewriter.setInsertionPoint(op);
1023
1024 auto itA = valueMapping.find(op.getLhs());
1025 auto itB = valueMapping.find(op.getRhs());
1026 auto itC = valueMapping.find(op.getAcc());
1027 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1028 itC == valueMapping.end())
1029 return rewriter.notifyMatchFailure(op, "no mapping");
1030 Value opA = itA->second, opB = itB->second, opC = itC->second;
1031 Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1032 op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1033 /*b_transpose=*/UnitAttr());
1034 valueMapping[op.getResult()] = matmul;
1035 return success();
1036}
1037
1038static LogicalResult
1039convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1040 llvm::DenseMap<Value, Value> &valueMapping) {
1041 OpBuilder::InsertionGuard g(rewriter);
1042 rewriter.setInsertionPoint(op);
1043
1044 auto itA = valueMapping.find(op.getLhs());
1045 auto itB = valueMapping.find(op.getRhs());
1046 auto itC = valueMapping.find(op.getAcc());
1047 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1048 itC == valueMapping.end())
1049 return rewriter.notifyMatchFailure(op, "no mapping");
1050 Value opA = itA->second, opB = itB->second, opC = itC->second;
1051 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1052 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1053 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1054 Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1055 op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1056 valueMapping[op.getResult()] = matmul;
1057 return success();
1058}
1059
1060/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1061static LogicalResult
1062convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1063 llvm::DenseMap<Value, Value> &valueMapping) {
1064 OpBuilder::InsertionGuard g(rewriter);
1065 rewriter.setInsertionPoint(op);
1066
1067 assert(constantSupportsMMAMatrixType(op));
1068
1069 auto splat =
1070 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1071 auto scalarConstant =
1072 rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1073 const char *fragType = inferFragType(op);
1074 auto vecType = cast<VectorType>(op.getType());
1075 gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1076 shape: vecType.getShape(), elementType: vecType.getElementType(), operand: llvm::StringRef(fragType));
1077 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1078 op.getLoc(), type, scalarConstant);
1079 valueMapping[op.getResult()] = matrix;
1080 return success();
1081}
1082
1083/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1084static LogicalResult
1085convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1086 llvm::DenseMap<Value, Value> &valueMapping) {
1087 OpBuilder::InsertionGuard g(rewriter);
1088 rewriter.setInsertionPoint(op);
1089
1090 assert(broadcastSupportsMMAMatrixType(op));
1091
1092 const char *fragType = inferFragType(op);
1093 auto vecType = op.getResultVectorType();
1094 gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1095 shape: vecType.getShape(), elementType: vecType.getElementType(), operand: llvm::StringRef(fragType));
1096 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1097 op.getLoc(), type, op.getSource());
1098 valueMapping[op.getResult()] = matrix;
1099 return success();
1100}
1101
1102// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1103// updated and needs to be updated separately for the loop to be correct.
1104static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1105 scf::ForOp loop,
1106 ValueRange newInitArgs) {
1107 OpBuilder::InsertionGuard g(rewriter);
1108 rewriter.setInsertionPoint(loop);
1109
1110 // Create a new loop before the existing one, with the extra operands.
1111 rewriter.setInsertionPoint(loop);
1112 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1113 llvm::append_range(operands, newInitArgs);
1114 scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1115 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1116 operands);
1117 rewriter.eraseBlock(block: newLoop.getBody());
1118
1119 newLoop.getRegion().getBlocks().splice(
1120 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1121 for (Value operand : newInitArgs)
1122 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1123
1124 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1125 loop.getNumResults())))
1126 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1127
1128 LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1129 LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1130 LLVM_DEBUG(DBGS() << "erase: " << loop);
1131
1132 rewriter.eraseOp(op: loop);
1133 return newLoop;
1134}
1135
1136static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1137 llvm::DenseMap<Value, Value> &valueMapping) {
1138 OpBuilder::InsertionGuard g(rewriter);
1139 rewriter.setInsertionPoint(op);
1140
1141 SmallVector<Value> newOperands;
1142 SmallVector<std::pair<size_t, size_t>> argMapping;
1143 for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1144 auto it = valueMapping.find(operand.value());
1145 if (it == valueMapping.end()) {
1146 LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1147 continue;
1148 }
1149 argMapping.push_back(std::make_pair(
1150 operand.index(), op.getInitArgs().size() + newOperands.size()));
1151 newOperands.push_back(it->second);
1152 }
1153
1154 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1155 Block &loopBody = *newForOp.getBody();
1156 for (auto mapping : argMapping) {
1157 valueMapping[newForOp.getResult(mapping.first)] =
1158 newForOp.getResult(mapping.second);
1159 valueMapping[loopBody.getArgument(i: mapping.first +
1160 newForOp.getNumInductionVars())] =
1161 loopBody.getArgument(i: mapping.second + newForOp.getNumInductionVars());
1162 }
1163
1164 LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1165 return success();
1166}
1167
1168static LogicalResult
1169convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1170 llvm::DenseMap<Value, Value> &valueMapping) {
1171 OpBuilder::InsertionGuard g(rewriter);
1172 rewriter.setInsertionPoint(op);
1173
1174 auto loop = cast<scf::ForOp>(op->getParentOp());
1175 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1176 for (const auto &operand : llvm::enumerate(op.getOperands())) {
1177 auto it = valueMapping.find(operand.value());
1178 if (it == valueMapping.end())
1179 continue;
1180 // Replace the yield of old value with the for op argument to make it easier
1181 // to remove the dead code.
1182 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1183 yieldOperands.push_back(it->second);
1184 }
1185 rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1186
1187 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1188 rewriter.eraseOp(op: op);
1189 return success();
1190}
1191
1192/// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1193static LogicalResult
1194convertElementwiseOp(RewriterBase &rewriter, Operation *op,
1195 gpu::MMAElementwiseOp opType,
1196 llvm::DenseMap<Value, Value> &valueMapping) {
1197 OpBuilder::InsertionGuard g(rewriter);
1198 rewriter.setInsertionPoint(op);
1199
1200 SmallVector<Value> matrixOperands;
1201 for (Value operand : op->getOperands()) {
1202 auto it = valueMapping.find(Val: operand);
1203 if (it == valueMapping.end())
1204 return rewriter.notifyMatchFailure(arg&: op, msg: "no mapping");
1205 matrixOperands.push_back(Elt: it->second);
1206 }
1207 auto resultType = cast<gpu::MMAMatrixType>(Val: matrixOperands[0].getType());
1208 if (opType == gpu::MMAElementwiseOp::EXTF) {
1209 // The floating point extension case has a different result type.
1210 auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1211 resultType = gpu::MMAMatrixType::get(shape: resultType.getShape(),
1212 elementType: vectorType.getElementType(),
1213 operand: resultType.getOperand());
1214 }
1215
1216 Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1217 op->getLoc(), resultType, matrixOperands, opType);
1218 valueMapping[op->getResult(idx: 0)] = newOp;
1219 return success();
1220}
1221
1222void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
1223 bool useNvGpu) {
1224 if (!useNvGpu) {
1225 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1226 arg: patterns.getContext());
1227 return;
1228 }
1229 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
1230 patterns.add<CombineTransferReadOpTranspose>(arg: patterns.getContext());
1231}
1232
1233LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
1234 Operation *rootOp) {
1235 SetVector<Operation *> ops = getOpToConvert(op: rootOp, /*useNvGpu=*/false);
1236 llvm::DenseMap<Value, Value> valueMapping;
1237
1238 auto globalRes = LogicalResult::success();
1239 for (Operation *op : ops) {
1240 LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1241 // Apparently callers do not want to early exit on failure here.
1242 auto res = LogicalResult::success();
1243 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1244 res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1245 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1246 res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1247 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1248 res = convertContractOp(rewriter, contractOp, valueMapping);
1249 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1250 res = convertConstantOp(rewriter, constantOp, valueMapping);
1251 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1252 res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1253 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1254 res = convertForOp(rewriter, forOp, valueMapping);
1255 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1256 res = convertYieldOp(rewriter, yieldOp, valueMapping);
1257 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1258 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1259 }
1260 if (failed(result: res))
1261 globalRes = failure();
1262 }
1263 return globalRes;
1264}
1265
1266LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
1267 Operation *rootOp) {
1268 SetVector<Operation *> ops = getOpToConvert(op: rootOp, /*useNvGpu=*/true);
1269 llvm::DenseMap<Value, Value> valueMapping;
1270 for (Operation *op : ops) {
1271 if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
1272 .Case(caseFn: [&](vector::TransferReadOp transferReadOp) {
1273 return convertTransferReadToLoads(rewriter, transferReadOp,
1274 valueMapping);
1275 })
1276 .Case(caseFn: [&](vector::TransferWriteOp transferWriteOp) {
1277 return convertTransferWriteToStores(rewriter, transferWriteOp,
1278 valueMapping);
1279 })
1280 .Case(caseFn: [&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1281 return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1282 valueMapping);
1283 })
1284 .Case(caseFn: [&](vector::ContractionOp contractionOp) {
1285 return convertContractOpToMmaSync(rewriter, contractionOp,
1286 valueMapping);
1287 })
1288 .Case(caseFn: [&](scf::ForOp forOp) {
1289 return convertForOp(rewriter, forOp, valueMapping);
1290 })
1291 .Case(caseFn: [&](scf::YieldOp yieldOp) {
1292 return convertYieldOp(rewriter, yieldOp, valueMapping);
1293 })
1294 .Case(caseFn: [&](arith::ConstantOp constOp) {
1295 return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1296 })
1297 .Default(defaultFn: [&](Operation *op) {
1298 return op->emitError() << "unhandled vector to mma type: " << *op;
1299 })
1300 .failed()) {
1301 return op->emitOpError()
1302 << "failed to convert op during vector-to-nvgpu conversion";
1303 }
1304 }
1305 return success();
1306}
1307
1308namespace {
1309
1310struct ConvertVectorToGPUPass
1311 : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1312
1313 explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1314 useNvGpu.setValue(useNvGpu_);
1315 }
1316
1317 void runOnOperation() override {
1318 RewritePatternSet patterns(&getContext());
1319 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1320 if (failed(
1321 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
1322 return signalPassFailure();
1323
1324 IRRewriter rewriter(&getContext());
1325 if (useNvGpu) {
1326 if (failed(
1327 convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1328 return signalPassFailure();
1329 return;
1330 }
1331 (void)convertVectorToMMAOps(rewriter, getOperation());
1332 }
1333};
1334
1335} // namespace
1336
1337std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1338 return std::make_unique<ConvertVectorToGPUPass>(args&: useNvGpu);
1339}
1340

source code of mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp