1//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements folding ops such as transpose and broadcast into the
10// affine maps of the elementwise op.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/Passes.h"
16#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/ADT/SmallVector.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
23#include "mlir/Dialect/Linalg/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::linalg;
28
29#define DEBUG_TYPE "linalg-fold-into-elementwise"
30
31namespace {
32struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
33 using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
34
35 LogicalResult matchAndRewrite(ElementwiseOp op,
36 PatternRewriter &rewriter) const override {
37 bool changed = false;
38 SmallVector<Value> newIns;
39 SmallVector<AffineMap> newMaps;
40 for (OpOperand *operand : op.getDpsInputOperands()) {
41 AffineMap map = op.getMatchingIndexingMap(opOperand: operand);
42 auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
43
44 if (!map.isIdentity() || !transposeOp) {
45 // push in original operand and its map.
46 newIns.push_back(Elt: operand->get());
47 newMaps.push_back(Elt: map);
48 continue;
49 }
50 newIns.push_back(Elt: transposeOp.getInput());
51 // push in transposeOp's inverse permutation map.
52 newMaps.push_back(Elt: transposeOp.getMatchingIndexingMap(
53 opOperand: transposeOp.getDpsInputOperand(i: 0)));
54 changed = true;
55 }
56 if (!changed)
57 return failure();
58 newMaps.push_back(Elt: op.getIndexingMapsArray().back());
59
60 rewriter.replaceOpWithNewOp<ElementwiseOp>(
61 op, args&: newIns, args: op.getDpsInits()[0], args: op.getKindAttr(),
62 args: rewriter.getAffineMapArrayAttr(values: newMaps));
63 return success();
64 }
65};
66
67struct LinalgFoldIntoElementwisePass
68 : public impl::LinalgFoldIntoElementwisePassBase<
69 LinalgFoldIntoElementwisePass> {
70 using impl::LinalgFoldIntoElementwisePassBase<
71 LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
72
73 void runOnOperation() override {
74 Operation *op = getOperation();
75 RewritePatternSet patterns(op->getContext());
76 populateLinalgFoldIntoElementwisePatterns(patterns);
77
78 if (failed(Result: applyPatternsGreedily(op, patterns: std::move(patterns))))
79 return signalPassFailure();
80 }
81};
82} // namespace
83
84void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
85 RewritePatternSet &patterns) {
86 patterns.add<FoldTransposePattern>(arg: patterns.getContext());
87}
88

source code of mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp