1//===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This pass promotes small floats (of some unsupported types T) to a supported
9// type U by wrapping all float operations on Ts with expansion to and
10// truncation from U, then operating on U.
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/Transforms/Passes.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/Vector/IR/VectorOps.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/Location.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/ErrorHandling.h"
24#include <optional>
25
26namespace mlir::arith {
27#define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
28#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29} // namespace mlir::arith
30
31using namespace mlir;
32
33namespace {
34struct EmulateUnsupportedFloatsPass
35 : arith::impl::ArithEmulateUnsupportedFloatsBase<
36 EmulateUnsupportedFloatsPass> {
37 using arith::impl::ArithEmulateUnsupportedFloatsBase<
38 EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
39
40 void runOnOperation() override;
41};
42
43struct EmulateFloatPattern final : ConversionPattern {
44 EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
45 : ConversionPattern::ConversionPattern(
46 converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
47
48 LogicalResult
49 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
50 ConversionPatternRewriter &rewriter) const override;
51};
52} // end namespace
53
54LogicalResult EmulateFloatPattern::matchAndRewrite(
55 Operation *op, ArrayRef<Value> operands,
56 ConversionPatternRewriter &rewriter) const {
57 if (getTypeConverter()->isLegal(op))
58 return failure();
59 // The rewrite doesn't handle cloning regions.
60 if (op->getNumRegions() != 0)
61 return failure();
62
63 Location loc = op->getLoc();
64 const TypeConverter *converter = getTypeConverter();
65 SmallVector<Type> resultTypes;
66 if (failed(Result: converter->convertTypes(types: op->getResultTypes(), results&: resultTypes))) {
67 // Note to anyone looking for this error message: this is a "can't happen".
68 // If you're seeing it, there's a bug.
69 return op->emitOpError(message: "type conversion failed in float emulation");
70 }
71 Operation *expandedOp =
72 rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
73 op->getAttrs(), op->getSuccessors(), /*regions=*/{});
74 SmallVector<Value> newResults(expandedOp->getResults());
75 for (auto [res, oldType, newType] : llvm::zip_equal(
76 MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
77 if (oldType != newType) {
78 auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
79 truncFOp.setFastmath(arith::FastMathFlags::contract);
80 res = truncFOp.getResult();
81 }
82 }
83 rewriter.replaceOp(op, newValues: newResults);
84 return success();
85}
86
87void mlir::arith::populateEmulateUnsupportedFloatsConversions(
88 TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
89 converter.addConversion(callback: [sourceTypes = SmallVector<Type>(sourceTypes),
90 targetType](Type type) -> std::optional<Type> {
91 if (llvm::is_contained(Range: sourceTypes, Element: type))
92 return targetType;
93 if (auto shaped = dyn_cast<ShapedType>(type))
94 if (llvm::is_contained(sourceTypes, shaped.getElementType()))
95 return shaped.clone(targetType);
96 // All other types legal
97 return type;
98 });
99 converter.addTargetMaterialization(
100 [](OpBuilder &b, Type target, ValueRange input, Location loc) {
101 auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
102 extFOp.setFastmath(arith::FastMathFlags::contract);
103 return extFOp;
104 });
105}
106
107void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
108 RewritePatternSet &patterns, const TypeConverter &converter) {
109 patterns.add<EmulateFloatPattern>(arg: converter, args: patterns.getContext());
110}
111
112void mlir::arith::populateEmulateUnsupportedFloatsLegality(
113 ConversionTarget &target, const TypeConverter &converter) {
114 // Don't try to legalize functions and other ops that don't need expansion.
115 target.markUnknownOpDynamicallyLegal(fn: [](Operation *op) { return true; });
116 target.addDynamicallyLegalDialect<arith::ArithDialect>(
117 [&](Operation *op) -> std::optional<bool> {
118 return converter.isLegal(op);
119 });
120 // Manually mark arithmetic-performing vector instructions.
121 target.addDynamicallyLegalOp<
122 vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
123 vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
124 [&](Operation *op) { return converter.isLegal(op); });
125 target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126 arith::ConstantOp, vector::SplatOp>();
127}
128
129void EmulateUnsupportedFloatsPass::runOnOperation() {
130 MLIRContext *ctx = &getContext();
131 Operation *op = getOperation();
132 SmallVector<Type> sourceTypes;
133 Type targetType;
134
135 std::optional<FloatType> maybeTargetType =
136 arith::parseFloatType(ctx, targetTypeStr);
137 if (!maybeTargetType) {
138 emitError(UnknownLoc::get(ctx), "could not map target type '" +
139 targetTypeStr +
140 "' to a known floating-point type");
141 return signalPassFailure();
142 }
143 targetType = *maybeTargetType;
144 for (StringRef sourceTypeStr : sourceTypeStrs) {
145 std::optional<FloatType> maybeSourceType =
146 arith::parseFloatType(ctx, sourceTypeStr);
147 if (!maybeSourceType) {
148 emitError(UnknownLoc::get(ctx), "could not map source type '" +
149 sourceTypeStr +
150 "' to a known floating-point type");
151 return signalPassFailure();
152 }
153 sourceTypes.push_back(*maybeSourceType);
154 }
155 if (sourceTypes.empty())
156 (void)emitOptionalWarning(
157 loc: std::nullopt,
158 args: "no source types specified, float emulation will do nothing");
159
160 if (llvm::is_contained(Range&: sourceTypes, Element: targetType)) {
161 emitError(UnknownLoc::get(ctx),
162 "target type cannot be an unsupported source type");
163 return signalPassFailure();
164 }
165 TypeConverter converter;
166 arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
167 targetType);
168 RewritePatternSet patterns(ctx);
169 arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
170 ConversionTarget target(getContext());
171 arith::populateEmulateUnsupportedFloatsLegality(target, converter);
172
173 if (failed(applyPartialConversion(op, target, std::move(patterns))))
174 signalPassFailure();
175}
176

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp