1 | //===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass resolves `memref.dim` operations of result values in terms of |
10 | // shapes of their operands using the `InferShapedTypeOpInterface`. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
17 | #include "mlir/Dialect/Arith/IR/Arith.h" |
18 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
19 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
20 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
22 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
23 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
24 | |
25 | namespace mlir { |
26 | namespace memref { |
27 | #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS |
28 | #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS |
29 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
30 | } // namespace memref |
31 | } // namespace mlir |
32 | |
33 | using namespace mlir; |
34 | |
35 | namespace { |
36 | /// Fold dim of an operation that implements the InferShapedTypeOpInterface |
37 | template <typename OpTy> |
38 | struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> { |
39 | using OpRewritePattern<OpTy>::OpRewritePattern; |
40 | |
41 | LogicalResult matchAndRewrite(OpTy dimOp, |
42 | PatternRewriter &rewriter) const override { |
43 | OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource()); |
44 | if (!dimValue) |
45 | return failure(); |
46 | auto shapedTypeOp = |
47 | dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner()); |
48 | if (!shapedTypeOp) |
49 | return failure(); |
50 | |
51 | std::optional<int64_t> dimIndex = dimOp.getConstantIndex(); |
52 | if (!dimIndex) |
53 | return failure(); |
54 | |
55 | SmallVector<Value> reifiedResultShapes; |
56 | if (failed(shapedTypeOp.reifyReturnTypeShapes( |
57 | rewriter, shapedTypeOp->getOperands(), reifiedResultShapes))) |
58 | return failure(); |
59 | |
60 | if (reifiedResultShapes.size() != shapedTypeOp->getNumResults()) |
61 | return failure(); |
62 | |
63 | Value resultShape = reifiedResultShapes[dimValue.getResultNumber()]; |
64 | auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType()); |
65 | if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType())) |
66 | return failure(); |
67 | |
68 | Location loc = dimOp->getLoc(); |
69 | rewriter.replaceOpWithNewOp<tensor::ExtractOp>( |
70 | dimOp, resultShape, |
71 | rewriter.create<arith::ConstantIndexOp>(loc, *dimIndex).getResult()); |
72 | return success(); |
73 | } |
74 | }; |
75 | |
76 | /// Fold dim of an operation that implements the InferShapedTypeOpInterface |
77 | template <typename OpTy> |
78 | struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> { |
79 | using OpRewritePattern<OpTy>::OpRewritePattern; |
80 | |
81 | void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); } |
82 | |
83 | LogicalResult matchAndRewrite(OpTy dimOp, |
84 | PatternRewriter &rewriter) const override { |
85 | OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource()); |
86 | if (!dimValue) |
87 | return failure(); |
88 | std::optional<int64_t> dimIndex = dimOp.getConstantIndex(); |
89 | if (!dimIndex) |
90 | return failure(); |
91 | |
92 | ReifiedRankedShapedTypeDims reifiedResultShapes; |
93 | if (failed(reifyResultShapes(rewriter, dimValue.getOwner(), |
94 | reifiedResultShapes))) |
95 | return failure(); |
96 | unsigned resultNumber = dimValue.getResultNumber(); |
97 | // Do not apply pattern if the IR is invalid (dim out of bounds). |
98 | if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size()) |
99 | return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds" ); |
100 | Value replacement = getValueOrCreateConstantIndexOp( |
101 | rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]); |
102 | rewriter.replaceOp(dimOp, replacement); |
103 | return success(); |
104 | } |
105 | }; |
106 | } // namespace |
107 | |
108 | //===----------------------------------------------------------------------===// |
109 | // Pass registration |
110 | //===----------------------------------------------------------------------===// |
111 | |
112 | namespace { |
113 | struct ResolveRankedShapeTypeResultDimsPass final |
114 | : public memref::impl::ResolveRankedShapeTypeResultDimsBase< |
115 | ResolveRankedShapeTypeResultDimsPass> { |
116 | void runOnOperation() override; |
117 | }; |
118 | |
119 | struct ResolveShapedTypeResultDimsPass final |
120 | : public memref::impl::ResolveShapedTypeResultDimsBase< |
121 | ResolveShapedTypeResultDimsPass> { |
122 | void runOnOperation() override; |
123 | }; |
124 | |
125 | } // namespace |
126 | |
127 | void memref::populateResolveRankedShapedTypeResultDimsPatterns( |
128 | RewritePatternSet &patterns) { |
129 | patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>, |
130 | DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>( |
131 | patterns.getContext()); |
132 | } |
133 | |
134 | void memref::populateResolveShapedTypeResultDimsPatterns( |
135 | RewritePatternSet &patterns) { |
136 | // TODO: Move tensor::DimOp pattern to the Tensor dialect. |
137 | patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>, |
138 | DimOfShapedTypeOpInterface<tensor::DimOp>>( |
139 | patterns.getContext()); |
140 | } |
141 | |
142 | void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { |
143 | RewritePatternSet patterns(&getContext()); |
144 | memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); |
145 | if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) |
146 | return signalPassFailure(); |
147 | } |
148 | |
149 | void ResolveShapedTypeResultDimsPass::runOnOperation() { |
150 | RewritePatternSet patterns(&getContext()); |
151 | memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); |
152 | memref::populateResolveShapedTypeResultDimsPatterns(patterns); |
153 | if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) |
154 | return signalPassFailure(); |
155 | } |
156 | |
157 | std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() { |
158 | return std::make_unique<ResolveShapedTypeResultDimsPass>(); |
159 | } |
160 | |
161 | std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() { |
162 | return std::make_unique<ResolveRankedShapeTypeResultDimsPass>(); |
163 | } |
164 | |