1//===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to expand affine index ops into one or more more
10// fundamental operations.
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/Passes.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Affine/Transforms/Transforms.h"
17#include "mlir/Dialect/Affine/Utils.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20namespace mlir {
21namespace affine {
22#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
23#include "mlir/Dialect/Affine/Passes.h.inc"
24} // namespace affine
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::affine;
29
30namespace {
31/// Lowers `affine.delinearize_index` into a sequence of division and remainder
32/// operations.
33struct LowerDelinearizeIndexOps
34 : public OpRewritePattern<AffineDelinearizeIndexOp> {
35 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
36 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
37 PatternRewriter &rewriter) const override {
38 FailureOr<SmallVector<Value>> multiIndex =
39 delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
40 llvm::to_vector(op.getBasis()));
41 if (failed(multiIndex))
42 return failure();
43 rewriter.replaceOp(op, *multiIndex);
44 return success();
45 }
46};
47
48class ExpandAffineIndexOpsPass
49 : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
50public:
51 ExpandAffineIndexOpsPass() = default;
52
53 void runOnOperation() override {
54 MLIRContext *context = &getContext();
55 RewritePatternSet patterns(context);
56 populateAffineExpandIndexOpsPatterns(patterns);
57 if (failed(
58 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
59 return signalPassFailure();
60 }
61};
62
63} // namespace
64
65void mlir::affine::populateAffineExpandIndexOpsPatterns(
66 RewritePatternSet &patterns) {
67 patterns.insert<LowerDelinearizeIndexOps>(arg: patterns.getContext());
68}
69
70std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() {
71 return std::make_unique<ExpandAffineIndexOpsPass>();
72}
73

source code of mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp