1//===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains cross-dialect canonicalization patterns that cannot be
10// actual canonicalization patterns due to undesired additional dependencies.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/SCF/Transforms/Passes.h"
15
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/Dialect/SCF/Transforms/Patterns.h"
20#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26namespace mlir {
27#define GEN_PASS_DEF_SCFFORLOOPCANONICALIZATION
28#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::scf;
33
34/// A simple, conservative analysis to determine if the loop is shape
35/// conserving. I.e., the type of the arg-th yielded value is the same as the
36/// type of the corresponding basic block argument of the loop.
37/// Note: This function handles only simple cases. Expand as needed.
38static bool isShapePreserving(ForOp forOp, int64_t arg) {
39 assert(arg < static_cast<int64_t>(forOp.getNumResults()) &&
40 "arg is out of bounds");
41 Value value = forOp.getYieldedValues()[arg];
42 while (value) {
43 if (value == forOp.getRegionIterArgs()[arg])
44 return true;
45 OpResult opResult = dyn_cast<OpResult>(value);
46 if (!opResult)
47 return false;
48
49 using tensor::InsertSliceOp;
50 value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
51 .template Case<InsertSliceOp>(
52 [&](InsertSliceOp op) { return op.getDest(); })
53 .template Case<ForOp>([&](ForOp forOp) {
54 return isShapePreserving(forOp, opResult.getResultNumber())
55 ? forOp.getInitArgs()[opResult.getResultNumber()]
56 : Value();
57 })
58 .Default([&](auto op) { return Value(); });
59 }
60 return false;
61}
62
63namespace {
64/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
65///
66/// ```
67/// %0 = ... : tensor<?x?xf32>
68/// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
69/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
70/// ...
71/// }
72/// ```
73///
74/// is folded to:
75///
76/// ```
77/// %0 = ... : tensor<?x?xf32>
78/// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
79/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
80/// ...
81/// }
82/// ```
83///
84/// Note: Dim ops are folded only if it can be proven that the runtime type of
85/// the iter arg does not change with loop iterations.
86template <typename OpTy>
87struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
88 using OpRewritePattern<OpTy>::OpRewritePattern;
89
90 LogicalResult matchAndRewrite(OpTy dimOp,
91 PatternRewriter &rewriter) const override {
92 auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
93 if (!blockArg)
94 return failure();
95 auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
96 if (!forOp)
97 return failure();
98 if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
99 return failure();
100
101 Value initArg = forOp.getTiedLoopInit(blockArg)->get();
102 rewriter.modifyOpInPlace(
103 dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
104
105 return success();
106 };
107};
108
109/// Fold dim ops of loop results to dim ops of their respective init args. E.g.:
110///
111/// ```
112/// %0 = ... : tensor<?x?xf32>
113/// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
114/// ...
115/// }
116/// %1 = tensor.dim %r, %c0 : tensor<?x?xf32>
117/// ```
118///
119/// is folded to:
120///
121/// ```
122/// %0 = ... : tensor<?x?xf32>
123/// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
124/// ...
125/// }
126/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
127/// ```
128///
129/// Note: Dim ops are folded only if it can be proven that the runtime type of
130/// the iter arg does not change with loop iterations.
131template <typename OpTy>
132struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
133 using OpRewritePattern<OpTy>::OpRewritePattern;
134
135 LogicalResult matchAndRewrite(OpTy dimOp,
136 PatternRewriter &rewriter) const override {
137 auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
138 if (!forOp)
139 return failure();
140 auto opResult = cast<OpResult>(dimOp.getSource());
141 unsigned resultNumber = opResult.getResultNumber();
142 if (!isShapePreserving(forOp, resultNumber))
143 return failure();
144 rewriter.modifyOpInPlace(dimOp, [&]() {
145 dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
146 });
147 return success();
148 }
149};
150
151/// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
152/// and scf.parallel loops with a known range.
153template <typename OpTy>
154struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
155 using OpRewritePattern<OpTy>::OpRewritePattern;
156
157 LogicalResult matchAndRewrite(OpTy op,
158 PatternRewriter &rewriter) const override {
159 return scf::canonicalizeMinMaxOpInLoop(rewriter, op, loopMatcher: scf::matchForLikeLoop);
160 }
161};
162
163struct SCFForLoopCanonicalization
164 : public impl::SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
165 void runOnOperation() override {
166 auto *parentOp = getOperation();
167 MLIRContext *ctx = parentOp->getContext();
168 RewritePatternSet patterns(ctx);
169 scf::populateSCFForLoopCanonicalizationPatterns(patterns);
170 if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns))))
171 signalPassFailure();
172 }
173};
174} // namespace
175
176void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
177 RewritePatternSet &patterns) {
178 MLIRContext *ctx = patterns.getContext();
179 patterns
180 .add<AffineOpSCFCanonicalizationPattern<affine::AffineMinOp>,
181 AffineOpSCFCanonicalizationPattern<affine::AffineMaxOp>,
182 DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
183 DimOfLoopResultFolder<tensor::DimOp>,
184 DimOfLoopResultFolder<memref::DimOp>>(ctx);
185}
186
187std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
188 return std::make_unique<SCFForLoopCanonicalization>();
189}
190

source code of mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp