| 1 | //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 10 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
| 11 | #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h" |
| 12 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 13 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
| 14 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| 15 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 16 | #include "mlir/Transforms/DialectConversion.h" |
| 17 | #include "llvm/Support/FormatVariadic.h" |
| 18 | #include "llvm/Support/MathExtras.h" |
| 19 | #include <cassert> |
| 20 | |
| 21 | namespace mlir::memref { |
| 22 | #define GEN_PASS_DEF_MEMREFEMULATEWIDEINT |
| 23 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
| 24 | } // namespace mlir::memref |
| 25 | |
| 26 | using namespace mlir; |
| 27 | |
| 28 | namespace { |
| 29 | |
| 30 | //===----------------------------------------------------------------------===// |
| 31 | // ConvertMemRefAlloc |
| 32 | //===----------------------------------------------------------------------===// |
| 33 | |
| 34 | struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { |
| 35 | using OpConversionPattern::OpConversionPattern; |
| 36 | |
| 37 | LogicalResult |
| 38 | matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, |
| 39 | ConversionPatternRewriter &rewriter) const override { |
| 40 | Type newTy = getTypeConverter()->convertType(op.getType()); |
| 41 | if (!newTy) |
| 42 | return rewriter.notifyMatchFailure( |
| 43 | op->getLoc(), |
| 44 | llvm::formatv("failed to convert memref type: {0}" , op.getType())); |
| 45 | |
| 46 | rewriter.replaceOpWithNewOp<memref::AllocOp>( |
| 47 | op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), |
| 48 | adaptor.getAlignmentAttr()); |
| 49 | return success(); |
| 50 | } |
| 51 | }; |
| 52 | |
| 53 | //===----------------------------------------------------------------------===// |
| 54 | // ConvertMemRefLoad |
| 55 | //===----------------------------------------------------------------------===// |
| 56 | |
| 57 | struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { |
| 58 | using OpConversionPattern::OpConversionPattern; |
| 59 | |
| 60 | LogicalResult |
| 61 | matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, |
| 62 | ConversionPatternRewriter &rewriter) const override { |
| 63 | Type newResTy = getTypeConverter()->convertType(op.getType()); |
| 64 | if (!newResTy) |
| 65 | return rewriter.notifyMatchFailure( |
| 66 | op->getLoc(), llvm::formatv("failed to convert memref type: {0}" , |
| 67 | op.getMemRefType())); |
| 68 | |
| 69 | rewriter.replaceOpWithNewOp<memref::LoadOp>( |
| 70 | op, newResTy, adaptor.getMemref(), adaptor.getIndices(), |
| 71 | op.getNontemporal()); |
| 72 | return success(); |
| 73 | } |
| 74 | }; |
| 75 | |
| 76 | //===----------------------------------------------------------------------===// |
| 77 | // ConvertMemRefStore |
| 78 | //===----------------------------------------------------------------------===// |
| 79 | |
| 80 | struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> { |
| 81 | using OpConversionPattern::OpConversionPattern; |
| 82 | |
| 83 | LogicalResult |
| 84 | matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, |
| 85 | ConversionPatternRewriter &rewriter) const override { |
| 86 | Type newTy = getTypeConverter()->convertType(op.getMemRefType()); |
| 87 | if (!newTy) |
| 88 | return rewriter.notifyMatchFailure( |
| 89 | op->getLoc(), llvm::formatv("failed to convert memref type: {0}" , |
| 90 | op.getMemRefType())); |
| 91 | |
| 92 | rewriter.replaceOpWithNewOp<memref::StoreOp>( |
| 93 | op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(), |
| 94 | op.getNontemporal()); |
| 95 | return success(); |
| 96 | } |
| 97 | }; |
| 98 | |
| 99 | //===----------------------------------------------------------------------===// |
| 100 | // Pass Definition |
| 101 | //===----------------------------------------------------------------------===// |
| 102 | |
| 103 | struct EmulateWideIntPass final |
| 104 | : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> { |
| 105 | using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase; |
| 106 | |
| 107 | void runOnOperation() override { |
| 108 | if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) { |
| 109 | signalPassFailure(); |
| 110 | return; |
| 111 | } |
| 112 | |
| 113 | Operation *op = getOperation(); |
| 114 | MLIRContext *ctx = op->getContext(); |
| 115 | |
| 116 | arith::WideIntEmulationConverter typeConverter(widestIntSupported); |
| 117 | memref::populateMemRefWideIntEmulationConversions(typeConverter&: typeConverter); |
| 118 | ConversionTarget target(*ctx); |
| 119 | target.addDynamicallyLegalDialect< |
| 120 | arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>( |
| 121 | [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); |
| 122 | |
| 123 | RewritePatternSet patterns(ctx); |
| 124 | // Add common pattenrs to support contants, functions, etc. |
| 125 | arith::populateArithWideIntEmulationPatterns(typeConverter: typeConverter, patterns); |
| 126 | |
| 127 | memref::populateMemRefWideIntEmulationPatterns(typeConverter: typeConverter, patterns); |
| 128 | |
| 129 | if (failed(applyPartialConversion(op, target, std::move(patterns)))) |
| 130 | signalPassFailure(); |
| 131 | } |
| 132 | }; |
| 133 | |
| 134 | } // end anonymous namespace |
| 135 | |
| 136 | //===----------------------------------------------------------------------===// |
| 137 | // Public Interface Definition |
| 138 | //===----------------------------------------------------------------------===// |
| 139 | |
| 140 | void memref::populateMemRefWideIntEmulationPatterns( |
| 141 | const arith::WideIntEmulationConverter &typeConverter, |
| 142 | RewritePatternSet &patterns) { |
| 143 | // Populate `memref.*` conversion patterns. |
| 144 | patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>( |
| 145 | arg: typeConverter, args: patterns.getContext()); |
| 146 | } |
| 147 | |
| 148 | void memref::populateMemRefWideIntEmulationConversions( |
| 149 | arith::WideIntEmulationConverter &typeConverter) { |
| 150 | typeConverter.addConversion( |
| 151 | callback: [&typeConverter](MemRefType ty) -> std::optional<Type> { |
| 152 | auto intTy = dyn_cast<IntegerType>(ty.getElementType()); |
| 153 | if (!intTy) |
| 154 | return ty; |
| 155 | |
| 156 | if (intTy.getIntOrFloatBitWidth() <= |
| 157 | typeConverter.getMaxTargetIntBitWidth()) |
| 158 | return ty; |
| 159 | |
| 160 | Type newElemTy = typeConverter.convertType(intTy); |
| 161 | if (!newElemTy) |
| 162 | return nullptr; |
| 163 | |
| 164 | return ty.cloneWith(std::nullopt, newElemTy); |
| 165 | }); |
| 166 | } |
| 167 | |