1 | //===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // This pass promotes small floats (of some unsupported types T) to a supported |
9 | // type U by wrapping all float operations on Ts with expansion to and |
10 | // truncation from U, then operating on U. |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
17 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
18 | #include "mlir/IR/BuiltinTypes.h" |
19 | #include "mlir/IR/Location.h" |
20 | #include "mlir/IR/PatternMatch.h" |
21 | #include "mlir/Transforms/DialectConversion.h" |
22 | #include "llvm/ADT/STLExtras.h" |
23 | #include "llvm/Support/ErrorHandling.h" |
24 | #include <optional> |
25 | |
26 | namespace mlir::arith { |
27 | #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS |
28 | #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
29 | } // namespace mlir::arith |
30 | |
31 | using namespace mlir; |
32 | |
33 | namespace { |
34 | struct EmulateUnsupportedFloatsPass |
35 | : arith::impl::ArithEmulateUnsupportedFloatsBase< |
36 | EmulateUnsupportedFloatsPass> { |
37 | using arith::impl::ArithEmulateUnsupportedFloatsBase< |
38 | EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase; |
39 | |
40 | void runOnOperation() override; |
41 | }; |
42 | |
43 | struct EmulateFloatPattern final : ConversionPattern { |
44 | EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx) |
45 | : ConversionPattern::ConversionPattern( |
46 | converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {} |
47 | |
48 | LogicalResult |
49 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
50 | ConversionPatternRewriter &rewriter) const override; |
51 | }; |
52 | } // end namespace |
53 | |
54 | LogicalResult EmulateFloatPattern::matchAndRewrite( |
55 | Operation *op, ArrayRef<Value> operands, |
56 | ConversionPatternRewriter &rewriter) const { |
57 | if (getTypeConverter()->isLegal(op)) |
58 | return failure(); |
59 | // The rewrite doesn't handle cloning regions. |
60 | if (op->getNumRegions() != 0) |
61 | return failure(); |
62 | |
63 | Location loc = op->getLoc(); |
64 | const TypeConverter *converter = getTypeConverter(); |
65 | SmallVector<Type> resultTypes; |
66 | if (failed(Result: converter->convertTypes(types: op->getResultTypes(), results&: resultTypes))) { |
67 | // Note to anyone looking for this error message: this is a "can't happen". |
68 | // If you're seeing it, there's a bug. |
69 | return op->emitOpError(message: "type conversion failed in float emulation" ); |
70 | } |
71 | Operation *expandedOp = |
72 | rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes, |
73 | op->getAttrs(), op->getSuccessors(), /*regions=*/{}); |
74 | SmallVector<Value> newResults(expandedOp->getResults()); |
75 | for (auto [res, oldType, newType] : llvm::zip_equal( |
76 | MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { |
77 | if (oldType != newType) { |
78 | auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res); |
79 | truncFOp.setFastmath(arith::FastMathFlags::contract); |
80 | res = truncFOp.getResult(); |
81 | } |
82 | } |
83 | rewriter.replaceOp(op, newValues: newResults); |
84 | return success(); |
85 | } |
86 | |
87 | void mlir::arith::populateEmulateUnsupportedFloatsConversions( |
88 | TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) { |
89 | converter.addConversion(callback: [sourceTypes = SmallVector<Type>(sourceTypes), |
90 | targetType](Type type) -> std::optional<Type> { |
91 | if (llvm::is_contained(Range: sourceTypes, Element: type)) |
92 | return targetType; |
93 | if (auto shaped = dyn_cast<ShapedType>(type)) |
94 | if (llvm::is_contained(sourceTypes, shaped.getElementType())) |
95 | return shaped.clone(targetType); |
96 | // All other types legal |
97 | return type; |
98 | }); |
99 | converter.addTargetMaterialization( |
100 | [](OpBuilder &b, Type target, ValueRange input, Location loc) { |
101 | auto extFOp = b.create<arith::ExtFOp>(loc, target, input); |
102 | extFOp.setFastmath(arith::FastMathFlags::contract); |
103 | return extFOp; |
104 | }); |
105 | } |
106 | |
107 | void mlir::arith::populateEmulateUnsupportedFloatsPatterns( |
108 | RewritePatternSet &patterns, const TypeConverter &converter) { |
109 | patterns.add<EmulateFloatPattern>(arg: converter, args: patterns.getContext()); |
110 | } |
111 | |
112 | void mlir::arith::populateEmulateUnsupportedFloatsLegality( |
113 | ConversionTarget &target, const TypeConverter &converter) { |
114 | // Don't try to legalize functions and other ops that don't need expansion. |
115 | target.markUnknownOpDynamicallyLegal(fn: [](Operation *op) { return true; }); |
116 | target.addDynamicallyLegalDialect<arith::ArithDialect>( |
117 | [&](Operation *op) -> std::optional<bool> { |
118 | return converter.isLegal(op); |
119 | }); |
120 | // Manually mark arithmetic-performing vector instructions. |
121 | target.addDynamicallyLegalOp< |
122 | vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp, |
123 | vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>( |
124 | [&](Operation *op) { return converter.isLegal(op); }); |
125 | target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp, |
126 | arith::ConstantOp, vector::SplatOp>(); |
127 | } |
128 | |
129 | void EmulateUnsupportedFloatsPass::runOnOperation() { |
130 | MLIRContext *ctx = &getContext(); |
131 | Operation *op = getOperation(); |
132 | SmallVector<Type> sourceTypes; |
133 | Type targetType; |
134 | |
135 | std::optional<FloatType> maybeTargetType = |
136 | arith::parseFloatType(ctx, targetTypeStr); |
137 | if (!maybeTargetType) { |
138 | emitError(UnknownLoc::get(ctx), "could not map target type '" + |
139 | targetTypeStr + |
140 | "' to a known floating-point type" ); |
141 | return signalPassFailure(); |
142 | } |
143 | targetType = *maybeTargetType; |
144 | for (StringRef sourceTypeStr : sourceTypeStrs) { |
145 | std::optional<FloatType> maybeSourceType = |
146 | arith::parseFloatType(ctx, sourceTypeStr); |
147 | if (!maybeSourceType) { |
148 | emitError(UnknownLoc::get(ctx), "could not map source type '" + |
149 | sourceTypeStr + |
150 | "' to a known floating-point type" ); |
151 | return signalPassFailure(); |
152 | } |
153 | sourceTypes.push_back(*maybeSourceType); |
154 | } |
155 | if (sourceTypes.empty()) |
156 | (void)emitOptionalWarning( |
157 | loc: std::nullopt, |
158 | args: "no source types specified, float emulation will do nothing" ); |
159 | |
160 | if (llvm::is_contained(Range&: sourceTypes, Element: targetType)) { |
161 | emitError(UnknownLoc::get(ctx), |
162 | "target type cannot be an unsupported source type" ); |
163 | return signalPassFailure(); |
164 | } |
165 | TypeConverter converter; |
166 | arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes, |
167 | targetType); |
168 | RewritePatternSet patterns(ctx); |
169 | arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter); |
170 | ConversionTarget target(getContext()); |
171 | arith::populateEmulateUnsupportedFloatsLegality(target, converter); |
172 | |
173 | if (failed(applyPartialConversion(op, target, std::move(patterns)))) |
174 | signalPassFailure(); |
175 | } |
176 | |