1//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
10// operations on f32 input datatype
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
15
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19
20using namespace mlir;
21using namespace mlir::nvgpu;
22
23namespace {
24
25struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
26
27 using OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern;
28
29 MmaSyncF32ToTF32Pattern(MLIRContext *context,
30 nvgpu::MmaSyncF32Lowering precision)
31 : OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
32 precision(precision) {}
33
34 LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
35 PatternRewriter &rewriter) const override {
36 Location location = op->getLoc();
37
38 if (op->hasAttr(name: op.getTf32EnabledAttrName()) ||
39 !cast<VectorType>(Val: op.getMatrixA().getType()).getElementType().isF32())
40 return failure();
41
42 if (precision == MmaSyncF32Lowering::Unkown)
43 return emitError(loc: location, message: "MmaSync F32-to-TF32 cannot be lowered with "
44 "unknown precision level");
45
46 if (precision == MmaSyncF32Lowering::TF32x3)
47 return emitError(loc: location, message: "TF32x3 is not supported at the moment "
48 "for nvgpu.mma.sync on f32 datatype");
49
50 if (precision == MmaSyncF32Lowering::TF32) {
51 rewriter.modifyOpInPlace(
52 root: op, callable: [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
53 }
54
55 return success();
56 }
57
58private:
59 /// Precision for F32 Tensor Cores (TF32 or TF32x3)
60 nvgpu::MmaSyncF32Lowering precision;
61};
62
63} // namespace
64
65void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns(
66 RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
67
68 patterns.add<MmaSyncF32ToTF32Pattern>(arg: patterns.getContext(), args&: precision);
69}
70

source code of mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp