1//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements folding ops such as transpose and broadcast into the
10// affine maps of the elementwise op.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/Passes.h"
16#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/TypeSwitch.h"
21
22namespace mlir {
23#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
24#include "mlir/Dialect/Linalg/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::linalg;
29
30#define DEBUG_TYPE "linalg-fold-into-elementwise"
31
32namespace {
33struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
34 using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
35
36 LogicalResult matchAndRewrite(ElementwiseOp op,
37 PatternRewriter &rewriter) const override {
38 bool changed = false;
39 SmallVector<Value> newIns;
40 SmallVector<AffineMap> newMaps;
41 for (OpOperand *operand : op.getDpsInputOperands()) {
42 AffineMap map = op.getMatchingIndexingMap(operand);
43 auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
44
45 if (!map.isIdentity() || !transposeOp) {
46 // push in original operand and its map.
47 newIns.push_back(operand->get());
48 newMaps.push_back(map);
49 continue;
50 }
51 newIns.push_back(transposeOp.getInput());
52 // push in transposeOp's inverse permutation map.
53 newMaps.push_back(transposeOp.getMatchingIndexingMap(
54 transposeOp.getDpsInputOperand(0)));
55 changed = true;
56 }
57 if (!changed)
58 return failure();
59 newMaps.push_back(op.getIndexingMapsArray().back());
60
61 rewriter.replaceOpWithNewOp<ElementwiseOp>(
62 op, newIns, op.getDpsInits()[0], op.getKindAttr(),
63 rewriter.getAffineMapArrayAttr(newMaps));
64 return success();
65 }
66};
67
68struct LinalgFoldIntoElementwisePass
69 : public impl::LinalgFoldIntoElementwisePassBase<
70 LinalgFoldIntoElementwisePass> {
71 using impl::LinalgFoldIntoElementwisePassBase<
72 LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
73
74 void runOnOperation() override {
75 Operation *op = getOperation();
76 RewritePatternSet patterns(op->getContext());
77 populateLinalgFoldIntoElementwisePatterns(patterns);
78
79 if (failed(applyPatternsGreedily(op, std::move(patterns))))
80 return signalPassFailure();
81 }
82};
83} // namespace
84
85void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
86 RewritePatternSet &patterns) {
87 patterns.add<FoldTransposePattern>(arg: patterns.getContext());
88}
89

source code of mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp