| 1 | //===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // This pass promotes small floats (of some unsupported types T) to a supported |
| 9 | // type U by wrapping all float operations on Ts with expansion to and |
| 10 | // truncation from U, then operating on U. |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
| 14 | |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 17 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 18 | #include "mlir/IR/BuiltinTypes.h" |
| 19 | #include "mlir/IR/Location.h" |
| 20 | #include "mlir/IR/PatternMatch.h" |
| 21 | #include "mlir/Transforms/DialectConversion.h" |
| 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | #include "llvm/Support/ErrorHandling.h" |
| 24 | #include <optional> |
| 25 | |
| 26 | namespace mlir::arith { |
| 27 | #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS |
| 28 | #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
| 29 | } // namespace mlir::arith |
| 30 | |
| 31 | using namespace mlir; |
| 32 | |
| 33 | namespace { |
| 34 | struct EmulateUnsupportedFloatsPass |
| 35 | : arith::impl::ArithEmulateUnsupportedFloatsBase< |
| 36 | EmulateUnsupportedFloatsPass> { |
| 37 | using arith::impl::ArithEmulateUnsupportedFloatsBase< |
| 38 | EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase; |
| 39 | |
| 40 | void runOnOperation() override; |
| 41 | }; |
| 42 | |
| 43 | struct EmulateFloatPattern final : ConversionPattern { |
| 44 | EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx) |
| 45 | : ConversionPattern::ConversionPattern( |
| 46 | converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {} |
| 47 | |
| 48 | LogicalResult |
| 49 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 50 | ConversionPatternRewriter &rewriter) const override; |
| 51 | }; |
| 52 | } // end namespace |
| 53 | |
| 54 | LogicalResult EmulateFloatPattern::matchAndRewrite( |
| 55 | Operation *op, ArrayRef<Value> operands, |
| 56 | ConversionPatternRewriter &rewriter) const { |
| 57 | if (getTypeConverter()->isLegal(op)) |
| 58 | return failure(); |
| 59 | // The rewrite doesn't handle cloning regions. |
| 60 | if (op->getNumRegions() != 0) |
| 61 | return failure(); |
| 62 | |
| 63 | Location loc = op->getLoc(); |
| 64 | const TypeConverter *converter = getTypeConverter(); |
| 65 | SmallVector<Type> resultTypes; |
| 66 | if (failed(Result: converter->convertTypes(types: op->getResultTypes(), results&: resultTypes))) { |
| 67 | // Note to anyone looking for this error message: this is a "can't happen". |
| 68 | // If you're seeing it, there's a bug. |
| 69 | return op->emitOpError(message: "type conversion failed in float emulation" ); |
| 70 | } |
| 71 | Operation *expandedOp = |
| 72 | rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes, |
| 73 | op->getAttrs(), op->getSuccessors(), /*regions=*/{}); |
| 74 | SmallVector<Value> newResults(expandedOp->getResults()); |
| 75 | for (auto [res, oldType, newType] : llvm::zip_equal( |
| 76 | MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { |
| 77 | if (oldType != newType) { |
| 78 | auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res); |
| 79 | truncFOp.setFastmath(arith::FastMathFlags::contract); |
| 80 | res = truncFOp.getResult(); |
| 81 | } |
| 82 | } |
| 83 | rewriter.replaceOp(op, newValues: newResults); |
| 84 | return success(); |
| 85 | } |
| 86 | |
| 87 | void mlir::arith::populateEmulateUnsupportedFloatsConversions( |
| 88 | TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) { |
| 89 | converter.addConversion(callback: [sourceTypes = SmallVector<Type>(sourceTypes), |
| 90 | targetType](Type type) -> std::optional<Type> { |
| 91 | if (llvm::is_contained(Range: sourceTypes, Element: type)) |
| 92 | return targetType; |
| 93 | if (auto shaped = dyn_cast<ShapedType>(type)) |
| 94 | if (llvm::is_contained(sourceTypes, shaped.getElementType())) |
| 95 | return shaped.clone(targetType); |
| 96 | // All other types legal |
| 97 | return type; |
| 98 | }); |
| 99 | converter.addTargetMaterialization( |
| 100 | [](OpBuilder &b, Type target, ValueRange input, Location loc) { |
| 101 | auto extFOp = b.create<arith::ExtFOp>(loc, target, input); |
| 102 | extFOp.setFastmath(arith::FastMathFlags::contract); |
| 103 | return extFOp; |
| 104 | }); |
| 105 | } |
| 106 | |
| 107 | void mlir::arith::populateEmulateUnsupportedFloatsPatterns( |
| 108 | RewritePatternSet &patterns, const TypeConverter &converter) { |
| 109 | patterns.add<EmulateFloatPattern>(arg: converter, args: patterns.getContext()); |
| 110 | } |
| 111 | |
| 112 | void mlir::arith::populateEmulateUnsupportedFloatsLegality( |
| 113 | ConversionTarget &target, const TypeConverter &converter) { |
| 114 | // Don't try to legalize functions and other ops that don't need expansion. |
| 115 | target.markUnknownOpDynamicallyLegal(fn: [](Operation *op) { return true; }); |
| 116 | target.addDynamicallyLegalDialect<arith::ArithDialect>( |
| 117 | [&](Operation *op) -> std::optional<bool> { |
| 118 | return converter.isLegal(op); |
| 119 | }); |
| 120 | // Manually mark arithmetic-performing vector instructions. |
| 121 | target.addDynamicallyLegalOp< |
| 122 | vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp, |
| 123 | vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>( |
| 124 | [&](Operation *op) { return converter.isLegal(op); }); |
| 125 | target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp, |
| 126 | arith::ConstantOp, vector::SplatOp>(); |
| 127 | } |
| 128 | |
| 129 | void EmulateUnsupportedFloatsPass::runOnOperation() { |
| 130 | MLIRContext *ctx = &getContext(); |
| 131 | Operation *op = getOperation(); |
| 132 | SmallVector<Type> sourceTypes; |
| 133 | Type targetType; |
| 134 | |
| 135 | std::optional<FloatType> maybeTargetType = |
| 136 | arith::parseFloatType(ctx, targetTypeStr); |
| 137 | if (!maybeTargetType) { |
| 138 | emitError(UnknownLoc::get(ctx), "could not map target type '" + |
| 139 | targetTypeStr + |
| 140 | "' to a known floating-point type" ); |
| 141 | return signalPassFailure(); |
| 142 | } |
| 143 | targetType = *maybeTargetType; |
| 144 | for (StringRef sourceTypeStr : sourceTypeStrs) { |
| 145 | std::optional<FloatType> maybeSourceType = |
| 146 | arith::parseFloatType(ctx, sourceTypeStr); |
| 147 | if (!maybeSourceType) { |
| 148 | emitError(UnknownLoc::get(ctx), "could not map source type '" + |
| 149 | sourceTypeStr + |
| 150 | "' to a known floating-point type" ); |
| 151 | return signalPassFailure(); |
| 152 | } |
| 153 | sourceTypes.push_back(*maybeSourceType); |
| 154 | } |
| 155 | if (sourceTypes.empty()) |
| 156 | (void)emitOptionalWarning( |
| 157 | loc: std::nullopt, |
| 158 | args: "no source types specified, float emulation will do nothing" ); |
| 159 | |
| 160 | if (llvm::is_contained(Range&: sourceTypes, Element: targetType)) { |
| 161 | emitError(UnknownLoc::get(ctx), |
| 162 | "target type cannot be an unsupported source type" ); |
| 163 | return signalPassFailure(); |
| 164 | } |
| 165 | TypeConverter converter; |
| 166 | arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes, |
| 167 | targetType); |
| 168 | RewritePatternSet patterns(ctx); |
| 169 | arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter); |
| 170 | ConversionTarget target(getContext()); |
| 171 | arith::populateEmulateUnsupportedFloatsLegality(target, converter); |
| 172 | |
| 173 | if (failed(applyPartialConversion(op, target, std::move(patterns)))) |
| 174 | signalPassFailure(); |
| 175 | } |
| 176 | |