1//===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass resolves `memref.dim` operations of result values in terms of
10// shapes of their operands using the `InferShapedTypeOpInterface`.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/Transforms/Passes.h"
15
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Arith/Utils/Utils.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/Dialect/Tensor/IR/Tensor.h"
23#include "mlir/Interfaces/InferTypeOpInterface.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26namespace mlir {
27namespace memref {
28#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
29#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
30#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
31} // namespace memref
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37/// Fold dim of an operation that implements the InferShapedTypeOpInterface
38template <typename OpTy>
39struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
40 using OpRewritePattern<OpTy>::OpRewritePattern;
41
42 LogicalResult matchAndRewrite(OpTy dimOp,
43 PatternRewriter &rewriter) const override {
44 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
45 if (!dimValue)
46 return failure();
47 auto shapedTypeOp =
48 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
49 if (!shapedTypeOp)
50 return failure();
51
52 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
53 if (!dimIndex)
54 return failure();
55
56 SmallVector<Value> reifiedResultShapes;
57 if (failed(shapedTypeOp.reifyReturnTypeShapes(
58 rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
59 return failure();
60
61 if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
62 return failure();
63
64 Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
65 auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType());
66 if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
67 return failure();
68
69 Location loc = dimOp->getLoc();
70 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
71 dimOp, resultShape,
72 rewriter.create<arith::ConstantIndexOp>(loc, *dimIndex).getResult());
73 return success();
74 }
75};
76
77/// Fold dim of an operation that implements the InferShapedTypeOpInterface
78template <typename OpTy>
79struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
80 using OpRewritePattern<OpTy>::OpRewritePattern;
81
82 void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); }
83
84 LogicalResult matchAndRewrite(OpTy dimOp,
85 PatternRewriter &rewriter) const override {
86 OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
87 if (!dimValue)
88 return failure();
89 std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
90 if (!dimIndex)
91 return failure();
92
93 ReifiedRankedShapedTypeDims reifiedResultShapes;
94 if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
95 reifiedResultShapes)))
96 return failure();
97 unsigned resultNumber = dimValue.getResultNumber();
98 // Do not apply pattern if the IR is invalid (dim out of bounds).
99 if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
100 return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
101 Value replacement = getValueOrCreateConstantIndexOp(
102 rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
103 rewriter.replaceOp(dimOp, replacement);
104 return success();
105 }
106};
107
108/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
109///
110/// ```
111/// %0 = ... : tensor<?x?xf32>
112/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
113/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
114/// ...
115/// }
116/// ```
117///
118/// is folded to:
119///
120/// ```
121/// %0 = ... : tensor<?x?xf32>
122/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
123/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
124/// ...
125/// }
126/// ```
127struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
128 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
129
130 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
131 PatternRewriter &rewriter) const final {
132 auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
133 if (!blockArg)
134 return failure();
135 // TODO: Enable this for loopLikeInterface. Restricting for scf.for
136 // because the init args shape might change in the loop body.
137 // For e.g.:
138 // ```
139 // %0 = tensor.empty(%c1) : tensor<?xf32>
140 // %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
141 // tensor<?xf32> {
142 // %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
143 // %2 = arith.addi %c1, %1 : index
144 // %3 = tensor.empty(%2) : tensor<?xf32>
145 // scf.yield %3 : tensor<?xf32>
146 // }
147 //
148 // ```
149 auto forAllOp =
150 dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
151 if (!forAllOp)
152 return failure();
153 Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
154 rewriter.modifyOpInPlace(
155 dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
156 return success();
157 }
158};
159} // namespace
160
161//===----------------------------------------------------------------------===//
162// Pass registration
163//===----------------------------------------------------------------------===//
164
165namespace {
166struct ResolveRankedShapeTypeResultDimsPass final
167 : public memref::impl::ResolveRankedShapeTypeResultDimsPassBase<
168 ResolveRankedShapeTypeResultDimsPass> {
169 void runOnOperation() override;
170};
171
172struct ResolveShapedTypeResultDimsPass final
173 : public memref::impl::ResolveShapedTypeResultDimsPassBase<
174 ResolveShapedTypeResultDimsPass> {
175 void runOnOperation() override;
176};
177
178} // namespace
179
180void memref::populateResolveRankedShapedTypeResultDimsPatterns(
181 RewritePatternSet &patterns) {
182 patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
183 DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
184 IterArgsToInitArgs>(patterns.getContext());
185}
186
187void memref::populateResolveShapedTypeResultDimsPatterns(
188 RewritePatternSet &patterns) {
189 // TODO: Move tensor::DimOp pattern to the Tensor dialect.
190 patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>,
191 DimOfShapedTypeOpInterface<tensor::DimOp>>(
192 patterns.getContext());
193}
194
195void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
196 RewritePatternSet patterns(&getContext());
197 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
198 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
199 return signalPassFailure();
200}
201
202void ResolveShapedTypeResultDimsPass::runOnOperation() {
203 RewritePatternSet patterns(&getContext());
204 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
205 memref::populateResolveShapedTypeResultDimsPatterns(patterns);
206 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
207 return signalPassFailure();
208}
209

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp