1//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements legalizing math operations on small floating-point
10// types through arith.extf and arith.truncf.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Math/IR/Math.h"
16#include "mlir/Dialect/Math/Transforms/Passes.h"
17#include "mlir/IR/Diagnostics.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/IR/TypeUtilities.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "llvm/ADT/STLExtras.h"
22
23namespace mlir::math {
24#define GEN_PASS_DEF_MATHLEGALIZETOF32
25#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
26} // namespace mlir::math
27
28using namespace mlir;
29namespace {
30struct LegalizeToF32RewritePattern final : ConversionPattern {
31 LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
32 : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
33 LogicalResult
34 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
35 ConversionPatternRewriter &rewriter) const override;
36};
37
38struct LegalizeToF32Pass final
39 : mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
40 void runOnOperation() override;
41};
42} // namespace
43
44void mlir::math::populateLegalizeToF32TypeConverter(
45 TypeConverter &typeConverter) {
46 typeConverter.addConversion(
47 callback: [](Type type) -> std::optional<Type> { return type; });
48 typeConverter.addConversion(callback: [](FloatType type) -> std::optional<Type> {
49 if (type.getWidth() < 32)
50 return Float32Type::get(type.getContext());
51 return std::nullopt;
52 });
53 typeConverter.addConversion(callback: [](ShapedType type) -> std::optional<Type> {
54 if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
55 return type.clone(Float32Type::get(type.getContext()));
56 return std::nullopt;
57 });
58 typeConverter.addTargetMaterialization(
59 callback: [](OpBuilder &b, Type target, ValueRange input, Location loc) {
60 return b.create<arith::ExtFOp>(loc, target, input);
61 });
62}
63
64void mlir::math::populateLegalizeToF32ConversionTarget(
65 ConversionTarget &target, TypeConverter &typeConverter) {
66 target.addDynamicallyLegalDialect<MathDialect>(
67 [&typeConverter](Operation *op) -> bool {
68 return typeConverter.isLegal(op);
69 });
70 target.addLegalOp<FmaOp>();
71 target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
72}
73
74LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
75 Operation *op, ArrayRef<Value> operands,
76 ConversionPatternRewriter &rewriter) const {
77 Location loc = op->getLoc();
78 const TypeConverter *converter = getTypeConverter();
79 FailureOr<Operation *> legalized =
80 convertOpResultTypes(op, operands, converter: *converter, rewriter);
81 if (failed(result: legalized))
82 return failure();
83
84 SmallVector<Value> results = (*legalized)->getResults();
85 for (auto [result, newType, origType] : llvm::zip_equal(
86 t&: results, u: (*legalized)->getResultTypes(), args: op->getResultTypes())) {
87 if (newType != origType)
88 result = rewriter.create<arith::TruncFOp>(loc, origType, result);
89 }
90 rewriter.replaceOp(op, newValues: results);
91 return success();
92}
93
94void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
95 TypeConverter &typeConverter) {
96 patterns.add<LegalizeToF32RewritePattern>(arg&: typeConverter,
97 args: patterns.getContext());
98}
99
100void LegalizeToF32Pass::runOnOperation() {
101 Operation *op = getOperation();
102 MLIRContext &ctx = getContext();
103
104 TypeConverter typeConverter;
105 math::populateLegalizeToF32TypeConverter(typeConverter);
106 ConversionTarget target(ctx);
107 math::populateLegalizeToF32ConversionTarget(target, typeConverter);
108 RewritePatternSet patterns(&ctx);
109 math::populateLegalizeToF32Patterns(patterns, typeConverter);
110 if (failed(applyPartialConversion(op, target, std::move(patterns))))
111 return signalPassFailure();
112}
113

source code of mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp