1//===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass resolves `memref.dim` operations of result values in terms of
10// shapes of their operands using the `InferShapedTypeOpInterface`.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/Transforms/Passes.h"
15
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Arith/Utils/Utils.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Interfaces/InferTypeOpInterface.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24
25namespace mlir {
26namespace memref {
27#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
28#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
29#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
30} // namespace memref
31} // namespace mlir
32
33using namespace mlir;
34
35namespace {
36/// Fold dim of an operation that implements the InferShapedTypeOpInterface
37template <typename OpTy>
38struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
39 using OpRewritePattern<OpTy>::OpRewritePattern;
40
41 LogicalResult matchAndRewrite(OpTy dimOp,
42 PatternRewriter &rewriter) const override {
43 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
44 if (!dimValue)
45 return failure();
46 auto shapedTypeOp =
47 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
48 if (!shapedTypeOp)
49 return failure();
50
51 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
52 if (!dimIndex)
53 return failure();
54
55 SmallVector<Value> reifiedResultShapes;
56 if (failed(shapedTypeOp.reifyReturnTypeShapes(
57 rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
58 return failure();
59
60 if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
61 return failure();
62
63 Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
64 auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType());
65 if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
66 return failure();
67
68 Location loc = dimOp->getLoc();
69 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
70 dimOp, resultShape,
71 rewriter.create<arith::ConstantIndexOp>(loc, *dimIndex).getResult());
72 return success();
73 }
74};
75
76/// Fold dim of an operation that implements the InferShapedTypeOpInterface
77template <typename OpTy>
78struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
79 using OpRewritePattern<OpTy>::OpRewritePattern;
80
81 void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); }
82
83 LogicalResult matchAndRewrite(OpTy dimOp,
84 PatternRewriter &rewriter) const override {
85 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
86 if (!dimValue)
87 return failure();
88 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
89 if (!dimIndex)
90 return failure();
91
92 ReifiedRankedShapedTypeDims reifiedResultShapes;
93 if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
94 reifiedResultShapes)))
95 return failure();
96 unsigned resultNumber = dimValue.getResultNumber();
97 // Do not apply pattern if the IR is invalid (dim out of bounds).
98 if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
99 return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
100 Value replacement = getValueOrCreateConstantIndexOp(
101 rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
102 rewriter.replaceOp(dimOp, replacement);
103 return success();
104 }
105};
106} // namespace
107
108//===----------------------------------------------------------------------===//
109// Pass registration
110//===----------------------------------------------------------------------===//
111
112namespace {
113struct ResolveRankedShapeTypeResultDimsPass final
114 : public memref::impl::ResolveRankedShapeTypeResultDimsBase<
115 ResolveRankedShapeTypeResultDimsPass> {
116 void runOnOperation() override;
117};
118
119struct ResolveShapedTypeResultDimsPass final
120 : public memref::impl::ResolveShapedTypeResultDimsBase<
121 ResolveShapedTypeResultDimsPass> {
122 void runOnOperation() override;
123};
124
125} // namespace
126
127void memref::populateResolveRankedShapedTypeResultDimsPatterns(
128 RewritePatternSet &patterns) {
129 patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
130 DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
131 patterns.getContext());
132}
133
134void memref::populateResolveShapedTypeResultDimsPatterns(
135 RewritePatternSet &patterns) {
136 // TODO: Move tensor::DimOp pattern to the Tensor dialect.
137 patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>,
138 DimOfShapedTypeOpInterface<tensor::DimOp>>(
139 patterns.getContext());
140}
141
142void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
143 RewritePatternSet patterns(&getContext());
144 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
145 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
146 return signalPassFailure();
147}
148
149void ResolveShapedTypeResultDimsPass::runOnOperation() {
150 RewritePatternSet patterns(&getContext());
151 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
152 memref::populateResolveShapedTypeResultDimsPatterns(patterns);
153 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
154 return signalPassFailure();
155}
156
157std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
158 return std::make_unique<ResolveShapedTypeResultDimsPass>();
159}
160
161std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() {
162 return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
163}
164

source code of mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp