1//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to GPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
14
15#include <type_traits>
16
17#include "mlir/Analysis/SliceAnalysis.h"
18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/GPU/IR/GPUDialect.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
24#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
25#include "mlir/Dialect/SCF/IR/SCF.h"
26#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
29#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/BuiltinOps.h"
32#include "mlir/IR/Region.h"
33#include "mlir/Pass/Pass.h"
34#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35#include "mlir/Transforms/Passes.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38
39#define DEBUG_TYPE "vector-to-gpu"
40#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41#define DBGSNL() (llvm::dbgs() << "\n")
42
43namespace mlir {
44#define GEN_PASS_DEF_CONVERTVECTORTOGPU
45#include "mlir/Conversion/Passes.h.inc"
46} // namespace mlir
47
48using namespace mlir;
49
50/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51/// AffineMap representing offsets to apply to indices, the function fills
52/// `indices` with the original indices plus the offsets. The offsets are
53/// applied by taking into account the permutation map of the transfer op. If
54/// the `offsetMap` has dimension placeholders, those should be provided in
55/// `dimValues`.
56template <typename TransferOpType>
57static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58 AffineMap offsetMap, ArrayRef<Value> dimValues,
59 SmallVector<Value, 4> &indices) {
60 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61 Location loc = xferOp.getLoc();
62 unsigned offsetsIdx = 0;
63 for (auto expr : xferOp.getPermutationMap().getResults()) {
64 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65 Value prevIdx = indices[dim.getPosition()];
66 SmallVector<OpFoldResult, 3> dims(dimValues);
67 dims.push_back(prevIdx);
68 AffineExpr d0 = rewriter.getAffineDimExpr(position: offsetMap.getNumDims());
69 indices[dim.getPosition()] = affine::makeComposedAffineApply(
70 rewriter, loc, d0 + offsetMap.getResult(idx: offsetsIdx++), dims);
71 continue;
72 }
73 }
74}
75
76// Return true if the contract op can be convert to MMA matmul.
77static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78 bool useNvGpu) {
79 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80 auto infer = [&](MapList m) {
81 return AffineMap::inferFromExprList(m, contract.getContext());
82 };
83 AffineExpr m, n, k;
84 bindDims(contract.getContext(), m, n, k);
85 auto iteratorTypes = contract.getIteratorTypes().getValue();
86 if (!(vector::isParallelIterator(attr: iteratorTypes[0]) &&
87 vector::isParallelIterator(attr: iteratorTypes[1]) &&
88 vector::isReductionIterator(attr: iteratorTypes[2])))
89 return false;
90
91 // The contract needs to represent a matmul to be able to convert to
92 // MMAMatrix matmul.
93 if (!useNvGpu &&
94 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
95 return false;
96 if (useNvGpu &&
97 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
98 return false;
99
100 return true;
101}
102
103// Return true if the given map represents a transposed matrix load,
104// i.e. (d0, d1, ...) -> (dn-1, dn-2).
105static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
106 MLIRContext *ctx = permutationMap.getContext();
107 // Local OpBuilder is fine here, we just build attributes.
108 OpBuilder b(ctx);
109 auto nDim = permutationMap.getNumDims();
110 AffineExpr zero = b.getAffineConstantExpr(constant: 0);
111 if (nDim < 2) {
112 // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
113 AffineExpr dim0 = b.getAffineDimExpr(position: 0);
114 return permutationMap == AffineMap::get(dimCount: 1, symbolCount: 0, results: {dim0, zero}, context: ctx);
115 }
116
117 AffineExpr innerDim = b.getAffineDimExpr(position: nDim - 1);
118 AffineExpr outerDim = b.getAffineDimExpr(position: nDim - 2);
119 // Support both transposed and transposed+broadcasted cases.
120 return permutationMap == AffineMap::get(dimCount: nDim, symbolCount: 0, results: {innerDim, outerDim}, context: ctx) ||
121 permutationMap == AffineMap::get(dimCount: nDim, symbolCount: 0, results: {innerDim, zero}, context: ctx);
122}
123
124// Return the stide for the second-to-last dimension of |type| if it is a memref
125// and has a constant stride.
126static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
127 auto memrefType = dyn_cast<MemRefType>(type);
128 if (!memrefType)
129 return false;
130 // If the memref is 0 or 1D the horizontal stride is 0.
131 if (memrefType.getRank() < 2)
132 return 0;
133 int64_t offset = 0;
134 SmallVector<int64_t, 2> strides;
135 if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
136 strides.back() != 1)
137 return std::nullopt;
138 int64_t stride = strides[strides.size() - 2];
139 if (stride == ShapedType::kDynamic)
140 return std::nullopt;
141 return stride;
142}
143
144// Return true if the transfer op can be converted to a MMA matrix load.
145static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
146 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147 readOp.getVectorType().getRank() != 2)
148 return false;
149 if (!getStaticallyKnownRowStride(readOp.getShapedType()))
150 return false;
151
152 // Only allow integer types if the signedness can be inferred.
153 if (readOp.getVectorType().getElementType().isInteger(8))
154 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155 !isa<arith::ExtUIOp>(*readOp->user_begin())))
156 return false;
157
158 AffineMap map = readOp.getPermutationMap();
159 MLIRContext *ctx = readOp.getContext();
160 AffineExpr innerDim = getAffineDimExpr(position: map.getNumDims() - 1, context: ctx);
161 AffineExpr zero = getAffineConstantExpr(constant: 0, context: ctx);
162 auto broadcastInnerDim =
163 AffineMap::get(dimCount: map.getNumDims(), symbolCount: 0, results: {zero, innerDim}, context: ctx);
164 return map.isMinorIdentity() || map == broadcastInnerDim ||
165 isTransposeMatrixLoadMap(permutationMap: map);
166}
167
168// Return true if the transfer op can be converted to a MMA matrix store.
169static bool
170transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
171 // TODO: support 0-d corner case.
172 if (writeOp.getTransferRank() == 0)
173 return false;
174
175 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176 writeOp.getVectorType().getRank() != 2)
177 return false;
178 if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
179 return false;
180 // TODO: Support transpose once it is added to GPU dialect ops.
181 if (!writeOp.getPermutationMap().isMinorIdentity())
182 return false;
183 return true;
184}
185
186/// Return true if the constant is a splat to a 2D vector so that it can be
187/// converted to a MMA constant matrix op.
188static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
189 auto vecType = dyn_cast<VectorType>(constantOp.getType());
190 if (!vecType || vecType.getRank() != 2)
191 return false;
192 return isa<SplatElementsAttr>(constantOp.getValue());
193}
194
195/// Return true if this is a broadcast from scalar to a 2D vector.
196static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
197 return broadcastOp.getResultVectorType().getRank() == 2;
198}
199
200/// Return true if this integer extend op can be folded into a contract op.
201template <typename ExtOpTy>
202static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203 auto transferReadOp =
204 extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
205 if (!transferReadOp)
206 return false;
207 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
208}
209
210static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
211
212/// Return the MMA elementwise enum associated with `op` if it is supported.
213/// Return `std::nullopt` otherwise.
214static std::optional<gpu::MMAElementwiseOp>
215convertElementwiseOpToMMA(Operation *op) {
216 if (isa<arith::AddFOp>(op))
217 return gpu::MMAElementwiseOp::ADDF;
218 if (isa<arith::MulFOp>(op))
219 return gpu::MMAElementwiseOp::MULF;
220 if (isa<arith::SubFOp>(op))
221 return gpu::MMAElementwiseOp::SUBF;
222 if (isa<arith::MaximumFOp>(op))
223 return gpu::MMAElementwiseOp::MAXF;
224 if (isa<arith::MinimumFOp>(op))
225 return gpu::MMAElementwiseOp::MINF;
226 if (isa<arith::DivFOp>(op))
227 return gpu::MMAElementwiseOp::DIVF;
228 if (isa<arith::AddIOp>(op))
229 return gpu::MMAElementwiseOp::ADDI;
230 if (isa<arith::MulIOp>(op))
231 return gpu::MMAElementwiseOp::MULI;
232 if (isa<arith::SubIOp>(op))
233 return gpu::MMAElementwiseOp::SUBI;
234 if (isa<arith::DivSIOp>(op))
235 return gpu::MMAElementwiseOp::DIVS;
236 if (isa<arith::DivUIOp>(op))
237 return gpu::MMAElementwiseOp::DIVU;
238 if (isa<arith::NegFOp>(op))
239 return gpu::MMAElementwiseOp::NEGATEF;
240 if (isa<arith::ExtFOp>(op))
241 return gpu::MMAElementwiseOp::EXTF;
242 return std::nullopt;
243}
244
245/// Return true if the op is supported as elementwise op on MMAMatrix type.
246static bool elementwiseSupportsMMAMatrixType(Operation *op) {
247 return convertElementwiseOpToMMA(op).has_value();
248}
249
250/// Returns true if the extract strided slice op is supported with `mma.sync`
251/// path.
252static bool
253extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
254
255 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
256 nvgpu::getWarpMatrixInfo(op);
257 if (failed(warpMatrixInfo))
258 return false;
259
260 FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
261 if (failed(contractOp))
262 return false;
263
264 // Handle vector.extract_strided_slice on registers containing
265 // matrixB and matrixC operands. vector.extract_strided_slice op
266 // is not supported on registers containing matrixA operands.
267 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
268 return (cast<VectorType>(op->getResult(0).getType()) ==
269 cast<VectorType>((*contractOp).getRhs().getType()));
270 if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
271 return (cast<VectorType>(op->getResult(0).getType()) ==
272 cast<VectorType>((*contractOp).getAcc().getType()));
273
274 return false;
275}
276
277static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
278 if (isa<scf::ForOp, scf::YieldOp>(op))
279 return true;
280 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
281 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
282 : transferReadSupportsMMAMatrixType(transferRead);
283 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
284 return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
285 : transferWriteSupportsMMAMatrixType(transferWrite);
286 if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
287 return useNvGpu &&
288 extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
289 if (auto contract = dyn_cast<vector::ContractionOp>(op))
290 return contractSupportsMMAMatrixType(contract, useNvGpu);
291 if (auto constant = dyn_cast<arith::ConstantOp>(op))
292 return constantSupportsMMAMatrixType(constant);
293 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
294 return broadcastSupportsMMAMatrixType(broadcast);
295 if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
296 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
297 if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
298 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
299 if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
300 return fpExtendSupportsMMAMatrixType(fpExtend);
301 return elementwiseSupportsMMAMatrixType(op);
302}
303
304/// Return an unsorted slice handling scf.for region differently than
305/// `getSlice`. In scf.for we only want to include as part of the slice elements
306/// that are part of the use/def chain.
307static SetVector<Operation *>
308getSliceContract(Operation *op,
309 const BackwardSliceOptions &backwardSliceOptions,
310 const ForwardSliceOptions &forwardSliceOptions) {
311 SetVector<Operation *> slice;
312 slice.insert(X: op);
313 unsigned currentIndex = 0;
314 SetVector<Operation *> backwardSlice;
315 SetVector<Operation *> forwardSlice;
316 while (currentIndex != slice.size()) {
317 auto *currentOp = (slice)[currentIndex];
318 // Compute and insert the backwardSlice starting from currentOp.
319 backwardSlice.clear();
320 LogicalResult result =
321 getBackwardSlice(op: currentOp, backwardSlice: &backwardSlice, options: backwardSliceOptions);
322 assert(result.succeeded() && "expected a backward slice");
323 (void)result;
324 slice.insert_range(R&: backwardSlice);
325
326 // Compute and insert the forwardSlice starting from currentOp.
327 forwardSlice.clear();
328 // Special case for ForOp, we don't want to include the whole region but
329 // only the value using the region arguments.
330 // TODO: We should refine this to only care about the region arguments being
331 // converted to matrix type.
332 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
333 for (Value forOpResult : forOp.getResults())
334 getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
335 for (BlockArgument &arg : forOp.getRegionIterArgs())
336 getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
337 } else {
338 getForwardSlice(op: currentOp, forwardSlice: &forwardSlice, options: forwardSliceOptions);
339 }
340 slice.insert_range(R&: forwardSlice);
341 ++currentIndex;
342 }
343 return slice;
344}
345
346// Analyze slice of operations based on convert op to figure out if the whole
347// slice can be converted to MMA operations.
348static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
349 bool useNvGpu) {
350 auto hasVectorDest = [](Operation *op) {
351 return llvm::any_of(Range: op->getResultTypes(), P: llvm::IsaPred<VectorType>);
352 };
353 BackwardSliceOptions backwardSliceOptions;
354 backwardSliceOptions.filter = hasVectorDest;
355
356 auto hasVectorSrc = [](Operation *op) {
357 return llvm::any_of(Range: op->getOperandTypes(), P: llvm::IsaPred<VectorType>);
358 };
359 ForwardSliceOptions forwardSliceOptions;
360 forwardSliceOptions.filter = hasVectorSrc;
361
362 SetVector<Operation *> opToConvert;
363 op->walk([&](vector::ContractionOp contract) {
364 if (opToConvert.contains(key: contract.getOperation()))
365 return;
366 SetVector<Operation *> dependentOps =
367 getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
368 // If any instruction cannot use MMA matrix type drop the whole
369 // chain. MMA matrix are stored in an opaque type so they cannot be used
370 // by all operations.
371 if (llvm::any_of(Range&: dependentOps, P: [useNvGpu](Operation *op) {
372 if (!supportsMMaMatrixType(op, useNvGpu)) {
373 LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
374 return true;
375 }
376 return false;
377 }))
378 return;
379
380 opToConvert.insert_range(R&: dependentOps);
381 });
382 // Sort the operations so that we can convert them in topological order.
383 return topologicalSort(toSort: opToConvert);
384}
385
386namespace {
387// Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
388// to MMA matmul.
389struct PrepareContractToGPUMMA
390 : public OpRewritePattern<vector::ContractionOp> {
391 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
392
393 LogicalResult matchAndRewrite(vector::ContractionOp op,
394 PatternRewriter &rewriter) const override {
395 Location loc = op.getLoc();
396 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
397
398 // Set up the parallel/reduction structure in right form.
399 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
400 auto infer = [&](MapList m) {
401 return AffineMap::inferFromExprList(m, op.getContext());
402 };
403 AffineExpr m, n, k;
404 bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k);
405 static constexpr std::array<int64_t, 2> perm = {1, 0};
406 auto iteratorTypes = op.getIteratorTypes().getValue();
407 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
408 if (!(vector::isParallelIterator(attr: iteratorTypes[0]) &&
409 vector::isParallelIterator(attr: iteratorTypes[1]) &&
410 vector::isReductionIterator(attr: iteratorTypes[2])))
411 return rewriter.notifyMatchFailure(op, "not a gemm contraction");
412 //
413 // Two outer parallel, one inner reduction (matmat flavor).
414 //
415 // This is the classical row-major matmul, nothing to do.
416 if (maps == infer({{m, k}, {k, n}, {m, n}}))
417 return rewriter.notifyMatchFailure(op, "contraction already prepared");
418 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
419 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
420 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
421 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
422 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
423 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
424 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
425 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
426 std::swap(a&: rhs, b&: lhs);
427 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
428 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
429 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
430 std::swap(a&: rhs, b&: lhs);
431 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
432 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
433 std::swap(a&: lhs, b&: rhs);
434 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
435 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
436 std::swap(a&: lhs, b&: rhs);
437 } else {
438 // TODO: llvm_unreachable ?
439 return rewriter.notifyMatchFailure(op, "unexpected contraction case");
440 }
441 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
442 op, lhs, rhs, res,
443 rewriter.getAffineMapArrayAttr(values: infer({{m, k}, {k, n}, {m, n}})),
444 op.getIteratorTypes());
445 return success();
446 }
447};
448
449// Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
450// row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
451// respectively. We can fold the transpose operation when loading the data from
452// Shared Memory to registers.
453struct CombineTransferReadOpTranspose final
454 : public OpRewritePattern<vector::TransposeOp> {
455 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
456
457 LogicalResult matchAndRewrite(vector::TransposeOp op,
458 PatternRewriter &rewriter) const override {
459 // Look through integer extend ops.
460 Value source = op.getVector();
461 Type resultType = op.getType();
462 Operation *extOp;
463 if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
464 (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
465 (extOp = source.getDefiningOp<arith::ExtFOp>())) {
466 source = extOp->getOperand(idx: 0);
467 resultType =
468 VectorType::get(cast<VectorType>(resultType).getShape(),
469 cast<VectorType>(source.getType()).getElementType());
470 }
471
472 auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
473 if (!transferReadOp)
474 return rewriter.notifyMatchFailure(op, "no transfer read");
475
476 // TODO: support 0-d corner case.
477 if (transferReadOp.getTransferRank() == 0)
478 return rewriter.notifyMatchFailure(op, "0-D transfer read");
479
480 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
481 return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
482
483 AffineMap permutationMap =
484 AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
485 AffineMap newMap =
486 permutationMap.compose(transferReadOp.getPermutationMap());
487
488 auto loc = op.getLoc();
489 Value result =
490 rewriter
491 .create<vector::TransferReadOp>(
492 loc, resultType, transferReadOp.getBase(),
493 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
494 transferReadOp.getPadding(), transferReadOp.getMask(),
495 transferReadOp.getInBoundsAttr())
496 .getResult();
497
498 // Fuse through the integer extend op.
499 if (extOp) {
500 if (isa<arith::ExtSIOp>(extOp))
501 result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
502 .getResult();
503 else if (isa<arith::ExtUIOp>(extOp))
504 result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
505 .getResult();
506 else
507 result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
508 .getResult();
509 }
510
511 rewriter.replaceOp(op, result);
512 return success();
513 }
514};
515
516} // namespace
517
518// MMA types have different layout based on how they are used in matmul ops.
519// Figure the right layout to use by looking at op uses.
520// TODO: Change the GPU dialect to abstract the layout at the this level and
521// only care about it during lowering to NVVM.
522static const char *inferFragType(Operation *op) {
523 // We can have arith.ext ops before reaching contract ops. See through them
524 // and other kinds of elementwise ops.
525 if (op->hasOneUse()) {
526 Operation *userOp = *op->user_begin();
527 if (userOp->hasTrait<OpTrait::Elementwise>())
528 return inferFragType(op: userOp);
529 }
530
531 for (Operation *users : op->getUsers()) {
532 auto contract = dyn_cast<vector::ContractionOp>(users);
533 if (!contract)
534 continue;
535 assert(op->getNumResults() == 1);
536 if (contract.getLhs() == op->getResult(idx: 0))
537 return "AOp";
538 if (contract.getRhs() == op->getResult(idx: 0))
539 return "BOp";
540 }
541 return "COp";
542}
543
544static LogicalResult
545convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
546 llvm::DenseMap<Value, Value> &valueMapping) {
547 OpBuilder::InsertionGuard g(rewriter);
548 rewriter.setInsertionPoint(op);
549
550 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
551 assert(transferReadSupportsMMAMatrixType(op) &&
552 "expected convertible operation");
553
554 std::optional<int64_t> stride =
555 getStaticallyKnownRowStride(op.getShapedType());
556 if (!stride.has_value()) {
557 LLVM_DEBUG(DBGS() << "no stride\n");
558 return rewriter.notifyMatchFailure(op, "no stride");
559 }
560
561 AffineMap map = op.getPermutationMap();
562 bool isTranspose = isTransposeMatrixLoadMap(permutationMap: map);
563
564 // Handle broadcast by setting the stride to 0.
565 if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
566 assert(cstExpr.getValue() == 0);
567 stride = 0;
568 }
569
570 Value mappingResult = op.getResult();
571 auto elType = op.getVectorType().getElementType();
572 const char *fragType = inferFragType(op);
573 if (op->hasOneUse()) {
574 auto *user = *op->user_begin();
575 // Infer the signedness of the mma type from the integer extend.
576 if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
577 elType = IntegerType::get(
578 op.getContext(), cast<IntegerType>(elType).getWidth(),
579 isa<arith::ExtSIOp>(user) ? IntegerType::Signed
580 : IntegerType::Unsigned);
581 mappingResult = user->getResult(0);
582 }
583 }
584 gpu::MMAMatrixType type =
585 gpu::MMAMatrixType::get(shape: op.getVectorType().getShape(), elementType: elType, operand: fragType);
586 Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
587 op.getLoc(), type, op.getBase(), op.getIndices(),
588 rewriter.getIndexAttr(*stride),
589 isTranspose ? rewriter.getUnitAttr() : UnitAttr());
590 valueMapping[mappingResult] = load;
591
592 LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
593 return success();
594}
595
596static LogicalResult
597convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
598 llvm::DenseMap<Value, Value> &valueMapping) {
599 OpBuilder::InsertionGuard g(rewriter);
600 rewriter.setInsertionPoint(op);
601
602 assert(transferWriteSupportsMMAMatrixType(op));
603 std::optional<int64_t> stride =
604 getStaticallyKnownRowStride(op.getShapedType());
605 if (!stride.has_value()) {
606 LLVM_DEBUG(DBGS() << "no stride\n");
607 return rewriter.notifyMatchFailure(op, "no stride");
608 }
609
610 auto it = valueMapping.find(op.getVector());
611 if (it == valueMapping.end()) {
612 LLVM_DEBUG(DBGS() << "no mapping\n");
613 return rewriter.notifyMatchFailure(op, "no mapping");
614 }
615
616 Value matrix = it->second;
617 auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
618 op.getLoc(), matrix, op.getBase(), op.getIndices(),
619 rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
620 (void)store;
621
622 LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
623
624 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
625 rewriter.eraseOp(op: op);
626 return success();
627}
628
629/// Returns the vector type which represents a matrix fragment.
630static VectorType
631getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
632 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
633 regInfo.elementsPerRegister};
634 Type elType = regInfo.registerLLVMType;
635 if (auto vecType = dyn_cast<VectorType>(elType))
636 elType = vecType.getElementType();
637 return VectorType::get(shape, elType);
638}
639
640/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
641static LogicalResult
642convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
643 llvm::DenseMap<Value, Value> &valueMapping) {
644 OpBuilder::InsertionGuard g(rewriter);
645 rewriter.setInsertionPoint(op);
646
647 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
648 nvgpu::getWarpMatrixInfo(op: op);
649 if (failed(Result: warpMatrixInfo)) {
650 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
651 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
652 }
653
654 FailureOr<nvgpu::FragmentElementInfo> regInfo =
655 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
656 if (failed(Result: regInfo)) {
657 LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
658 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
659 }
660
661 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
662 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
663 if (!dense) {
664 LLVM_DEBUG(DBGS() << "not a splat\n");
665 return rewriter.notifyMatchFailure(op, "not a splat");
666 }
667
668 Value result = rewriter.create<arith::ConstantOp>(
669 op.getLoc(), vectorType,
670 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
671 valueMapping[op.getResult()] = result;
672 return success();
673}
674
675/// Check if the loaded matrix operand requires transposed.
676/// Transposed Map Example:
677/// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
678/// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
679/// The code below checks if the output 2D is transposed using a generalized
680/// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
681/// Returns : true; if m > n, false o.w.
682static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
683 mlir::AffineMap map = op.getPermutationMap();
684
685 if (map.getNumResults() != 2) {
686 LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
687 "is not a 2d operand\n");
688 return failure();
689 }
690
691 // Output 2D matrix dimensions in the order of d0, d1.
692 mlir::AffineExpr dM = map.getResult(idx: 0);
693 mlir::AffineExpr dN = map.getResult(idx: 1);
694
695 // Find the position of these expressions in the input.
696 auto exprM = dyn_cast<AffineDimExpr>(Val&: dM);
697 auto exprN = dyn_cast<AffineDimExpr>(Val&: dN);
698
699 if (!exprM || !exprN) {
700 LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
701 "expressions, then transpose cannot be determined.\n");
702 return failure();
703 }
704
705 return exprM.getPosition() > exprN.getPosition();
706}
707
708static LogicalResult
709creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
710 llvm::DenseMap<Value, Value> &valueMapping) {
711 OpBuilder::InsertionGuard g(rewriter);
712 rewriter.setInsertionPoint(op);
713 Location loc = op->getLoc();
714
715 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
716 nvgpu::getWarpMatrixInfo(op: op);
717 if (failed(Result: warpMatrixInfo)) {
718 LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
719 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
720 }
721
722 FailureOr<nvgpu::FragmentElementInfo> regInfo =
723 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
724 if (failed(Result: regInfo)) {
725 LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
726 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
727 }
728
729 FailureOr<bool> transpose = isTransposed(op);
730 if (failed(Result: transpose)) {
731 LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
732 return rewriter.notifyMatchFailure(
733 op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
734 }
735
736 FailureOr<nvgpu::LdMatrixParams> params =
737 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
738
739 if (failed(params)) {
740 LLVM_DEBUG(
741 DBGS()
742 << "failed to convert vector.transfer_read to ldmatrix. "
743 << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
744 return rewriter.notifyMatchFailure(
745 op, "failed to convert vector.transfer_read to ldmatrix; this op "
746 "likely should not be converted to a nvgpu.ldmatrix call.");
747 }
748
749 // Adjust the load offset.
750 auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
751 FailureOr<AffineMap> offsets =
752 nvgpu::getLaneIdToLdMatrixMatrixCoord(builder&: rewriter, loc, params: *params);
753 if (failed(Result: offsets)) {
754 LLVM_DEBUG(DBGS() << "no offsets\n");
755 return rewriter.notifyMatchFailure(op, "no offsets");
756 }
757
758 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
759
760 SmallVector<Value, 4> indices;
761 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
762 indices);
763
764 nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
765 loc, vectorType, op.getBase(), indices, *transpose, params->numTiles);
766 valueMapping[op] = newOp->getResult(0);
767 return success();
768}
769
770static LogicalResult
771createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
772 llvm::DenseMap<Value, Value> &valueMapping) {
773 OpBuilder::InsertionGuard g(rewriter);
774 rewriter.setInsertionPoint(op);
775
776 Location loc = op.getLoc();
777 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
778 nvgpu::getWarpMatrixInfo(op: op);
779 if (failed(Result: warpMatrixInfo))
780 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
781 FailureOr<nvgpu::FragmentElementInfo> regInfo =
782 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
783 if (failed(Result: regInfo)) {
784 return rewriter.notifyMatchFailure(
785 op, "Failed to deduce register fragment type during "
786 "conversion to distributed non-ldmatrix compatible load");
787 }
788
789 Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
790
791 // This is the individual element type.
792 Type loadedElType = regInfo->registerLLVMType;
793 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
794
795 Value fill = rewriter.create<arith::ConstantOp>(
796 op.getLoc(), vectorType.getElementType(),
797 rewriter.getZeroAttr(vectorType.getElementType()));
798 Value result =
799 rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
800
801 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
802
803 // If we are not transposing, then we can use vectorized loads. Otherwise, we
804 // must load each element individually.
805 if (!isTransposeLoad) {
806 if (!isa<VectorType>(Val: loadedElType)) {
807 loadedElType = VectorType::get({1}, loadedElType);
808 }
809
810 for (int i = 0; i < vectorType.getShape()[0]; i++) {
811 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
812 builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo);
813 if (failed(Result: coords))
814 return rewriter.notifyMatchFailure(op, "no coords");
815
816 Value logicalValueId = rewriter.create<arith::ConstantOp>(
817 loc, rewriter.getIndexType(),
818 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
819 SmallVector<Value, 4> newIndices;
820 getXferIndices<vector::TransferReadOp>(
821 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
822
823 Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
824 op.getBase(), newIndices);
825 result = rewriter.create<vector::InsertOp>(loc, el, result, i);
826 }
827 } else {
828 if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
829 loadedElType = vecType.getElementType();
830 }
831 for (int i = 0; i < vectorType.getShape()[0]; i++) {
832 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
833 innerIdx++) {
834
835 Value logicalValueId = rewriter.create<arith::ConstantOp>(
836 loc, rewriter.getIndexType(),
837 rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
838 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
839 builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo);
840 if (failed(Result: coords))
841 return rewriter.notifyMatchFailure(op, "no coords");
842
843 SmallVector<Value, 4> newIndices;
844 getXferIndices<vector::TransferReadOp>(
845 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
846 Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
847 op.getBase(), newIndices);
848 result = rewriter.create<vector::InsertOp>(
849 op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
850 }
851 }
852 }
853
854 valueMapping[op.getResult()] = result;
855 return success();
856}
857
858/// Return true if this is a shared memory memref type.
859static bool isSharedMemory(MemRefType type) {
860 auto addressSpace =
861 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
862 return addressSpace &&
863 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
864}
865
866/// Converts a `vector.transfer_read` operation directly to either a
867/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
868/// used when converting to `nvgpu.mma.sync` operations.
869static LogicalResult
870convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
871 llvm::DenseMap<Value, Value> &valueMapping) {
872 OpBuilder::InsertionGuard g(rewriter);
873 rewriter.setInsertionPoint(op);
874
875 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
876 nvgpu::getWarpMatrixInfo(op: op);
877 if (failed(Result: warpMatrixInfo))
878 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
879
880 bool isLdMatrixCompatible =
881 isSharedMemory(cast<MemRefType>(op.getBase().getType())) &&
882 nvgpu::inferTileWidthInBits(type: *warpMatrixInfo) == 128;
883
884 VectorType vecTy = op.getVectorType();
885 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
886
887 // When we are transposing the B operand, ldmatrix will only work if we have
888 // at least 8 rows to read and the width to read for the transpose is 128
889 // bits.
890 if (!op.getPermutationMap().isMinorIdentity() &&
891 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
892 vecTy.getDimSize(0) * bitWidth < 128))
893 isLdMatrixCompatible = false;
894
895 if (!isLdMatrixCompatible)
896 return createNonLdMatrixLoads(rewriter, op, valueMapping);
897
898 return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
899}
900
901static LogicalResult
902convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
903 llvm::DenseMap<Value, Value> &valueMapping) {
904 OpBuilder::InsertionGuard g(rewriter);
905 rewriter.setInsertionPoint(op);
906
907 Location loc = op->getLoc();
908 auto it = valueMapping.find(op.getVector());
909 if (it == valueMapping.end())
910 return rewriter.notifyMatchFailure(op, "no mapping");
911 Value matrix = it->second;
912
913 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
914 nvgpu::getWarpMatrixInfo(op: op);
915 if (failed(Result: warpMatrixInfo))
916 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
917 FailureOr<nvgpu::FragmentElementInfo> regInfo =
918 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
919 if (failed(Result: regInfo))
920 return rewriter.notifyMatchFailure(op, "not mma sync reg info");
921
922 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
923 Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr);
924
925 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
926 Value logicalValueId = rewriter.create<arith::ConstantOp>(
927 loc, rewriter.getIndexType(),
928 rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
929 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
930 builder&: rewriter, loc: op.getLoc(), fragmentType: *warpMatrixInfo);
931 if (failed(Result: coords))
932 return rewriter.notifyMatchFailure(op, "no coords");
933
934 Value el =
935 rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
936 SmallVector<Value, 4> newIndices;
937 getXferIndices<vector::TransferWriteOp>(
938 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
939 rewriter.create<vector::StoreOp>(loc, el, op.getBase(), newIndices);
940 }
941
942 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
943 rewriter.eraseOp(op: op);
944 return success();
945}
946
947static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
948 SmallVectorImpl<int64_t> &results) {
949 for (auto attr : arrayAttr)
950 results.push_back(cast<IntegerAttr>(attr).getInt());
951}
952
953static LogicalResult
954convertExtractStridedSlice(RewriterBase &rewriter,
955 vector::ExtractStridedSliceOp op,
956 llvm::DenseMap<Value, Value> &valueMapping) {
957 OpBuilder::InsertionGuard g(rewriter);
958 rewriter.setInsertionPoint(op);
959
960 Location loc = op->getLoc();
961
962 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
963 nvgpu::getWarpMatrixInfo(op: op);
964 if (failed(Result: warpMatrixInfo))
965 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
966
967 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
968 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
969 if (failed(Result: mmaSyncFragmentInfo))
970 return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
971
972 // Find the vector.transer_read whose result vector is being sliced.
973 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
974 if (!transferReadOp)
975 return rewriter.notifyMatchFailure(op, "no transfer read");
976
977 warpMatrixInfo = nvgpu::getWarpMatrixInfo(op: transferReadOp);
978 if (failed(Result: warpMatrixInfo))
979 return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
980
981 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
982 nvgpu::getMmaSyncRegisterType(type: *warpMatrixInfo);
983 if (failed(Result: ldFragmentInfo))
984 return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
985
986 assert(
987 (mmaSyncFragmentInfo->elementsPerRegister ==
988 ldFragmentInfo->elementsPerRegister) &&
989 "Number of elements per register should be same for load and mma.sync");
990
991 // Create vector.extract_strided_slice op for thread-owned fragments.
992 std::array<int64_t, 2> strides = {1,
993 1}; // stride for extract slice is always 1.
994 std::array<int64_t, 2> sliceShape = {
995 mmaSyncFragmentInfo->numRegistersPerFragment,
996 mmaSyncFragmentInfo->elementsPerRegister};
997 auto it = valueMapping.find(transferReadOp);
998 if (it == valueMapping.end())
999 return rewriter.notifyMatchFailure(op, "no mapping");
1000 auto sourceVector = it->second;
1001
1002 // offset and sizes at warp-level of onwership.
1003 SmallVector<int64_t> offsets;
1004 populateFromInt64AttrArray(op.getOffsets(), offsets);
1005
1006 SmallVector<int64_t> sizes;
1007 populateFromInt64AttrArray(op.getSizes(), sizes);
1008 ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1009
1010 // Compute offset in vector registers. Note that the mma.sync vector registers
1011 // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1012 // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1013 std::array<int64_t, 2> sliceOffset = {0, 0};
1014
1015 if (offsets[0] && offsets[1])
1016 return op->emitError() << "Slicing fragments in 2D is not supported. ";
1017 if (offsets[0])
1018 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1019 else if (offsets[1])
1020 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1021
1022 Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1023 loc, sourceVector, sliceOffset, sliceShape, strides);
1024
1025 valueMapping[op] = newOp;
1026 return success();
1027}
1028
1029static LogicalResult
1030convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1031 llvm::DenseMap<Value, Value> &valueMapping) {
1032 OpBuilder::InsertionGuard g(rewriter);
1033 rewriter.setInsertionPoint(op);
1034
1035 auto itA = valueMapping.find(op.getLhs());
1036 auto itB = valueMapping.find(op.getRhs());
1037 auto itC = valueMapping.find(op.getAcc());
1038 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1039 itC == valueMapping.end())
1040 return rewriter.notifyMatchFailure(op, "no mapping");
1041 Value opA = itA->second, opB = itB->second, opC = itC->second;
1042 Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1043 op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1044 /*b_transpose=*/UnitAttr());
1045 valueMapping[op.getResult()] = matmul;
1046 return success();
1047}
1048
1049static LogicalResult
1050convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1051 llvm::DenseMap<Value, Value> &valueMapping) {
1052 OpBuilder::InsertionGuard g(rewriter);
1053 rewriter.setInsertionPoint(op);
1054
1055 auto itA = valueMapping.find(op.getLhs());
1056 auto itB = valueMapping.find(op.getRhs());
1057 auto itC = valueMapping.find(op.getAcc());
1058 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1059 itC == valueMapping.end())
1060 return rewriter.notifyMatchFailure(op, "no mapping");
1061 Value opA = itA->second, opB = itB->second, opC = itC->second;
1062 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1063 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1064 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1065 Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1066 op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1067 valueMapping[op.getResult()] = matmul;
1068 return success();
1069}
1070
1071/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1072static LogicalResult
1073convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1074 llvm::DenseMap<Value, Value> &valueMapping) {
1075 OpBuilder::InsertionGuard g(rewriter);
1076 rewriter.setInsertionPoint(op);
1077
1078 assert(constantSupportsMMAMatrixType(op));
1079
1080 auto splat =
1081 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1082 auto scalarConstant =
1083 rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1084 const char *fragType = inferFragType(op);
1085 auto vecType = cast<VectorType>(op.getType());
1086 gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1087 shape: vecType.getShape(), elementType: vecType.getElementType(), operand: llvm::StringRef(fragType));
1088 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1089 op.getLoc(), type, scalarConstant);
1090 valueMapping[op.getResult()] = matrix;
1091 return success();
1092}
1093
1094/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1095static LogicalResult
1096convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1097 llvm::DenseMap<Value, Value> &valueMapping) {
1098 OpBuilder::InsertionGuard g(rewriter);
1099 rewriter.setInsertionPoint(op);
1100
1101 assert(broadcastSupportsMMAMatrixType(op));
1102
1103 const char *fragType = inferFragType(op);
1104 auto vecType = op.getResultVectorType();
1105 gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
1106 shape: vecType.getShape(), elementType: vecType.getElementType(), operand: llvm::StringRef(fragType));
1107 auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1108 op.getLoc(), type, op.getSource());
1109 valueMapping[op.getResult()] = matrix;
1110 return success();
1111}
1112
1113// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1114// updated and needs to be updated separately for the loop to be correct.
1115static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1116 scf::ForOp loop,
1117 ValueRange newInitArgs) {
1118 OpBuilder::InsertionGuard g(rewriter);
1119 rewriter.setInsertionPoint(loop);
1120
1121 // Create a new loop before the existing one, with the extra operands.
1122 rewriter.setInsertionPoint(loop);
1123 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1124 llvm::append_range(operands, newInitArgs);
1125 scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1126 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1127 operands);
1128 rewriter.eraseBlock(block: newLoop.getBody());
1129
1130 newLoop.getRegion().getBlocks().splice(
1131 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1132 for (Value operand : newInitArgs)
1133 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1134
1135 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1136 loop.getNumResults())))
1137 rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1138
1139 LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1140 LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1141 LLVM_DEBUG(DBGS() << "erase: " << loop);
1142
1143 rewriter.eraseOp(op: loop);
1144 return newLoop;
1145}
1146
1147static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1148 llvm::DenseMap<Value, Value> &valueMapping) {
1149 OpBuilder::InsertionGuard g(rewriter);
1150 rewriter.setInsertionPoint(op);
1151
1152 SmallVector<Value> newOperands;
1153 SmallVector<std::pair<size_t, size_t>> argMapping;
1154 for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1155 auto it = valueMapping.find(operand.value());
1156 if (it == valueMapping.end()) {
1157 LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1158 continue;
1159 }
1160 argMapping.push_back(std::make_pair(
1161 operand.index(), op.getInitArgs().size() + newOperands.size()));
1162 newOperands.push_back(it->second);
1163 }
1164
1165 scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1166 Block &loopBody = *newForOp.getBody();
1167 for (auto mapping : argMapping) {
1168 valueMapping[newForOp.getResult(mapping.first)] =
1169 newForOp.getResult(mapping.second);
1170 valueMapping[loopBody.getArgument(i: mapping.first +
1171 newForOp.getNumInductionVars())] =
1172 loopBody.getArgument(i: mapping.second + newForOp.getNumInductionVars());
1173 }
1174
1175 LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1176 return success();
1177}
1178
1179static LogicalResult
1180convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1181 llvm::DenseMap<Value, Value> &valueMapping) {
1182 OpBuilder::InsertionGuard g(rewriter);
1183 rewriter.setInsertionPoint(op);
1184
1185 auto loop = cast<scf::ForOp>(op->getParentOp());
1186 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1187 for (const auto &operand : llvm::enumerate(op.getOperands())) {
1188 auto it = valueMapping.find(operand.value());
1189 if (it == valueMapping.end())
1190 continue;
1191 // Replace the yield of old value with the for op argument to make it easier
1192 // to remove the dead code.
1193 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1194 yieldOperands.push_back(it->second);
1195 }
1196 rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1197
1198 LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1199 rewriter.eraseOp(op: op);
1200 return success();
1201}
1202
1203/// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1204static LogicalResult
1205convertElementwiseOp(RewriterBase &rewriter, Operation *op,
1206 gpu::MMAElementwiseOp opType,
1207 llvm::DenseMap<Value, Value> &valueMapping) {
1208 OpBuilder::InsertionGuard g(rewriter);
1209 rewriter.setInsertionPoint(op);
1210
1211 SmallVector<Value> matrixOperands;
1212 for (Value operand : op->getOperands()) {
1213 auto it = valueMapping.find(Val: operand);
1214 if (it == valueMapping.end())
1215 return rewriter.notifyMatchFailure(arg&: op, msg: "no mapping");
1216 matrixOperands.push_back(Elt: it->second);
1217 }
1218 auto resultType = cast<gpu::MMAMatrixType>(Val: matrixOperands[0].getType());
1219 if (opType == gpu::MMAElementwiseOp::EXTF) {
1220 // The floating point extension case has a different result type.
1221 auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1222 resultType = gpu::MMAMatrixType::get(shape: resultType.getShape(),
1223 elementType: vectorType.getElementType(),
1224 operand: resultType.getOperand());
1225 }
1226
1227 Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1228 op->getLoc(), resultType, matrixOperands, opType);
1229 valueMapping[op->getResult(idx: 0)] = newOp;
1230 return success();
1231}
1232
1233void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
1234 bool useNvGpu) {
1235 if (!useNvGpu) {
1236 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1237 arg: patterns.getContext());
1238 return;
1239 }
1240 vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
1241 patterns.add<CombineTransferReadOpTranspose>(arg: patterns.getContext());
1242}
1243
1244LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
1245 Operation *rootOp) {
1246 SetVector<Operation *> ops = getOpToConvert(op: rootOp, /*useNvGpu=*/false);
1247 llvm::DenseMap<Value, Value> valueMapping;
1248
1249 auto globalRes = LogicalResult::success();
1250 for (Operation *op : ops) {
1251 LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1252 // Apparently callers do not want to early exit on failure here.
1253 auto res = LogicalResult::success();
1254 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1255 res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1256 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1257 res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1258 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1259 res = convertContractOp(rewriter, contractOp, valueMapping);
1260 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1261 res = convertConstantOp(rewriter, constantOp, valueMapping);
1262 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1263 res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1264 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1265 res = convertForOp(rewriter, forOp, valueMapping);
1266 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1267 res = convertYieldOp(rewriter, yieldOp, valueMapping);
1268 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1269 res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1270 }
1271 if (failed(Result: res))
1272 globalRes = failure();
1273 }
1274 return globalRes;
1275}
1276
1277LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
1278 Operation *rootOp) {
1279 SetVector<Operation *> ops = getOpToConvert(op: rootOp, /*useNvGpu=*/true);
1280 llvm::DenseMap<Value, Value> valueMapping;
1281 for (Operation *op : ops) {
1282 if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
1283 .Case(caseFn: [&](vector::TransferReadOp transferReadOp) {
1284 return convertTransferReadToLoads(rewriter, transferReadOp,
1285 valueMapping);
1286 })
1287 .Case(caseFn: [&](vector::TransferWriteOp transferWriteOp) {
1288 return convertTransferWriteToStores(rewriter, transferWriteOp,
1289 valueMapping);
1290 })
1291 .Case(caseFn: [&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1292 return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1293 valueMapping);
1294 })
1295 .Case(caseFn: [&](vector::ContractionOp contractionOp) {
1296 return convertContractOpToMmaSync(rewriter, contractionOp,
1297 valueMapping);
1298 })
1299 .Case(caseFn: [&](scf::ForOp forOp) {
1300 return convertForOp(rewriter, forOp, valueMapping);
1301 })
1302 .Case(caseFn: [&](scf::YieldOp yieldOp) {
1303 return convertYieldOp(rewriter, yieldOp, valueMapping);
1304 })
1305 .Case(caseFn: [&](arith::ConstantOp constOp) {
1306 return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1307 })
1308 .Default(defaultFn: [&](Operation *op) {
1309 return op->emitError() << "unhandled vector to mma type: " << *op;
1310 })
1311 .failed()) {
1312 return op->emitOpError()
1313 << "failed to convert op during vector-to-nvgpu conversion";
1314 }
1315 }
1316 return success();
1317}
1318
1319namespace {
1320
1321struct ConvertVectorToGPUPass
1322 : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1323
1324 explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1325 useNvGpu.setValue(useNvGpu_);
1326 }
1327
1328 void runOnOperation() override {
1329 RewritePatternSet patterns(&getContext());
1330 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1331 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
1332 return signalPassFailure();
1333
1334 IRRewriter rewriter(&getContext());
1335 if (useNvGpu) {
1336 if (failed(
1337 convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1338 return signalPassFailure();
1339 return;
1340 }
1341 (void)convertVectorToMMAOps(rewriter, getOperation());
1342 }
1343};
1344
1345} // namespace
1346
1347std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1348 return std::make_unique<ConvertVectorToGPUPass>(args&: useNvGpu);
1349}
1350

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp