1//===- AffineExpandIndexOpsAsAffine.cpp - Expand index ops to apply pass --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to expand affine index ops into one or more more
10// fundamental operations.
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/Passes.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Affine/Transforms/Transforms.h"
17#include "mlir/Dialect/Affine/Utils.h"
18#include "mlir/Dialect/Arith/Utils/Utils.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21namespace mlir {
22namespace affine {
23#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE
24#include "mlir/Dialect/Affine/Passes.h.inc"
25} // namespace affine
26} // namespace mlir
27
28using namespace mlir;
29using namespace mlir::affine;
30
31namespace {
32/// Lowers `affine.delinearize_index` into a sequence of division and remainder
33/// operations.
34struct LowerDelinearizeIndexOps
35 : public OpRewritePattern<AffineDelinearizeIndexOp> {
36 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
37 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
38 PatternRewriter &rewriter) const override {
39 FailureOr<SmallVector<Value>> multiIndex =
40 delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
41 op.getEffectiveBasis(), /*hasOuterBound=*/false);
42 if (failed(multiIndex))
43 return failure();
44 rewriter.replaceOp(op, *multiIndex);
45 return success();
46 }
47};
48
49/// Lowers `affine.linearize_index` into a sequence of multiplications and
50/// additions.
51struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
52 using OpRewritePattern::OpRewritePattern;
53 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
54 PatternRewriter &rewriter) const override {
55 // Should be folded away, included here for safety.
56 if (op.getMultiIndex().empty()) {
57 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
58 return success();
59 }
60
61 SmallVector<OpFoldResult> multiIndex =
62 getAsOpFoldResult(op.getMultiIndex());
63 OpFoldResult linearIndex =
64 linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
65 Value linearIndexValue =
66 getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
67 rewriter.replaceOp(op, linearIndexValue);
68 return success();
69 }
70};
71
72class ExpandAffineIndexOpsAsAffinePass
73 : public affine::impl::AffineExpandIndexOpsAsAffineBase<
74 ExpandAffineIndexOpsAsAffinePass> {
75public:
76 ExpandAffineIndexOpsAsAffinePass() = default;
77
78 void runOnOperation() override {
79 MLIRContext *context = &getContext();
80 RewritePatternSet patterns(context);
81 populateAffineExpandIndexOpsAsAffinePatterns(patterns);
82 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
83 return signalPassFailure();
84 }
85};
86
87} // namespace
88
89void mlir::affine::populateAffineExpandIndexOpsAsAffinePatterns(
90 RewritePatternSet &patterns) {
91 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
92 arg: patterns.getContext());
93}
94
95std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsAsAffinePass() {
96 return std::make_unique<ExpandAffineIndexOpsAsAffinePass>();
97}
98

source code of mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp