1 | //===- LegalizeToF32.cpp - Legalize functions on small floats ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements legalizing math operations on small floating-point |
10 | // types through arith.extf and arith.truncf. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Math/IR/Math.h" |
16 | #include "mlir/Dialect/Math/Transforms/Passes.h" |
17 | #include "mlir/IR/Diagnostics.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | #include "mlir/IR/TypeUtilities.h" |
20 | #include "mlir/Transforms/DialectConversion.h" |
21 | #include "llvm/ADT/STLExtras.h" |
22 | |
23 | namespace mlir::math { |
24 | #define GEN_PASS_DEF_MATHLEGALIZETOF32 |
25 | #include "mlir/Dialect/Math/Transforms/Passes.h.inc" |
26 | } // namespace mlir::math |
27 | |
28 | using namespace mlir; |
29 | namespace { |
30 | struct LegalizeToF32RewritePattern final : ConversionPattern { |
31 | LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context) |
32 | : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {} |
33 | LogicalResult |
34 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
35 | ConversionPatternRewriter &rewriter) const override; |
36 | }; |
37 | |
38 | struct LegalizeToF32Pass final |
39 | : mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> { |
40 | void runOnOperation() override; |
41 | }; |
42 | } // namespace |
43 | |
44 | void mlir::math::populateLegalizeToF32TypeConverter( |
45 | TypeConverter &typeConverter) { |
46 | typeConverter.addConversion( |
47 | callback: [](Type type) -> std::optional<Type> { return type; }); |
48 | typeConverter.addConversion(callback: [](FloatType type) -> std::optional<Type> { |
49 | if (type.getWidth() < 32) |
50 | return Float32Type::get(type.getContext()); |
51 | return std::nullopt; |
52 | }); |
53 | typeConverter.addConversion(callback: [](ShapedType type) -> std::optional<Type> { |
54 | if (auto elemTy = dyn_cast<FloatType>(type.getElementType())) |
55 | return type.clone(Float32Type::get(type.getContext())); |
56 | return std::nullopt; |
57 | }); |
58 | typeConverter.addTargetMaterialization( |
59 | callback: [](OpBuilder &b, Type target, ValueRange input, Location loc) { |
60 | return b.create<arith::ExtFOp>(loc, target, input); |
61 | }); |
62 | } |
63 | |
64 | void mlir::math::populateLegalizeToF32ConversionTarget( |
65 | ConversionTarget &target, TypeConverter &typeConverter) { |
66 | target.addDynamicallyLegalDialect<MathDialect>( |
67 | [&typeConverter](Operation *op) -> bool { |
68 | return typeConverter.isLegal(op); |
69 | }); |
70 | target.addLegalOp<FmaOp>(); |
71 | target.addLegalOp<arith::ExtFOp, arith::TruncFOp>(); |
72 | } |
73 | |
74 | LogicalResult LegalizeToF32RewritePattern::matchAndRewrite( |
75 | Operation *op, ArrayRef<Value> operands, |
76 | ConversionPatternRewriter &rewriter) const { |
77 | Location loc = op->getLoc(); |
78 | const TypeConverter *converter = getTypeConverter(); |
79 | FailureOr<Operation *> legalized = |
80 | convertOpResultTypes(op, operands, converter: *converter, rewriter); |
81 | if (failed(result: legalized)) |
82 | return failure(); |
83 | |
84 | SmallVector<Value> results = (*legalized)->getResults(); |
85 | for (auto [result, newType, origType] : llvm::zip_equal( |
86 | t&: results, u: (*legalized)->getResultTypes(), args: op->getResultTypes())) { |
87 | if (newType != origType) |
88 | result = rewriter.create<arith::TruncFOp>(loc, origType, result); |
89 | } |
90 | rewriter.replaceOp(op, newValues: results); |
91 | return success(); |
92 | } |
93 | |
94 | void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns, |
95 | TypeConverter &typeConverter) { |
96 | patterns.add<LegalizeToF32RewritePattern>(arg&: typeConverter, |
97 | args: patterns.getContext()); |
98 | } |
99 | |
100 | void LegalizeToF32Pass::runOnOperation() { |
101 | Operation *op = getOperation(); |
102 | MLIRContext &ctx = getContext(); |
103 | |
104 | TypeConverter typeConverter; |
105 | math::populateLegalizeToF32TypeConverter(typeConverter); |
106 | ConversionTarget target(ctx); |
107 | math::populateLegalizeToF32ConversionTarget(target, typeConverter); |
108 | RewritePatternSet patterns(&ctx); |
109 | math::populateLegalizeToF32Patterns(patterns, typeConverter); |
110 | if (failed(applyPartialConversion(op, target, std::move(patterns)))) |
111 | return signalPassFailure(); |
112 | } |
113 | |