1//===- ExpandDivs.cpp - Expansion patterns for MemRef operations ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/Arith/Transforms/Passes.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
13#include "mlir/IR/TypeUtilities.h"
14#include "mlir/Transforms/DialectConversion.h"
15
16namespace mlir {
17namespace memref {
18#define GEN_PASS_DEF_EXPANDOPSPASS
19#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
20} // namespace memref
21} // namespace mlir
22
23using namespace mlir;
24
25namespace {
26
27/// Converts `memref.reshape` that has a target shape of a statically-known
28/// size to `memref.reinterpret_cast`.
29struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
30public:
31 using OpRewritePattern::OpRewritePattern;
32
33 LogicalResult matchAndRewrite(memref::ReshapeOp op,
34 PatternRewriter &rewriter) const final {
35 auto shapeType = cast<MemRefType>(Val: op.getShape().getType());
36 if (!shapeType.hasStaticShape())
37 return failure();
38
39 int64_t rank = cast<MemRefType>(Val&: shapeType).getDimSize(idx: 0);
40 SmallVector<OpFoldResult, 4> sizes, strides;
41 sizes.resize(N: rank);
42 strides.resize(N: rank);
43
44 Location loc = op.getLoc();
45 Value stride = nullptr;
46 int64_t staticStride = 1;
47 for (int i = rank - 1; i >= 0; --i) {
48 Value size;
49 // Load dynamic sizes from the shape input, use constants for static dims.
50 if (op.getType().isDynamicDim(idx: i)) {
51 Value index = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i);
52 size = rewriter.create<memref::LoadOp>(location: loc, args: op.getShape(), args&: index);
53 if (!isa<IndexType>(Val: size.getType()))
54 size = rewriter.create<arith::IndexCastOp>(
55 location: loc, args: rewriter.getIndexType(), args&: size);
56 sizes[i] = size;
57 } else {
58 auto sizeAttr = rewriter.getIndexAttr(value: op.getType().getDimSize(idx: i));
59 size = rewriter.create<arith::ConstantOp>(location: loc, args&: sizeAttr);
60 sizes[i] = sizeAttr;
61 }
62 if (stride)
63 strides[i] = stride;
64 else
65 strides[i] = rewriter.getIndexAttr(value: staticStride);
66
67 if (i > 0) {
68 if (stride) {
69 stride = rewriter.create<arith::MulIOp>(location: loc, args&: stride, args&: size);
70 } else if (op.getType().isDynamicDim(idx: i)) {
71 stride = rewriter.create<arith::MulIOp>(
72 location: loc, args: rewriter.create<arith::ConstantIndexOp>(location: loc, args&: staticStride),
73 args&: size);
74 } else {
75 staticStride *= op.getType().getDimSize(idx: i);
76 }
77 }
78 }
79 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
80 op, args: op.getType(), args: op.getSource(), /*offset=*/args: rewriter.getIndexAttr(value: 0),
81 args&: sizes, args&: strides);
82 return success();
83 }
84};
85
86struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
87 void runOnOperation() override {
88 MLIRContext &ctx = getContext();
89
90 RewritePatternSet patterns(&ctx);
91 memref::populateExpandOpsPatterns(patterns);
92 ConversionTarget target(ctx);
93
94 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
95 target.addDynamicallyLegalOp<memref::ReshapeOp>(callback: [](memref::ReshapeOp op) {
96 return !cast<MemRefType>(Val: op.getShape().getType()).hasStaticShape();
97 });
98 if (failed(Result: applyPartialConversion(op: getOperation(), target,
99 patterns: std::move(patterns))))
100 signalPassFailure();
101 }
102};
103
104} // namespace
105
106void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
107 patterns.add<MemRefReshapeOpConverter>(arg: patterns.getContext());
108}
109

source code of mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp