1 | //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
10 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
11 | #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h" |
12 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
13 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
14 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
15 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
16 | #include "mlir/Transforms/DialectConversion.h" |
17 | #include "llvm/Support/FormatVariadic.h" |
18 | #include "llvm/Support/MathExtras.h" |
19 | #include <cassert> |
20 | |
21 | namespace mlir::memref { |
22 | #define GEN_PASS_DEF_MEMREFEMULATEWIDEINT |
23 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
24 | } // namespace mlir::memref |
25 | |
26 | using namespace mlir; |
27 | |
28 | namespace { |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // ConvertMemRefAlloc |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { |
35 | using OpConversionPattern::OpConversionPattern; |
36 | |
37 | LogicalResult |
38 | matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, |
39 | ConversionPatternRewriter &rewriter) const override { |
40 | Type newTy = getTypeConverter()->convertType(op.getType()); |
41 | if (!newTy) |
42 | return rewriter.notifyMatchFailure( |
43 | op->getLoc(), |
44 | llvm::formatv("failed to convert memref type: {0}" , op.getType())); |
45 | |
46 | rewriter.replaceOpWithNewOp<memref::AllocOp>( |
47 | op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), |
48 | adaptor.getAlignmentAttr()); |
49 | return success(); |
50 | } |
51 | }; |
52 | |
53 | //===----------------------------------------------------------------------===// |
54 | // ConvertMemRefLoad |
55 | //===----------------------------------------------------------------------===// |
56 | |
57 | struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { |
58 | using OpConversionPattern::OpConversionPattern; |
59 | |
60 | LogicalResult |
61 | matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, |
62 | ConversionPatternRewriter &rewriter) const override { |
63 | Type newResTy = getTypeConverter()->convertType(op.getType()); |
64 | if (!newResTy) |
65 | return rewriter.notifyMatchFailure( |
66 | op->getLoc(), llvm::formatv("failed to convert memref type: {0}" , |
67 | op.getMemRefType())); |
68 | |
69 | rewriter.replaceOpWithNewOp<memref::LoadOp>( |
70 | op, newResTy, adaptor.getMemref(), adaptor.getIndices(), |
71 | op.getNontemporal()); |
72 | return success(); |
73 | } |
74 | }; |
75 | |
76 | //===----------------------------------------------------------------------===// |
77 | // ConvertMemRefStore |
78 | //===----------------------------------------------------------------------===// |
79 | |
80 | struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> { |
81 | using OpConversionPattern::OpConversionPattern; |
82 | |
83 | LogicalResult |
84 | matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, |
85 | ConversionPatternRewriter &rewriter) const override { |
86 | Type newTy = getTypeConverter()->convertType(op.getMemRefType()); |
87 | if (!newTy) |
88 | return rewriter.notifyMatchFailure( |
89 | op->getLoc(), llvm::formatv("failed to convert memref type: {0}" , |
90 | op.getMemRefType())); |
91 | |
92 | rewriter.replaceOpWithNewOp<memref::StoreOp>( |
93 | op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(), |
94 | op.getNontemporal()); |
95 | return success(); |
96 | } |
97 | }; |
98 | |
99 | //===----------------------------------------------------------------------===// |
100 | // Pass Definition |
101 | //===----------------------------------------------------------------------===// |
102 | |
103 | struct EmulateWideIntPass final |
104 | : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> { |
105 | using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase; |
106 | |
107 | void runOnOperation() override { |
108 | if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) { |
109 | signalPassFailure(); |
110 | return; |
111 | } |
112 | |
113 | Operation *op = getOperation(); |
114 | MLIRContext *ctx = op->getContext(); |
115 | |
116 | arith::WideIntEmulationConverter typeConverter(widestIntSupported); |
117 | memref::populateMemRefWideIntEmulationConversions(typeConverter&: typeConverter); |
118 | ConversionTarget target(*ctx); |
119 | target.addDynamicallyLegalDialect< |
120 | arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>( |
121 | [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); |
122 | |
123 | RewritePatternSet patterns(ctx); |
124 | // Add common pattenrs to support contants, functions, etc. |
125 | arith::populateArithWideIntEmulationPatterns(typeConverter&: typeConverter, patterns); |
126 | |
127 | memref::populateMemRefWideIntEmulationPatterns(typeConverter&: typeConverter, patterns); |
128 | |
129 | if (failed(applyPartialConversion(op, target, std::move(patterns)))) |
130 | signalPassFailure(); |
131 | } |
132 | }; |
133 | |
134 | } // end anonymous namespace |
135 | |
136 | //===----------------------------------------------------------------------===// |
137 | // Public Interface Definition |
138 | //===----------------------------------------------------------------------===// |
139 | |
140 | void memref::populateMemRefWideIntEmulationPatterns( |
141 | arith::WideIntEmulationConverter &typeConverter, |
142 | RewritePatternSet &patterns) { |
143 | // Populate `memref.*` conversion patterns. |
144 | patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>( |
145 | arg&: typeConverter, args: patterns.getContext()); |
146 | } |
147 | |
148 | void memref::populateMemRefWideIntEmulationConversions( |
149 | arith::WideIntEmulationConverter &typeConverter) { |
150 | typeConverter.addConversion( |
151 | callback: [&typeConverter](MemRefType ty) -> std::optional<Type> { |
152 | auto intTy = dyn_cast<IntegerType>(ty.getElementType()); |
153 | if (!intTy) |
154 | return ty; |
155 | |
156 | if (intTy.getIntOrFloatBitWidth() <= |
157 | typeConverter.getMaxTargetIntBitWidth()) |
158 | return ty; |
159 | |
160 | Type newElemTy = typeConverter.convertType(intTy); |
161 | if (!newElemTy) |
162 | return std::nullopt; |
163 | |
164 | return ty.cloneWith(std::nullopt, newElemTy); |
165 | }); |
166 | } |
167 | |