1//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/Arith/Transforms/Passes.h"
11#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/Dialect/MemRef/Transforms/Passes.h"
14#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
15#include "mlir/Dialect/Vector/IR/VectorOps.h"
16#include "mlir/Transforms/DialectConversion.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/Support/MathExtras.h"
19#include <cassert>
20
21namespace mlir::memref {
22#define GEN_PASS_DEF_MEMREFEMULATEWIDEINT
23#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
24} // namespace mlir::memref
25
26using namespace mlir;
27
28namespace {
29
30//===----------------------------------------------------------------------===//
31// ConvertMemRefAlloc
32//===----------------------------------------------------------------------===//
33
34struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
35 using OpConversionPattern::OpConversionPattern;
36
37 LogicalResult
38 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
39 ConversionPatternRewriter &rewriter) const override {
40 Type newTy = getTypeConverter()->convertType(op.getType());
41 if (!newTy)
42 return rewriter.notifyMatchFailure(
43 op->getLoc(),
44 llvm::formatv("failed to convert memref type: {0}", op.getType()));
45
46 rewriter.replaceOpWithNewOp<memref::AllocOp>(
47 op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
48 adaptor.getAlignmentAttr());
49 return success();
50 }
51};
52
53//===----------------------------------------------------------------------===//
54// ConvertMemRefLoad
55//===----------------------------------------------------------------------===//
56
57struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
58 using OpConversionPattern::OpConversionPattern;
59
60 LogicalResult
61 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
62 ConversionPatternRewriter &rewriter) const override {
63 Type newResTy = getTypeConverter()->convertType(op.getType());
64 if (!newResTy)
65 return rewriter.notifyMatchFailure(
66 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
67 op.getMemRefType()));
68
69 rewriter.replaceOpWithNewOp<memref::LoadOp>(
70 op, newResTy, adaptor.getMemref(), adaptor.getIndices(),
71 op.getNontemporal());
72 return success();
73 }
74};
75
76//===----------------------------------------------------------------------===//
77// ConvertMemRefStore
78//===----------------------------------------------------------------------===//
79
80struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
81 using OpConversionPattern::OpConversionPattern;
82
83 LogicalResult
84 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 Type newTy = getTypeConverter()->convertType(op.getMemRefType());
87 if (!newTy)
88 return rewriter.notifyMatchFailure(
89 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
90 op.getMemRefType()));
91
92 rewriter.replaceOpWithNewOp<memref::StoreOp>(
93 op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(),
94 op.getNontemporal());
95 return success();
96 }
97};
98
99//===----------------------------------------------------------------------===//
100// Pass Definition
101//===----------------------------------------------------------------------===//
102
103struct EmulateWideIntPass final
104 : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> {
105 using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase;
106
107 void runOnOperation() override {
108 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
109 signalPassFailure();
110 return;
111 }
112
113 Operation *op = getOperation();
114 MLIRContext *ctx = op->getContext();
115
116 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
117 memref::populateMemRefWideIntEmulationConversions(typeConverter&: typeConverter);
118 ConversionTarget target(*ctx);
119 target.addDynamicallyLegalDialect<
120 arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>(
121 [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
122
123 RewritePatternSet patterns(ctx);
124 // Add common pattenrs to support contants, functions, etc.
125 arith::populateArithWideIntEmulationPatterns(typeConverter&: typeConverter, patterns);
126
127 memref::populateMemRefWideIntEmulationPatterns(typeConverter&: typeConverter, patterns);
128
129 if (failed(applyPartialConversion(op, target, std::move(patterns))))
130 signalPassFailure();
131 }
132};
133
134} // end anonymous namespace
135
136//===----------------------------------------------------------------------===//
137// Public Interface Definition
138//===----------------------------------------------------------------------===//
139
140void memref::populateMemRefWideIntEmulationPatterns(
141 arith::WideIntEmulationConverter &typeConverter,
142 RewritePatternSet &patterns) {
143 // Populate `memref.*` conversion patterns.
144 patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>(
145 arg&: typeConverter, args: patterns.getContext());
146}
147
148void memref::populateMemRefWideIntEmulationConversions(
149 arith::WideIntEmulationConverter &typeConverter) {
150 typeConverter.addConversion(
151 callback: [&typeConverter](MemRefType ty) -> std::optional<Type> {
152 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
153 if (!intTy)
154 return ty;
155
156 if (intTy.getIntOrFloatBitWidth() <=
157 typeConverter.getMaxTargetIntBitWidth())
158 return ty;
159
160 Type newElemTy = typeConverter.convertType(intTy);
161 if (!newElemTy)
162 return std::nullopt;
163
164 return ty.cloneWith(std::nullopt, newElemTy);
165 });
166}
167

source code of mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp