1//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to GPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
14
15#include "mlir/Analysis/SliceAnalysis.h"
16#include "mlir/Analysis/TopologicalSortUtils.h"
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/GPU/IR/GPUDialect.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
22#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25#include "mlir/Dialect/Vector/IR/VectorOps.h"
26#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
27#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/Region.h"
30#include "mlir/Pass/Pass.h"
31#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#include "llvm/ADT/STLExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34
35#define DEBUG_TYPE "vector-to-gpu"
36#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37#define DBGSNL() (llvm::dbgs() << "\n")
38
39namespace mlir {
40#define GEN_PASS_DEF_CONVERTVECTORTOGPU
41#include "mlir/Conversion/Passes.h.inc"
42} // namespace mlir
43
44using namespace mlir;
45
46/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
47/// AffineMap representing offsets to apply to indices, the function fills
48/// `indices` with the original indices plus the offsets. The offsets are
49/// applied by taking into account the permutation map of the transfer op. If
50/// the `offsetMap` has dimension placeholders, those should be provided in
51/// `dimValues`.
52template <typename TransferOpType>
53static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
54 AffineMap offsetMap, ArrayRef<Value> dimValues,
55 SmallVector<Value, 4> &indices) {
56 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
57 Location loc = xferOp.getLoc();
58 unsigned offsetsIdx = 0;
59 for (auto expr : xferOp.getPermutationMap().getResults()) {
60 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
61 Value prevIdx = indices[dim.getPosition()];
62 SmallVector<OpFoldResult, 3> dims(dimValues);
63 dims.push_back(Elt: prevIdx);
64 AffineExpr d0 = rewriter.getAffineDimExpr(position: offsetMap.getNumDims());
65 indices[dim.getPosition()] = affine::makeComposedAffineApply(
66 b&: rewriter, loc, e: d0 + offsetMap.getResult(idx: offsetsIdx++), operands: dims);
67 continue;
68 }
69 }
70}
71
72// Return true if the contract op can be convert to MMA matmul.
73static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
74 bool useNvGpu) {
75 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
76 auto infer = [&](MapList m) {
77 return AffineMap::inferFromExprList(exprsList: m, context: contract.getContext());
78 };
79 AffineExpr m, n, k;
80 bindDims(ctx: contract.getContext(), exprs&: m, exprs&: n, exprs&: k);
81 auto iteratorTypes = contract.getIteratorTypes().getValue();
82 if (!(vector::isParallelIterator(attr: iteratorTypes[0]) &&
83 vector::isParallelIterator(attr: iteratorTypes[1]) &&
84 vector::isReductionIterator(attr: iteratorTypes[2])))
85 return false;
86
87 // The contract needs to represent a matmul to be able to convert to
88 // MMAMatrix matmul.
89 if (!useNvGpu &&
90 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
91 return false;
92 if (useNvGpu &&
93 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
94 return false;
95
96 return true;
97}
98
99// Return true if the given map represents a transposed matrix load,
100// i.e. (d0, d1, ...) -> (dn-1, dn-2).
101static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
102 MLIRContext *ctx = permutationMap.getContext();
103 // Local OpBuilder is fine here, we just build attributes.
104 OpBuilder b(ctx);
105 auto nDim = permutationMap.getNumDims();
106 AffineExpr zero = b.getAffineConstantExpr(constant: 0);
107 if (nDim < 2) {
108 // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
109 AffineExpr dim0 = b.getAffineDimExpr(position: 0);
110 return permutationMap == AffineMap::get(dimCount: 1, symbolCount: 0, results: {dim0, zero}, context: ctx);
111 }
112
113 AffineExpr innerDim = b.getAffineDimExpr(position: nDim - 1);
114 AffineExpr outerDim = b.getAffineDimExpr(position: nDim - 2);
115 // Support both transposed and transposed+broadcasted cases.
116 return permutationMap == AffineMap::get(dimCount: nDim, symbolCount: 0, results: {innerDim, outerDim}, context: ctx) ||
117 permutationMap == AffineMap::get(dimCount: nDim, symbolCount: 0, results: {innerDim, zero}, context: ctx);
118}
119
120// Return the stide for the second-to-last dimension of |type| if it is a memref
121// and has a constant stride.
122static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
123 auto memrefType = dyn_cast<MemRefType>(Val&: type);
124 if (!memrefType)
125 return false;
126 // If the memref is 0 or 1D the horizontal stride is 0.
127 if (memrefType.getRank() < 2)
128 return 0;
129 int64_t offset = 0;
130 SmallVector<int64_t, 2> strides;
131 if (failed(Result: memrefType.getStridesAndOffset(strides, offset)) ||
132 strides.back() != 1)
133 return std::nullopt;
134 int64_t stride = strides[strides.size() - 2];
135 if (stride == ShapedType::kDynamic)
136 return std::nullopt;
137 return stride;
138}
139
140// Return true if the transfer op can be converted to a MMA matrix load.
141static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
142 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
143 readOp.getVectorType().getRank() != 2)
144 return false;
145 if (!getStaticallyKnownRowStride(type: readOp.getShapedType()))
146 return false;
147
148 // Only allow integer types if the signedness can be inferred.
149 if (readOp.getVectorType().getElementType().isInteger(width: 8))
150 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(Val: *readOp->user_begin()) &&
151 !isa<arith::ExtUIOp>(Val: *readOp->user_begin())))
152 return false;
153
154 AffineMap map = readOp.getPermutationMap();
155 MLIRContext *ctx = readOp.getContext();
156 AffineExpr innerDim = getAffineDimExpr(position: map.getNumDims() - 1, context: ctx);
157 AffineExpr zero = getAffineConstantExpr(constant: 0, context: ctx);
158 auto broadcastInnerDim =
159 AffineMap::get(dimCount: map.getNumDims(), symbolCount: 0, results: {zero, innerDim}, context: ctx);
160 return map.isMinorIdentity() || map == broadcastInnerDim ||
161 isTransposeMatrixLoadMap(permutationMap: map);
162}
163
164// Return true if the transfer op can be converted to a MMA matrix store.
165static bool
166transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
167 // TODO: support 0-d corner case.
168 if (writeOp.getTransferRank() == 0)
169 return false;
170
171 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
172 writeOp.getVectorType().getRank() != 2)
173 return false;
174 if (!getStaticallyKnownRowStride(type: writeOp.getShapedType()))
175 return false;
176 // TODO: Support transpose once it is added to GPU dialect ops.
177 if (!writeOp.getPermutationMap().isMinorIdentity())
178 return false;
179 return true;
180}
181
182/// Return true if the constant is a splat to a 2D vector so that it can be
183/// converted to a MMA constant matrix op.
184static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
185 auto vecType = dyn_cast<VectorType>(Val: constantOp.getType());
186 if (!vecType || vecType.getRank() != 2)
187 return false;
188 return isa<SplatElementsAttr>(Val: constantOp.getValue());
189}
190
191/// Return true if this is a broadcast from scalar to a 2D vector.
192static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
193 return broadcastOp.getResultVectorType().getRank() == 2;
194}
195
196/// Return true if this integer extend op can be folded into a contract op.
197template <typename ExtOpTy>
198static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
199 auto transferReadOp =
200 extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
201 if (!transferReadOp)
202 return false;
203 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
204}
205
206static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
207
208/// Return the MMA elementwise enum associated with `op` if it is supported.
209/// Return `std::nullopt` otherwise.
210static std::optional<gpu::MMAElementwiseOp>
211convertElementwiseOpToMMA(Operation *op) {
212 if (isa<arith::AddFOp>(Val: op))
213 return gpu::MMAElementwiseOp::ADDF;
214 if (isa<arith::MulFOp>(Val: op))
215 return gpu::MMAElementwiseOp::MULF;
216 if (isa<arith::SubFOp>(Val: op))
217 return gpu::MMAElementwiseOp::SUBF;
218 if (isa<arith::MaximumFOp>(Val: op))
219 return gpu::MMAElementwiseOp::MAXF;
220 if (isa<arith::MinimumFOp>(Val: op))
221 return gpu::MMAElementwiseOp::MINF;
222 if (isa<arith::DivFOp>(Val: op))
223 return gpu::MMAElementwiseOp::DIVF;
224 if (isa<arith::AddIOp>(Val: op))
225 return gpu::MMAElementwiseOp::ADDI;
226 if (isa<arith::MulIOp>(Val: op))
227 return gpu::MMAElementwiseOp::MULI;
228 if (isa<arith::SubIOp>(Val: op))
229 return gpu::MMAElementwiseOp::SUBI;
230 if (isa<arith::DivSIOp>(Val: op))
231 return gpu::MMAElementwiseOp::DIVS;
232 if (isa<arith::DivUIOp>(Val: op))
233 return gpu::MMAElementwiseOp::DIVU;
234 if (isa<arith::NegFOp>(Val: op))
235 return gpu::MMAElementwiseOp::NEGATEF;
236 if (isa<arith::ExtFOp>(Val: op))
237 return gpu::MMAElementwiseOp::EXTF;
238 return std::nullopt;
239}
240
241/// Return true if the op is supported as elementwise op on MMAMatrix type.
242static bool elementwiseSupportsMMAMatrixType(Operation *op) {
243 return convertElementwiseOpToMMA(op).has_value();
244}
245
246/// Returns true if the extract strided slice op is supported with `mma.sync`
247/// path.
248static bool
249extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
250
251 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
252 nvgpu::getWarpMatrixInfo(op);
253 if (failed(Result: warpMatrixInfo))
254 return false;
255
256 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
257 if (failed(Result: contractOp))
258 return false;
259
260 // Handle vector.extract_strided_slice on registers containing
261 // matrixB and matrixC operands. vector.extract_strided_slice op
262 // is not supported on registers containing matrixA operands.
263 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
264 return (cast<VectorType>(Val: op->getResult(idx: 0).getType()) ==
265 cast<VectorType>(Val: (*contractOp).getRhs().getType()));
266 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
267 return (cast<VectorType>(Val: op->getResult(idx: 0).getType()) ==
268 cast<VectorType>(Val: (*contractOp).getAcc().getType()));
269
270 return false;
271}
272
273static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
274 if (isa<scf::ForOp, scf::YieldOp>(Val: op))
275 return true;
276 if (auto transferRead = dyn_cast<vector::TransferReadOp>(Val: op))
277 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(op: transferRead)
278 : transferReadSupportsMMAMatrixType(readOp: transferRead);
279 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(Val: op))
280 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(op: transferWrite)
281 : transferWriteSupportsMMAMatrixType(writeOp: transferWrite);
282 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(Val: op))
283 return useNvGpu &&
284 extractStridedSliceSupportsMMAMatrixType(op: extractStridedSlice);
285 if (auto contract = dyn_cast<vector::ContractionOp>(Val: op))
286 return contractSupportsMMAMatrixType(contract, useNvGpu);
287 if (auto constant = dyn_cast<arith::ConstantOp>(Val: op))
288 return constantSupportsMMAMatrixType(constantOp: constant);
289 if (auto broadcast = dyn_cast<vector::BroadcastOp>(Val: op))
290 return broadcastSupportsMMAMatrixType(broadcastOp: broadcast);
291 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(Val: op))
292 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(extOp: signedExtend);
293 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(Val: op))
294 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(extOp: unsignedExtend);
295 if (auto fpExtend = dyn_cast<arith::ExtFOp>(Val: op))
296 return fpExtendSupportsMMAMatrixType(extOp: fpExtend);
297 return elementwiseSupportsMMAMatrixType(op);
298}
299
300/// Return an unsorted slice handling scf.for region differently than
301/// `getSlice`. In scf.for we only want to include as part of the slice elements
302/// that are part of the use/def chain.
303static SetVector<Operation *>
304getSliceContract(Operation *op,
305 const BackwardSliceOptions &backwardSliceOptions,
306 const ForwardSliceOptions &forwardSliceOptions) {
307 SetVector<Operation *> slice;
308 slice.insert(X: op);
309 unsigned currentIndex = 0;
310 SetVector<Operation *> backwardSlice;
311 SetVector<Operation *> forwardSlice;
312 while (currentIndex != slice.size()) {
313 auto *currentOp = (slice)[currentIndex];
314 // Compute and insert the backwardSlice starting from currentOp.
315 backwardSlice.clear();
316 LogicalResult result =
317 getBackwardSlice(op: currentOp, backwardSlice: &backwardSlice, options: backwardSliceOptions);
318 assert(result.succeeded() && "expected a backward slice");
319 (void)result;
320 slice.insert_range(R&: backwardSlice);
321
322 // Compute and insert the forwardSlice starting from currentOp.
323 forwardSlice.clear();
324 // Special case for ForOp, we don't want to include the whole region but
325 // only the value using the region arguments.
326 // TODO: We should refine this to only care about the region arguments being
327 // converted to matrix type.
328 if (auto forOp = dyn_cast<scf::ForOp>(Val: currentOp)) {
329 for (Value forOpResult : forOp.getResults())
330 getForwardSlice(root: forOpResult, forwardSlice: &forwardSlice, options: forwardSliceOptions);
331 for (BlockArgument &arg : forOp.getRegionIterArgs())
332 getForwardSlice(root: arg, forwardSlice: &forwardSlice, options: forwardSliceOptions);
333 } else {
334 getForwardSlice(op: currentOp, forwardSlice: &forwardSlice, options: forwardSliceOptions);
335 }
336 slice.insert_range(R&: forwardSlice);
337 ++currentIndex;
338 }
339 return slice;
340}
341
342// Analyze slice of operations based on convert op to figure out if the whole
343// slice can be converted to MMA operations.
344static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
345 bool useNvGpu) {
346 auto hasVectorDest = [](Operation *op) {
347 return llvm::any_of(Range: op->getResultTypes(), P: llvm::IsaPred<VectorType>);
348 };
349 BackwardSliceOptions backwardSliceOptions;
350 backwardSliceOptions.filter = hasVectorDest;
351
352 auto hasVectorSrc = [](Operation *op) {
353 return llvm::any_of(Range: op->getOperandTypes(), P: llvm::IsaPred<VectorType>);
354 };
355 ForwardSliceOptions forwardSliceOptions;
356 forwardSliceOptions.filter = hasVectorSrc;
357
358 SetVector<Operation *> opToConvert;
359 op->walk(callback: [&](vector::ContractionOp contract) {
360 if (opToConvert.contains(key: contract.getOperation()))
361 return;
362 SetVector<Operation *> dependentOps =
363 getSliceContract(op: contract, backwardSliceOptions, forwardSliceOptions);
364 // If any instruction cannot use MMA matrix type drop the whole
365 // chain. MMA matrix are stored in an opaque type so they cannot be used
366 // by all operations.
367 if (llvm::any_of(Range&: dependentOps, P: [useNvGpu](Operation *op) {
368 if (!supportsMMaMatrixType(op, useNvGpu)) {
369 LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
370 return true;
371 }
372 return false;
373 }))
374 return;
375
376 opToConvert.insert_range(R&: dependentOps);
377 });
378 // Sort the operations so that we can convert them in topological order.
379 return topologicalSort(toSort: opToConvert);
380}
381
382namespace {
383// Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
384// to MMA matmul.
385struct PrepareContractToGPUMMA
386 : public OpRewritePattern<vector::ContractionOp> {
387 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
388
389 LogicalResult matchAndRewrite(vector::ContractionOp op,
390 PatternRewriter &rewriter) const override {
391 Location loc = op.getLoc();
392 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
393
394 // Set up the parallel/reduction structure in right form.
395 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
396 auto infer = [&](MapList m) {
397 return AffineMap::inferFromExprList(exprsList: m, context: op.getContext());
398 };
399 AffineExpr m, n, k;
400 bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k);
401 static constexpr std::array<int64_t, 2> perm = {1, 0};
402 auto iteratorTypes = op.getIteratorTypes().getValue();
403 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
404 if (!(vector::isParallelIterator(attr: iteratorTypes[0]) &&
405 vector::isParallelIterator(attr: iteratorTypes[1]) &&
406 vector::isReductionIterator(attr: iteratorTypes[2])))
407 return rewriter.notifyMatchFailure(arg&: op, msg: "not a gemm contraction");
408 //
409 // Two outer parallel, one inner reduction (matmat flavor).
410 //
411 // This is the classical row-major matmul, nothing to do.
412 if (maps == infer({{m, k}, {k, n}, {m, n}}))
413 return rewriter.notifyMatchFailure(arg&: op, msg: "contraction already prepared");
414 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
415 rhs = rewriter.create<vector::TransposeOp>(location: loc, args&: rhs, args: perm);
416 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
417 lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: lhs, args: perm);
418 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
419 rhs = rewriter.create<vector::TransposeOp>(location: loc, args&: rhs, args: perm);
420 lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: lhs, args: perm);
421 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
422 std::swap(a&: rhs, b&: lhs);
423 rhs = rewriter.create<vector::TransposeOp>(location: loc, args&: rhs, args: perm);
424 lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: lhs, args: perm);
425 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
426 std::swap(a&: rhs, b&: lhs);
427 rhs = rewriter.create<vector::TransposeOp>(location: loc, args&: rhs, args: perm);
428 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
429 std::swap(a&: lhs, b&: rhs);
430 lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: lhs, args: perm);
431 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
432 std::swap(a&: lhs, b&: rhs);
433 } else {
434 // TODO: llvm_unreachable ?
435 return rewriter.notifyMatchFailure(arg&: op, msg: "unexpected contraction case");
436 }
437 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
438 op, args&: lhs, args&: rhs, args&: res,
439 args: rewriter.getAffineMapArrayAttr(values: infer({{m, k}, {k, n}, {m, n}})),
440 args: op.getIteratorTypes());
441 return success();
442 }
443};
444
445// Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
446// row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
447// respectively. We can fold the transpose operation when loading the data from
448// Shared Memory to registers.
449struct CombineTransferReadOpTranspose final
450 : public OpRewritePattern<vector::TransposeOp> {
451 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
452
453 LogicalResult matchAndRewrite(vector::TransposeOp op,
454 PatternRewriter &rewriter) const override {
455 // Look through integer extend ops.
456 Value source = op.getVector();
457 Type resultType = op.getType();
458 Operation *extOp;
459 if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
460 (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
461 (extOp = source.getDefiningOp<arith::ExtFOp>())) {
462 source = extOp->getOperand(idx: 0);
463 resultType =
464 VectorType::get(shape: cast<VectorType>(Val&: resultType).getShape(),
465 elementType: cast<VectorType>(Val: source.getType()).getElementType());
466 }
467
468 auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
469 if (!transferReadOp)
470 return rewriter.notifyMatchFailure(arg&: op, msg: "no transfer read");
471
472 // TODO: support 0-d corner case.
473 if (transferReadOp.getTransferRank() == 0)
474 return rewriter.notifyMatchFailure(arg&: op, msg: "0-D transfer read");
475
476 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
477 return rewriter.notifyMatchFailure(arg&: op, msg: "not inbounds transfer read");
478
479 AffineMap permutationMap =
480 AffineMap::getPermutationMap(permutation: op.getPermutation(), context: op.getContext());
481 AffineMap newMap =
482 permutationMap.compose(map: transferReadOp.getPermutationMap());
483
484 auto loc = op.getLoc();
485 Value result =
486 rewriter
487 .create<vector::TransferReadOp>(
488 location: loc, args&: resultType, args: transferReadOp.getBase(),
489 args: transferReadOp.getIndices(), args: AffineMapAttr::get(value: newMap),
490 args: transferReadOp.getPadding(), args: transferReadOp.getMask(),
491 args: transferReadOp.getInBoundsAttr())
492 .getResult();
493
494 // Fuse through the integer extend op.
495 if (extOp) {
496 if (isa<arith::ExtSIOp>(Val: extOp))
497 result = rewriter.create<arith::ExtSIOp>(location: loc, args: op.getType(), args&: result)
498 .getResult();
499 else if (isa<arith::ExtUIOp>(Val: extOp))
500 result = rewriter.create<arith::ExtUIOp>(location: loc, args: op.getType(), args&: result)
501 .getResult();
502 else
503 result = rewriter.create<arith::ExtFOp>(location: loc, args: op.getType(), args&: result)
504 .getResult();
505 }
506
507 rewriter.replaceOp(op, newValues: result);
508 return success();
509 }
510};
511
512} // namespace
513
514// MMA types have different layout based on how they are used in matmul ops.
515// Figure the right layout to use by looking at op uses.
516// TODO: Change the GPU dialect to abstract the layout at the this level and
517// only care about it during lowering to NVVM.
518static const char *inferFragType(Operation *op) {
519 // We can have arith.ext ops before reaching contract ops. See through them
520 // and other kinds of elementwise ops.
521 if (op->hasOneUse()) {
522 Operation *userOp = *op->user_begin();
523 if (userOp->hasTrait<OpTrait::Elementwise>())
524 return inferFragType(op: userOp);
525 }
526
527 for (Operation *users : op->getUsers()) {
528 auto contract = dyn_cast<vector::ContractionOp>(Val: users);
529 if (!contract)
530 continue;
531 assert(op->getNumResults() == 1);
532 if (contract.getLhs() == op->getResult(idx: 0))
533 return "AOp";
534 if (contract.getRhs() == op->getResult(idx: 0))
535 return "BOp";
536 }
537 return "COp";
538}
539
540static LogicalResult
541convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
542 llvm::DenseMap<Value, Value> &valueMapping) {
543 OpBuilder::InsertionGuard g(rewriter);
544 rewriter.setInsertionPoint(op);
545
546 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
547 assert(transferReadSupportsMMAMatrixType(op) &&
548 "expected convertible operation");
549
550 std::optional<int64_t> stride =
551 getStaticallyKnownRowStride(type: op.getShapedType());
552 if (!stride.has_value()) {
553 LLVM_DEBUG(DBGS() << "no stride\n");
554 return rewriter.notifyMatchFailure(arg&: op, msg: "no stride");
555 }
556
557 AffineMap map = op.getPermutationMap();
558 bool isTranspose = isTransposeMatrixLoadMap(permutationMap: map);
559
560 // Handle broadcast by setting the stride to 0.
561 if (auto cstExpr = dyn_cast<AffineConstantExpr>(Val: map.getResult(idx: isTranspose))) {
562 assert(cstExpr.getValue() == 0);
563 stride = 0;
564 }
565
566 Value mappingResult = op.getResult();
567 auto elType = op.getVectorType().getElementType();
568 const char *fragType = inferFragType(op);
569 if (op->hasOneUse()) {
570 auto *user = *op->user_begin();
571 // Infer the signedness of the mma type from the integer extend.
572 if (isa<arith::ExtSIOp, arith::ExtUIOp>(Val: user)) {
573 elType = IntegerType::get(
574 context: op.getContext(), width: cast<IntegerType>(Val&: elType).getWidth(),
575 signedness: isa<arith::ExtSIOp>(Val: user) ? IntegerType::Signed
576 : IntegerType::Unsigned);
577 mappingResult = user->getResult(idx: 0);
578 }
579 }
580 gpu::MMAMatrixType type =
581 gpu::MMAMatrixType::get(shape: op.getVectorType().getShape(), elementType: elType, operand: fragType);
582 Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
583 location: op.getLoc(), args&: type, args: op.getBase(), args: op.getIndices(),
584 args: rewriter.getIndexAttr(value: *stride),
585 args: isTranspose ? rewriter.getUnitAttr() : UnitAttr());
586 valueMapping[mappingResult] = load;
587
588 LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
589 return success();
590}
591
592static LogicalResult
593convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
594 llvm::DenseMap<Value, Value> &valueMapping) {
595 OpBuilder::InsertionGuard g(rewriter);
596 rewriter.setInsertionPoint(op);
597
598 assert(transferWriteSupportsMMAMatrixType(op));
599 std::optional<int64_t> stride =
600 getStaticallyKnownRowStride(type: op.getShapedType());
601 if (!stride.has_value()) {
602 LLVM_DEBUG(DBGS() << "no stride\n");
603 return rewriter.notifyMatchFailure(arg&: op, msg: "no stride");
604 }
605
606 auto it = valueMapping.find(Val: op.getVector());
607 if (it == valueMapping.end()) {
608 LLVM_DEBUG(DBGS() << "no mapping\n");
609 return rewriter.notifyMatchFailure(arg&: op, msg: "no mapping");
610 }
611
612 Value matrix = it->second;
613 auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
614 location: op.getLoc(), args&: matrix, args: op.getBase(), args: op.getIndices(),
615 args: rewriter.getIndexAttr(value: *stride), /*transpose=*/args: UnitAttr());
616 (void)store;
617
618 LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
619
620 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
621 rewriter.eraseOp(op);
622 return success();
623}
624
625/// Returns the vector type which represents a matrix fragment.
626static VectorType
627getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
628 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
629 regInfo.elementsPerRegister};
630 Type elType = regInfo.registerLLVMType;
631 if (auto vecType = dyn_cast<VectorType>(Val&: elType))
632 elType = vecType.getElementType();
633 return VectorType::get(shape, elementType: elType);
634}
635
636/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
637static LogicalResult
638convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
639 llvm::DenseMap<Value, Value> &valueMapping) {
640 OpBuilder::InsertionGuard g(rewriter);
641 rewriter.setInsertionPoint(op);
642
643 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
644 nvgpu::getWarpMatrixInfo(op);
645 if (failed(Result: warpMatrixInfo)) {
646 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
647 return rewriter.notifyMatchFailure(arg&: op, msg: "no warpMatrixInfo");
648 }
649
650 FailureOr<nvgpu::FragmentElementInfo> regInfo =
651 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
652 if (failed(Result: regInfo)) {
653 LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
654 return rewriter.notifyMatchFailure(arg&: op, msg: "not mma sync reg info");
655 }
656
657 VectorType vectorType = getMmaSyncVectorOperandType(regInfo: *regInfo);
658 auto dense = dyn_cast<SplatElementsAttr>(Val: op.getValue());
659 if (!dense) {
660 LLVM_DEBUG(DBGS() << "not a splat\n");
661 return rewriter.notifyMatchFailure(arg&: op, msg: "not a splat");
662 }
663
664 Value result = rewriter.create<arith::ConstantOp>(
665 location: op.getLoc(), args&: vectorType,
666 args: DenseElementsAttr::get(type: vectorType, values: dense.getSplatValue<Attribute>()));
667 valueMapping[op.getResult()] = result;
668 return success();
669}
670
671/// Check if the loaded matrix operand requires transposed.
672/// Transposed Map Example:
673/// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
674/// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
675/// The code below checks if the output 2D is transposed using a generalized
676/// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
677/// Returns : true; if m > n, false o.w.
678static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
679 mlir::AffineMap map = op.getPermutationMap();
680
681 if (map.getNumResults() != 2) {
682 LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
683 "is not a 2d operand\n");
684 return failure();
685 }
686
687 // Output 2D matrix dimensions in the order of d0, d1.
688 mlir::AffineExpr dM = map.getResult(idx: 0);
689 mlir::AffineExpr dN = map.getResult(idx: 1);
690
691 // Find the position of these expressions in the input.
692 auto exprM = dyn_cast<AffineDimExpr>(Val&: dM);
693 auto exprN = dyn_cast<AffineDimExpr>(Val&: dN);
694
695 if (!exprM || !exprN) {
696 LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
697 "expressions, then transpose cannot be determined.\n");
698 return failure();
699 }
700
701 return exprM.getPosition() > exprN.getPosition();
702}
703
704static LogicalResult
705creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
706 llvm::DenseMap<Value, Value> &valueMapping) {
707 OpBuilder::InsertionGuard g(rewriter);
708 rewriter.setInsertionPoint(op);
709 Location loc = op->getLoc();
710
711 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
712 nvgpu::getWarpMatrixInfo(op);
713 if (failed(Result: warpMatrixInfo)) {
714 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
715 return rewriter.notifyMatchFailure(arg&: op, msg: "no warpMatrixInfo");
716 }
717
718 FailureOr<nvgpu::FragmentElementInfo> regInfo =
719 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
720 if (failed(Result: regInfo)) {
721 LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
722 return rewriter.notifyMatchFailure(arg&: op, msg: "not mma sync reg info");
723 }
724
725 FailureOr<bool> transpose = isTransposed(op);
726 if (failed(Result: transpose)) {
727 LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
728 return rewriter.notifyMatchFailure(
729 arg&: op, msg: "Op should likely not be converted to a nvgpu.ldmatrix call.");
730 }
731
732 FailureOr<nvgpu::LdMatrixParams> params =
733 nvgpu::getLdMatrixParams(type: *warpMatrixInfo, transpose: *transpose);
734
735 if (failed(Result: params)) {
736 LLVM_DEBUG(
737 DBGS()
738 << "failed to convert vector.transfer_read to ldmatrix. "
739 << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
740 return rewriter.notifyMatchFailure(
741 arg&: op, msg: "failed to convert vector.transfer_read to ldmatrix; this op "
742 "likely should not be converted to a nvgpu.ldmatrix call.");
743 }
744
745 // Adjust the load offset.
746 auto laneId = rewriter.create<gpu::LaneIdOp>(location: loc, /*upperBound=*/args: nullptr);
747 FailureOr<AffineMap> offsets =
748 nvgpu::getLaneIdToLdMatrixMatrixCoord(builder&: rewriter, loc, params: *params);
749 if (failed(Result: offsets)) {
750 LLVM_DEBUG(DBGS() << "no offsets\n");
751 return rewriter.notifyMatchFailure(arg&: op, msg: "no offsets");
752 }
753
754 VectorType vectorType = getMmaSyncVectorOperandType(regInfo: *regInfo);
755
756 SmallVector<Value, 4> indices;
757 getXferIndices<vector::TransferReadOp>(rewriter, xferOp: op, offsetMap: *offsets, dimValues: {laneId},
758 indices);
759
760 nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
761 location: loc, args&: vectorType, args: op.getBase(), args&: indices, args&: *transpose, args&: params->numTiles);
762 valueMapping[op] = newOp->getResult(idx: 0);
763 return success();
764}
765
766static LogicalResult
767createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
768 llvm::DenseMap<Value, Value> &valueMapping) {
769 OpBuilder::InsertionGuard g(rewriter);
770 rewriter.setInsertionPoint(op);
771
772 Location loc = op.getLoc();
773 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
774 nvgpu::getWarpMatrixInfo(op);
775 if (failed(Result: warpMatrixInfo))
776 return rewriter.notifyMatchFailure(arg&: op, msg: "no warpMatrixInfo");
777 FailureOr<nvgpu::FragmentElementInfo> regInfo =
778 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
779 if (failed(Result: regInfo)) {
780 return rewriter.notifyMatchFailure(
781 arg&: op, msg: "Failed to deduce register fragment type during "
782 "conversion to distributed non-ldmatrix compatible load");
783 }
784
785 Value laneId = rewriter.create<gpu::LaneIdOp>(location: loc, /*upperBound=*/args: nullptr);
786
787 // This is the individual element type.
788 Type loadedElType = regInfo->registerLLVMType;
789 VectorType vectorType = getMmaSyncVectorOperandType(regInfo: *regInfo);
790
791 Value fill = rewriter.create<arith::ConstantOp>(
792 location: op.getLoc(), args: vectorType.getElementType(),
793 args: rewriter.getZeroAttr(type: vectorType.getElementType()));
794 Value result =
795 rewriter.create<vector::SplatOp>(location: op.getLoc(), args&: fill, args&: vectorType);
796
797 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
798
799 // If we are not transposing, then we can use vectorized loads. Otherwise, we
800 // must load each element individually.
801 if (!isTransposeLoad) {
802 if (!isa<VectorType>(Val: loadedElType)) {
803 loadedElType = VectorType::get(shape: {1}, elementType: loadedElType);
804 }
805
806 for (int i = 0; i < vectorType.getShape()[0]; i++) {
807 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
808 builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo);
809 if (failed(Result: coords))
810 return rewriter.notifyMatchFailure(arg&: op, msg: "no coords");
811
812 Value logicalValueId = rewriter.create<arith::ConstantOp>(
813 location: loc, args: rewriter.getIndexType(),
814 args: rewriter.getIndexAttr(value: i * regInfo->elementsPerRegister));
815 SmallVector<Value, 4> newIndices;
816 getXferIndices<vector::TransferReadOp>(
817 rewriter, xferOp: op, offsetMap: *coords, dimValues: {laneId, logicalValueId}, indices&: newIndices);
818
819 Value el = rewriter.create<vector::LoadOp>(location: loc, args&: loadedElType,
820 args: op.getBase(), args&: newIndices);
821 result = rewriter.create<vector::InsertOp>(location: loc, args&: el, args&: result, args&: i);
822 }
823 } else {
824 if (auto vecType = dyn_cast<VectorType>(Val&: loadedElType)) {
825 loadedElType = vecType.getElementType();
826 }
827 for (int i = 0; i < vectorType.getShape()[0]; i++) {
828 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
829 innerIdx++) {
830
831 Value logicalValueId = rewriter.create<arith::ConstantOp>(
832 location: loc, args: rewriter.getIndexType(),
833 args: rewriter.getIndexAttr(value: i * regInfo->elementsPerRegister + innerIdx));
834 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
835 builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo);
836 if (failed(Result: coords))
837 return rewriter.notifyMatchFailure(arg&: op, msg: "no coords");
838
839 SmallVector<Value, 4> newIndices;
840 getXferIndices<vector::TransferReadOp>(
841 rewriter, xferOp: op, offsetMap: *coords, dimValues: {laneId, logicalValueId}, indices&: newIndices);
842 Value el = rewriter.create<memref::LoadOp>(location: op.getLoc(), args&: loadedElType,
843 args: op.getBase(), args&: newIndices);
844 result = rewriter.create<vector::InsertOp>(
845 location: op.getLoc(), args&: el, args&: result, args: ArrayRef<int64_t>{i, innerIdx});
846 }
847 }
848 }
849
850 valueMapping[op.getResult()] = result;
851 return success();
852}
853
854/// Return true if this is a shared memory memref type.
855static bool isSharedMemory(MemRefType type) {
856 auto addressSpace =
857 dyn_cast_or_null<gpu::AddressSpaceAttr>(Val: type.getMemorySpace());
858 return addressSpace &&
859 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
860}
861
862/// Converts a `vector.transfer_read` operation directly to either a
863/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
864/// used when converting to `nvgpu.mma.sync` operations.
865static LogicalResult
866convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
867 llvm::DenseMap<Value, Value> &valueMapping) {
868 OpBuilder::InsertionGuard g(rewriter);
869 rewriter.setInsertionPoint(op);
870
871 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
872 nvgpu::getWarpMatrixInfo(op);
873 if (failed(Result: warpMatrixInfo))
874 return rewriter.notifyMatchFailure(arg&: op, msg: "no warpMatrixInfo");
875
876 bool isLdMatrixCompatible =
877 isSharedMemory(type: cast<MemRefType>(Val: op.getBase().getType())) &&
878 nvgpu::inferTileWidthInBits(type: *warpMatrixInfo) == 128;
879
880 VectorType vecTy = op.getVectorType();
881 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
882
883 // When we are transposing the B operand, ldmatrix will only work if we have
884 // at least 8 rows to read and the width to read for the transpose is 128
885 // bits.
886 if (!op.getPermutationMap().isMinorIdentity() &&
887 (bitWidth != 16 || vecTy.getDimSize(idx: 1) < 8 ||
888 vecTy.getDimSize(idx: 0) * bitWidth < 128))
889 isLdMatrixCompatible = false;
890
891 if (!isLdMatrixCompatible)
892 return createNonLdMatrixLoads(rewriter, op, valueMapping);
893
894 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
895}
896
897static LogicalResult
898convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
899 llvm::DenseMap<Value, Value> &valueMapping) {
900 OpBuilder::InsertionGuard g(rewriter);
901 rewriter.setInsertionPoint(op);
902
903 Location loc = op->getLoc();
904 auto it = valueMapping.find(Val: op.getVector());
905 if (it == valueMapping.end())
906 return rewriter.notifyMatchFailure(arg&: op, msg: "no mapping");
907 Value matrix = it->second;
908
909 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
910 nvgpu::getWarpMatrixInfo(op);
911 if (failed(Result: warpMatrixInfo))
912 return rewriter.notifyMatchFailure(arg&: op, msg: "no warpMatrixInfo");
913 FailureOr<nvgpu::FragmentElementInfo> regInfo =
914 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
915 if (failed(Result: regInfo))
916 return rewriter.notifyMatchFailure(arg&: op, msg: "not mma sync reg info");
917
918 VectorType vectorType = getMmaSyncVectorOperandType(regInfo: *regInfo);
919 Value laneId = rewriter.create<gpu::LaneIdOp>(location: loc, /*upperBound=*/args: nullptr);
920
921 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
922 Value logicalValueId = rewriter.create<arith::ConstantOp>(
923 location: loc, args: rewriter.getIndexType(),
924 args: rewriter.getIndexAttr(value: i * regInfo->elementsPerRegister));
925 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
926 builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo);
927 if (failed(Result: coords))
928 return rewriter.notifyMatchFailure(arg&: op, msg: "no coords");
929
930 Value el =
931 rewriter.create<vector::ExtractOp>(location: loc, args&: matrix, args: ArrayRef<int64_t>{i});
932 SmallVector<Value, 4> newIndices;
933 getXferIndices<vector::TransferWriteOp>(
934 rewriter, xferOp: op, offsetMap: *coords, dimValues: {laneId, logicalValueId}, indices&: newIndices);
935 rewriter.create<vector::StoreOp>(location: loc, args&: el, args: op.getBase(), args&: newIndices);
936 }
937
938 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
939 rewriter.eraseOp(op);
940 return success();
941}
942
943static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
944 SmallVectorImpl<int64_t> &results) {
945 for (auto attr : arrayAttr)
946 results.push_back(Elt: cast<IntegerAttr>(Val&: attr).getInt());
947}
948
949static LogicalResult
950convertExtractStridedSlice(RewriterBase &rewriter,
951 vector::ExtractStridedSliceOp op,
952 llvm::DenseMap<Value, Value> &valueMapping) {
953 OpBuilder::InsertionGuard g(rewriter);
954 rewriter.setInsertionPoint(op);
955
956 Location loc = op->getLoc();
957
958 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
959 nvgpu::getWarpMatrixInfo(op);
960 if (failed(Result: warpMatrixInfo))
961 return rewriter.notifyMatchFailure(arg&: op, msg: "no warpMatrixInfo");
962
963 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
964 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
965 if (failed(Result: mmaSyncFragmentInfo))
966 return rewriter.notifyMatchFailure(arg&: op, msg: "no mmaSyncFragmentInfo");
967
968 // Find the vector.transer_read whose result vector is being sliced.
969 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
970 if (!transferReadOp)
971 return rewriter.notifyMatchFailure(arg&: op, msg: "no transfer read");
972
973 warpMatrixInfo = nvgpu::getWarpMatrixInfo(op: transferReadOp);
974 if (failed(Result: warpMatrixInfo))
975 return rewriter.notifyMatchFailure(arg&: op, msg: "no warpMatrixInfo");
976
977 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
978 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
979 if (failed(Result: ldFragmentInfo))
980 return rewriter.notifyMatchFailure(arg&: op, msg: "no ldFragmentInfo");
981
982 assert(
983 (mmaSyncFragmentInfo->elementsPerRegister ==
984 ldFragmentInfo->elementsPerRegister) &&
985 "Number of elements per register should be same for load and mma.sync");
986
987 // Create vector.extract_strided_slice op for thread-owned fragments.
988 std::array<int64_t, 2> strides = {1,
989 1}; // stride for extract slice is always 1.
990 std::array<int64_t, 2> sliceShape = {
991 mmaSyncFragmentInfo->numRegistersPerFragment,
992 mmaSyncFragmentInfo->elementsPerRegister};
993 auto it = valueMapping.find(Val: transferReadOp);
994 if (it == valueMapping.end())
995 return rewriter.notifyMatchFailure(arg&: op, msg: "no mapping");
996 auto sourceVector = it->second;
997
998 // offset and sizes at warp-level of onwership.
999 SmallVector<int64_t> offsets;
1000 populateFromInt64AttrArray(arrayAttr: op.getOffsets(), results&: offsets);
1001
1002 SmallVector<int64_t> sizes;
1003 populateFromInt64AttrArray(arrayAttr: op.getSizes(), results&: sizes);
1004 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1005
1006 // Compute offset in vector registers. Note that the mma.sync vector registers
1007 // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1008 // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1009 std::array<int64_t, 2> sliceOffset = {0, 0};
1010
1011 if (offsets[0] && offsets[1])
1012 return op->emitError() << "Slicing fragments in 2D is not supported. ";
1013 if (offsets[0])
1014 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1015 else if (offsets[1])
1016 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1017
1018 Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1019 location: loc, args&: sourceVector, args&: sliceOffset, args&: sliceShape, args&: strides);
1020
1021 valueMapping[op] = newOp;
1022 return success();
1023}
1024
1025static LogicalResult
1026convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1027 llvm::DenseMap<Value, Value> &valueMapping) {
1028 OpBuilder::InsertionGuard g(rewriter);
1029 rewriter.setInsertionPoint(op);
1030
1031 auto itA = valueMapping.find(Val: op.getLhs());
1032 auto itB = valueMapping.find(Val: op.getRhs());
1033 auto itC = valueMapping.find(Val: op.getAcc());
1034 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1035 itC == valueMapping.end())
1036 return rewriter.notifyMatchFailure(arg&: op, msg: "no mapping");
1037 Value opA = itA->second, opB = itB->second, opC = itC->second;
1038 Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1039 location: op.getLoc(), args: opC.getType(), args&: opA, args&: opB, args&: opC, /*a_transpose=*/args: UnitAttr(),
1040 /*b_transpose=*/args: UnitAttr());
1041 valueMapping[op.getResult()] = matmul;
1042 return success();
1043}
1044
1045static LogicalResult
1046convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1047 llvm::DenseMap<Value, Value> &valueMapping) {
1048 OpBuilder::InsertionGuard g(rewriter);
1049 rewriter.setInsertionPoint(op);
1050
1051 auto itA = valueMapping.find(Val: op.getLhs());
1052 auto itB = valueMapping.find(Val: op.getRhs());
1053 auto itC = valueMapping.find(Val: op.getAcc());
1054 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1055 itC == valueMapping.end())
1056 return rewriter.notifyMatchFailure(arg&: op, msg: "no mapping");
1057 Value opA = itA->second, opB = itB->second, opC = itC->second;
1058 int64_t m = cast<VectorType>(Val: op.getLhs().getType()).getShape()[0];
1059 int64_t n = cast<VectorType>(Val: op.getRhs().getType()).getShape()[0];
1060 int64_t k = cast<VectorType>(Val: op.getLhs().getType()).getShape()[1];
1061 Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1062 location: op.getLoc(), args&: opA, args&: opB, args&: opC, args: rewriter.getI64ArrayAttr(values: {m, n, k}));
1063 valueMapping[op.getResult()] = matmul;
1064 return success();
1065}
1066
1067/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1068static LogicalResult
1069convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1070 llvm::DenseMap<Value, Value> &valueMapping) {
1071 OpBuilder::InsertionGuard g(rewriter);
1072 rewriter.setInsertionPoint(op);
1073
1074 assert(constantSupportsMMAMatrixType(op));
1075
1076 auto splat =
1077 cast<SplatElementsAttr>(Val: op.getValue()).getSplatValue<TypedAttr>();
1078 auto scalarConstant =
1079 rewriter.create<arith::ConstantOp>(location: op.getLoc(), args: splat.getType(), args&: splat);
1080 const char *fragType = inferFragType(op);
1081 auto vecType = cast<VectorType>(Val: op.getType());
1082 gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1083 shape: vecType.getShape(), elementType: vecType.getElementType(), operand: llvm::StringRef(fragType));
1084 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1085 location: op.getLoc(), args&: type, args&: scalarConstant);
1086 valueMapping[op.getResult()] = matrix;
1087 return success();
1088}
1089
1090/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1091static LogicalResult
1092convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1093 llvm::DenseMap<Value, Value> &valueMapping) {
1094 OpBuilder::InsertionGuard g(rewriter);
1095 rewriter.setInsertionPoint(op);
1096
1097 assert(broadcastSupportsMMAMatrixType(op));
1098
1099 const char *fragType = inferFragType(op);
1100 auto vecType = op.getResultVectorType();
1101 gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1102 shape: vecType.getShape(), elementType: vecType.getElementType(), operand: llvm::StringRef(fragType));
1103 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1104 location: op.getLoc(), args&: type, args: op.getSource());
1105 valueMapping[op.getResult()] = matrix;
1106 return success();
1107}
1108
1109// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1110// updated and needs to be updated separately for the loop to be correct.
1111static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1112 scf::ForOp loop,
1113 ValueRange newInitArgs) {
1114 OpBuilder::InsertionGuard g(rewriter);
1115 rewriter.setInsertionPoint(loop);
1116
1117 // Create a new loop before the existing one, with the extra operands.
1118 rewriter.setInsertionPoint(loop);
1119 auto operands = llvm::to_vector<4>(Range: loop.getInitArgs());
1120 llvm::append_range(C&: operands, R&: newInitArgs);
1121 scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1122 location: loop.getLoc(), args: loop.getLowerBound(), args: loop.getUpperBound(), args: loop.getStep(),
1123 args&: operands);
1124 rewriter.eraseBlock(block: newLoop.getBody());
1125
1126 newLoop.getRegion().getBlocks().splice(
1127 where: newLoop.getRegion().getBlocks().begin(), L2&: loop.getRegion().getBlocks());
1128 for (Value operand : newInitArgs)
1129 newLoop.getBody()->addArgument(type: operand.getType(), loc: operand.getLoc());
1130
1131 for (auto it : llvm::zip(t: loop.getResults(), u: newLoop.getResults().take_front(
1132 n: loop.getNumResults())))
1133 rewriter.replaceAllUsesWith(from: std::get<0>(t&: it), to: std::get<1>(t&: it));
1134
1135 LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1136 LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1137 LLVM_DEBUG(DBGS() << "erase: " << loop);
1138
1139 rewriter.eraseOp(op: loop);
1140 return newLoop;
1141}
1142
1143static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1144 llvm::DenseMap<Value, Value> &valueMapping) {
1145 OpBuilder::InsertionGuard g(rewriter);
1146 rewriter.setInsertionPoint(op);
1147
1148 SmallVector<Value> newOperands;
1149 SmallVector<std::pair<size_t, size_t>> argMapping;
1150 for (const auto &operand : llvm::enumerate(First: op.getInitArgs())) {
1151 auto it = valueMapping.find(Val: operand.value());
1152 if (it == valueMapping.end()) {
1153 LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1154 continue;
1155 }
1156 argMapping.push_back(Elt: std::make_pair(
1157 x: operand.index(), y: op.getInitArgs().size() + newOperands.size()));
1158 newOperands.push_back(Elt: it->second);
1159 }
1160
1161 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, loop: op, newInitArgs: newOperands);
1162 Block &loopBody = *newForOp.getBody();
1163 for (auto mapping : argMapping) {
1164 valueMapping[newForOp.getResult(i: mapping.first)] =
1165 newForOp.getResult(i: mapping.second);
1166 valueMapping[loopBody.getArgument(i: mapping.first +
1167 newForOp.getNumInductionVars())] =
1168 loopBody.getArgument(i: mapping.second + newForOp.getNumInductionVars());
1169 }
1170
1171 LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1172 return success();
1173}
1174
1175static LogicalResult
1176convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1177 llvm::DenseMap<Value, Value> &valueMapping) {
1178 OpBuilder::InsertionGuard g(rewriter);
1179 rewriter.setInsertionPoint(op);
1180
1181 auto loop = cast<scf::ForOp>(Val: op->getParentOp());
1182 auto yieldOperands = llvm::to_vector<4>(Range: op.getOperands());
1183 for (const auto &operand : llvm::enumerate(First: op.getOperands())) {
1184 auto it = valueMapping.find(Val: operand.value());
1185 if (it == valueMapping.end())
1186 continue;
1187 // Replace the yield of old value with the for op argument to make it easier
1188 // to remove the dead code.
1189 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1190 yieldOperands.push_back(Elt: it->second);
1191 }
1192 rewriter.create<scf::YieldOp>(location: op.getLoc(), args&: yieldOperands);
1193
1194 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1195 rewriter.eraseOp(op);
1196 return success();
1197}
1198
1199/// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1200static LogicalResult
1201convertElementwiseOp(RewriterBase &rewriter, Operation *op,
1202 gpu::MMAElementwiseOp opType,
1203 llvm::DenseMap<Value, Value> &valueMapping) {
1204 OpBuilder::InsertionGuard g(rewriter);
1205 rewriter.setInsertionPoint(op);
1206
1207 SmallVector<Value> matrixOperands;
1208 for (Value operand : op->getOperands()) {
1209 auto it = valueMapping.find(Val: operand);
1210 if (it == valueMapping.end())
1211 return rewriter.notifyMatchFailure(arg&: op, msg: "no mapping");
1212 matrixOperands.push_back(Elt: it->second);
1213 }
1214 auto resultType = cast<gpu::MMAMatrixType>(Val: matrixOperands[0].getType());
1215 if (opType == gpu::MMAElementwiseOp::EXTF) {
1216 // The floating point extension case has a different result type.
1217 auto vectorType = cast<VectorType>(Val: op->getResultTypes()[0]);
1218 resultType = gpu::MMAMatrixType::get(shape: resultType.getShape(),
1219 elementType: vectorType.getElementType(),
1220 operand: resultType.getOperand());
1221 }
1222
1223 Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1224 location: op->getLoc(), args&: resultType, args&: matrixOperands, args&: opType);
1225 valueMapping[op->getResult(idx: 0)] = newOp;
1226 return success();
1227}
1228
1229void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
1230 bool useNvGpu) {
1231 if (!useNvGpu) {
1232 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1233 arg: patterns.getContext());
1234 return;
1235 }
1236 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
1237 patterns.add<CombineTransferReadOpTranspose>(arg: patterns.getContext());
1238}
1239
1240LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
1241 Operation *rootOp) {
1242 SetVector<Operation *> ops = getOpToConvert(op: rootOp, /*useNvGpu=*/false);
1243 llvm::DenseMap<Value, Value> valueMapping;
1244
1245 auto globalRes = LogicalResult::success();
1246 for (Operation *op : ops) {
1247 LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1248 // Apparently callers do not want to early exit on failure here.
1249 auto res = LogicalResult::success();
1250 if (auto transferRead = dyn_cast<vector::TransferReadOp>(Val: op)) {
1251 res = convertTransferReadOp(rewriter, op: transferRead, valueMapping);
1252 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(Val: op)) {
1253 res = convertTransferWriteOp(rewriter, op: transferWrite, valueMapping);
1254 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(Val: op)) {
1255 res = convertContractOp(rewriter, op: contractOp, valueMapping);
1256 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(Val: op)) {
1257 res = convertConstantOp(rewriter, op: constantOp, valueMapping);
1258 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(Val: op)) {
1259 res = convertBroadcastOp(rewriter, op: broadcastOp, valueMapping);
1260 } else if (auto forOp = dyn_cast<scf::ForOp>(Val: op)) {
1261 res = convertForOp(rewriter, op: forOp, valueMapping);
1262 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(Val: op)) {
1263 res = convertYieldOp(rewriter, op: yieldOp, valueMapping);
1264 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1265 res = convertElementwiseOp(rewriter, op, opType: *elementwiseType, valueMapping);
1266 }
1267 if (failed(Result: res))
1268 globalRes = failure();
1269 }
1270 return globalRes;
1271}
1272
1273LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
1274 Operation *rootOp) {
1275 SetVector<Operation *> ops = getOpToConvert(op: rootOp, /*useNvGpu=*/true);
1276 llvm::DenseMap<Value, Value> valueMapping;
1277 for (Operation *op : ops) {
1278 if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
1279 .Case(caseFn: [&](vector::TransferReadOp transferReadOp) {
1280 return convertTransferReadToLoads(rewriter, op: transferReadOp,
1281 valueMapping);
1282 })
1283 .Case(caseFn: [&](vector::TransferWriteOp transferWriteOp) {
1284 return convertTransferWriteToStores(rewriter, op: transferWriteOp,
1285 valueMapping);
1286 })
1287 .Case(caseFn: [&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1288 return convertExtractStridedSlice(rewriter, op: extractStridedSliceOp,
1289 valueMapping);
1290 })
1291 .Case(caseFn: [&](vector::ContractionOp contractionOp) {
1292 return convertContractOpToMmaSync(rewriter, op: contractionOp,
1293 valueMapping);
1294 })
1295 .Case(caseFn: [&](scf::ForOp forOp) {
1296 return convertForOp(rewriter, op: forOp, valueMapping);
1297 })
1298 .Case(caseFn: [&](scf::YieldOp yieldOp) {
1299 return convertYieldOp(rewriter, op: yieldOp, valueMapping);
1300 })
1301 .Case(caseFn: [&](arith::ConstantOp constOp) {
1302 return convertConstantOpMmaSync(rewriter, op: constOp, valueMapping);
1303 })
1304 .Default(defaultFn: [&](Operation *op) {
1305 return op->emitError() << "unhandled vector to mma type: " << *op;
1306 })
1307 .failed()) {
1308 return op->emitOpError()
1309 << "failed to convert op during vector-to-nvgpu conversion";
1310 }
1311 }
1312 return success();
1313}
1314
1315namespace {
1316
1317struct ConvertVectorToGPUPass
1318 : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1319
1320 explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1321 useNvGpu.setValue(V: useNvGpu_);
1322 }
1323
1324 void runOnOperation() override {
1325 RewritePatternSet patterns(&getContext());
1326 populatePrepareVectorToMMAPatterns(patterns, useNvGpu: useNvGpu.getValue());
1327 if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns))))
1328 return signalPassFailure();
1329
1330 IRRewriter rewriter(&getContext());
1331 if (useNvGpu) {
1332 if (failed(
1333 Result: convertVectorToNVVMCompatibleMMASync(rewriter, rootOp: getOperation())))
1334 return signalPassFailure();
1335 return;
1336 }
1337 (void)convertVectorToMMAOps(rewriter, rootOp: getOperation());
1338 }
1339};
1340
1341} // namespace
1342
1343std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1344 return std::make_unique<ConvertVectorToGPUPass>(args&: useNvGpu);
1345}
1346

source code of mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp