1//===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This pass promotes small floats (of some unsupported types T) to a supported
9// type U by wrapping all float operations on Ts with expansion to and
10// truncation from U, then operating on U.
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/Transforms/Passes.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/Location.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/ErrorHandling.h"
23#include <optional>
24
25namespace mlir::arith {
26#define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
27#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
28} // namespace mlir::arith
29
30using namespace mlir;
31
32namespace {
33struct EmulateUnsupportedFloatsPass
34 : arith::impl::ArithEmulateUnsupportedFloatsBase<
35 EmulateUnsupportedFloatsPass> {
36 using arith::impl::ArithEmulateUnsupportedFloatsBase<
37 EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
38
39 void runOnOperation() override;
40};
41
42struct EmulateFloatPattern final : ConversionPattern {
43 EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx)
44 : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
45
46 LogicalResult match(Operation *op) const override;
47 void rewrite(Operation *op, ArrayRef<Value> operands,
48 ConversionPatternRewriter &rewriter) const override;
49};
50} // end namespace
51
52/// Map strings to float types. This function is here because no one else needs
53/// it yet, feel free to abstract it out.
54static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
55 StringRef name) {
56 Builder b(ctx);
57 return llvm::StringSwitch<std::optional<FloatType>>(name)
58 .Case(S: "f8E5M2", Value: b.getFloat8E5M2Type())
59 .Case(S: "f8E4M3FN", Value: b.getFloat8E4M3FNType())
60 .Case(S: "f8E5M2FNUZ", Value: b.getFloat8E5M2FNUZType())
61 .Case(S: "f8E4M3FNUZ", Value: b.getFloat8E4M3FNUZType())
62 .Case(S: "bf16", Value: b.getBF16Type())
63 .Case(S: "f16", Value: b.getF16Type())
64 .Case(S: "f32", Value: b.getF32Type())
65 .Case(S: "f64", Value: b.getF64Type())
66 .Case(S: "f80", Value: b.getF80Type())
67 .Case(S: "f128", Value: b.getF128Type())
68 .Default(Value: std::nullopt);
69}
70
71LogicalResult EmulateFloatPattern::match(Operation *op) const {
72 if (getTypeConverter()->isLegal(op))
73 return failure();
74 // The rewrite doesn't handle cloning regions.
75 if (op->getNumRegions() != 0)
76 return failure();
77 return success();
78}
79
80void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
81 ConversionPatternRewriter &rewriter) const {
82 Location loc = op->getLoc();
83 const TypeConverter *converter = getTypeConverter();
84 SmallVector<Type> resultTypes;
85 if (failed(result: converter->convertTypes(types: op->getResultTypes(), results&: resultTypes))) {
86 // Note to anyone looking for this error message: this is a "can't happen".
87 // If you're seeing it, there's a bug.
88 op->emitOpError(message: "type conversion failed in float emulation");
89 return;
90 }
91 Operation *expandedOp =
92 rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
93 op->getAttrs(), op->getSuccessors(), /*regions=*/{});
94 SmallVector<Value> newResults(expandedOp->getResults());
95 for (auto [res, oldType, newType] : llvm::zip_equal(
96 MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
97 if (oldType != newType)
98 res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
99 }
100 rewriter.replaceOp(op, newValues: newResults);
101}
102
103void mlir::arith::populateEmulateUnsupportedFloatsConversions(
104 TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
105 converter.addConversion(callback: [sourceTypes = SmallVector<Type>(sourceTypes),
106 targetType](Type type) -> std::optional<Type> {
107 if (llvm::is_contained(Range: sourceTypes, Element: type))
108 return targetType;
109 if (auto shaped = dyn_cast<ShapedType>(type))
110 if (llvm::is_contained(sourceTypes, shaped.getElementType()))
111 return shaped.clone(targetType);
112 // All other types legal
113 return type;
114 });
115 converter.addTargetMaterialization(
116 callback: [](OpBuilder &b, Type target, ValueRange input, Location loc) {
117 return b.create<arith::ExtFOp>(loc, target, input);
118 });
119}
120
121void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
122 RewritePatternSet &patterns, TypeConverter &converter) {
123 patterns.add<EmulateFloatPattern>(arg&: converter, args: patterns.getContext());
124}
125
126void mlir::arith::populateEmulateUnsupportedFloatsLegality(
127 ConversionTarget &target, TypeConverter &converter) {
128 // Don't try to legalize functions and other ops that don't need expansion.
129 target.markUnknownOpDynamicallyLegal(fn: [](Operation *op) { return true; });
130 target.addDynamicallyLegalDialect<arith::ArithDialect>(
131 [&](Operation *op) -> std::optional<bool> {
132 return converter.isLegal(op);
133 });
134 // Manually mark arithmetic-performing vector instructions.
135 target.addDynamicallyLegalOp<
136 vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
137 vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
138 [&](Operation *op) { return converter.isLegal(op); });
139 target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
140 arith::ConstantOp, vector::SplatOp>();
141}
142
143void EmulateUnsupportedFloatsPass::runOnOperation() {
144 MLIRContext *ctx = &getContext();
145 Operation *op = getOperation();
146 SmallVector<Type> sourceTypes;
147 Type targetType;
148
149 std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
150 if (!maybeTargetType) {
151 emitError(UnknownLoc::get(ctx), "could not map target type '" +
152 targetTypeStr +
153 "' to a known floating-point type");
154 return signalPassFailure();
155 }
156 targetType = *maybeTargetType;
157 for (StringRef sourceTypeStr : sourceTypeStrs) {
158 std::optional<FloatType> maybeSourceType =
159 parseFloatType(ctx, sourceTypeStr);
160 if (!maybeSourceType) {
161 emitError(UnknownLoc::get(ctx), "could not map source type '" +
162 sourceTypeStr +
163 "' to a known floating-point type");
164 return signalPassFailure();
165 }
166 sourceTypes.push_back(*maybeSourceType);
167 }
168 if (sourceTypes.empty())
169 (void)emitOptionalWarning(
170 loc: std::nullopt,
171 args: "no source types specified, float emulation will do nothing");
172
173 if (llvm::is_contained(Range&: sourceTypes, Element: targetType)) {
174 emitError(UnknownLoc::get(ctx),
175 "target type cannot be an unsupported source type");
176 return signalPassFailure();
177 }
178 TypeConverter converter;
179 arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
180 targetType);
181 RewritePatternSet patterns(ctx);
182 arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
183 ConversionTarget target(getContext());
184 arith::populateEmulateUnsupportedFloatsLegality(target, converter);
185
186 if (failed(applyPartialConversion(op, target, std::move(patterns))))
187 signalPassFailure();
188}
189

source code of mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp