1 | //===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // This pass promotes small floats (of some unsupported types T) to a supported |
9 | // type U by wrapping all float operations on Ts with expansion to and |
10 | // truncation from U, then operating on U. |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "mlir/IR/Location.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Transforms/DialectConversion.h" |
21 | #include "llvm/ADT/STLExtras.h" |
22 | #include "llvm/Support/ErrorHandling.h" |
23 | #include <optional> |
24 | |
25 | namespace mlir::arith { |
26 | #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS |
27 | #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
28 | } // namespace mlir::arith |
29 | |
30 | using namespace mlir; |
31 | |
32 | namespace { |
33 | struct EmulateUnsupportedFloatsPass |
34 | : arith::impl::ArithEmulateUnsupportedFloatsBase< |
35 | EmulateUnsupportedFloatsPass> { |
36 | using arith::impl::ArithEmulateUnsupportedFloatsBase< |
37 | EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase; |
38 | |
39 | void runOnOperation() override; |
40 | }; |
41 | |
42 | struct EmulateFloatPattern final : ConversionPattern { |
43 | EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx) |
44 | : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {} |
45 | |
46 | LogicalResult match(Operation *op) const override; |
47 | void rewrite(Operation *op, ArrayRef<Value> operands, |
48 | ConversionPatternRewriter &rewriter) const override; |
49 | }; |
50 | } // end namespace |
51 | |
52 | /// Map strings to float types. This function is here because no one else needs |
53 | /// it yet, feel free to abstract it out. |
54 | static std::optional<FloatType> parseFloatType(MLIRContext *ctx, |
55 | StringRef name) { |
56 | Builder b(ctx); |
57 | return llvm::StringSwitch<std::optional<FloatType>>(name) |
58 | .Case(S: "f8E5M2" , Value: b.getFloat8E5M2Type()) |
59 | .Case(S: "f8E4M3FN" , Value: b.getFloat8E4M3FNType()) |
60 | .Case(S: "f8E5M2FNUZ" , Value: b.getFloat8E5M2FNUZType()) |
61 | .Case(S: "f8E4M3FNUZ" , Value: b.getFloat8E4M3FNUZType()) |
62 | .Case(S: "bf16" , Value: b.getBF16Type()) |
63 | .Case(S: "f16" , Value: b.getF16Type()) |
64 | .Case(S: "f32" , Value: b.getF32Type()) |
65 | .Case(S: "f64" , Value: b.getF64Type()) |
66 | .Case(S: "f80" , Value: b.getF80Type()) |
67 | .Case(S: "f128" , Value: b.getF128Type()) |
68 | .Default(Value: std::nullopt); |
69 | } |
70 | |
71 | LogicalResult EmulateFloatPattern::match(Operation *op) const { |
72 | if (getTypeConverter()->isLegal(op)) |
73 | return failure(); |
74 | // The rewrite doesn't handle cloning regions. |
75 | if (op->getNumRegions() != 0) |
76 | return failure(); |
77 | return success(); |
78 | } |
79 | |
80 | void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands, |
81 | ConversionPatternRewriter &rewriter) const { |
82 | Location loc = op->getLoc(); |
83 | const TypeConverter *converter = getTypeConverter(); |
84 | SmallVector<Type> resultTypes; |
85 | if (failed(result: converter->convertTypes(types: op->getResultTypes(), results&: resultTypes))) { |
86 | // Note to anyone looking for this error message: this is a "can't happen". |
87 | // If you're seeing it, there's a bug. |
88 | op->emitOpError(message: "type conversion failed in float emulation" ); |
89 | return; |
90 | } |
91 | Operation *expandedOp = |
92 | rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes, |
93 | op->getAttrs(), op->getSuccessors(), /*regions=*/{}); |
94 | SmallVector<Value> newResults(expandedOp->getResults()); |
95 | for (auto [res, oldType, newType] : llvm::zip_equal( |
96 | MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { |
97 | if (oldType != newType) |
98 | res = rewriter.create<arith::TruncFOp>(loc, oldType, res); |
99 | } |
100 | rewriter.replaceOp(op, newValues: newResults); |
101 | } |
102 | |
103 | void mlir::arith::populateEmulateUnsupportedFloatsConversions( |
104 | TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) { |
105 | converter.addConversion(callback: [sourceTypes = SmallVector<Type>(sourceTypes), |
106 | targetType](Type type) -> std::optional<Type> { |
107 | if (llvm::is_contained(Range: sourceTypes, Element: type)) |
108 | return targetType; |
109 | if (auto shaped = dyn_cast<ShapedType>(type)) |
110 | if (llvm::is_contained(sourceTypes, shaped.getElementType())) |
111 | return shaped.clone(targetType); |
112 | // All other types legal |
113 | return type; |
114 | }); |
115 | converter.addTargetMaterialization( |
116 | callback: [](OpBuilder &b, Type target, ValueRange input, Location loc) { |
117 | return b.create<arith::ExtFOp>(loc, target, input); |
118 | }); |
119 | } |
120 | |
121 | void mlir::arith::populateEmulateUnsupportedFloatsPatterns( |
122 | RewritePatternSet &patterns, TypeConverter &converter) { |
123 | patterns.add<EmulateFloatPattern>(arg&: converter, args: patterns.getContext()); |
124 | } |
125 | |
126 | void mlir::arith::populateEmulateUnsupportedFloatsLegality( |
127 | ConversionTarget &target, TypeConverter &converter) { |
128 | // Don't try to legalize functions and other ops that don't need expansion. |
129 | target.markUnknownOpDynamicallyLegal(fn: [](Operation *op) { return true; }); |
130 | target.addDynamicallyLegalDialect<arith::ArithDialect>( |
131 | [&](Operation *op) -> std::optional<bool> { |
132 | return converter.isLegal(op); |
133 | }); |
134 | // Manually mark arithmetic-performing vector instructions. |
135 | target.addDynamicallyLegalOp< |
136 | vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp, |
137 | vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>( |
138 | [&](Operation *op) { return converter.isLegal(op); }); |
139 | target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp, |
140 | arith::ConstantOp, vector::SplatOp>(); |
141 | } |
142 | |
143 | void EmulateUnsupportedFloatsPass::runOnOperation() { |
144 | MLIRContext *ctx = &getContext(); |
145 | Operation *op = getOperation(); |
146 | SmallVector<Type> sourceTypes; |
147 | Type targetType; |
148 | |
149 | std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr); |
150 | if (!maybeTargetType) { |
151 | emitError(UnknownLoc::get(ctx), "could not map target type '" + |
152 | targetTypeStr + |
153 | "' to a known floating-point type" ); |
154 | return signalPassFailure(); |
155 | } |
156 | targetType = *maybeTargetType; |
157 | for (StringRef sourceTypeStr : sourceTypeStrs) { |
158 | std::optional<FloatType> maybeSourceType = |
159 | parseFloatType(ctx, sourceTypeStr); |
160 | if (!maybeSourceType) { |
161 | emitError(UnknownLoc::get(ctx), "could not map source type '" + |
162 | sourceTypeStr + |
163 | "' to a known floating-point type" ); |
164 | return signalPassFailure(); |
165 | } |
166 | sourceTypes.push_back(*maybeSourceType); |
167 | } |
168 | if (sourceTypes.empty()) |
169 | (void)emitOptionalWarning( |
170 | loc: std::nullopt, |
171 | args: "no source types specified, float emulation will do nothing" ); |
172 | |
173 | if (llvm::is_contained(Range&: sourceTypes, Element: targetType)) { |
174 | emitError(UnknownLoc::get(ctx), |
175 | "target type cannot be an unsupported source type" ); |
176 | return signalPassFailure(); |
177 | } |
178 | TypeConverter converter; |
179 | arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes, |
180 | targetType); |
181 | RewritePatternSet patterns(ctx); |
182 | arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter); |
183 | ConversionTarget target(getContext()); |
184 | arith::populateEmulateUnsupportedFloatsLegality(target, converter); |
185 | |
186 | if (failed(applyPartialConversion(op, target, std::move(patterns)))) |
187 | signalPassFailure(); |
188 | } |
189 | |