1 | //===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements folding ops such as transpose and broadcast into the |
10 | // affine maps of the elementwise op. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
15 | #include "mlir/Dialect/Linalg/Passes.h" |
16 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
17 | #include "mlir/IR/PatternMatch.h" |
18 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
19 | #include "llvm/ADT/SmallVector.h" |
20 | #include "llvm/ADT/TypeSwitch.h" |
21 | |
22 | namespace mlir { |
23 | #define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS |
24 | #include "mlir/Dialect/Linalg/Passes.h.inc" |
25 | } // namespace mlir |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::linalg; |
29 | |
30 | #define DEBUG_TYPE "linalg-fold-into-elementwise" |
31 | |
32 | namespace { |
33 | struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> { |
34 | using OpRewritePattern<ElementwiseOp>::OpRewritePattern; |
35 | |
36 | LogicalResult matchAndRewrite(ElementwiseOp op, |
37 | PatternRewriter &rewriter) const override { |
38 | bool changed = false; |
39 | SmallVector<Value> newIns; |
40 | SmallVector<AffineMap> newMaps; |
41 | for (OpOperand *operand : op.getDpsInputOperands()) { |
42 | AffineMap map = op.getMatchingIndexingMap(operand); |
43 | auto transposeOp = operand->get().getDefiningOp<TransposeOp>(); |
44 | |
45 | if (!map.isIdentity() || !transposeOp) { |
46 | // push in original operand and its map. |
47 | newIns.push_back(operand->get()); |
48 | newMaps.push_back(map); |
49 | continue; |
50 | } |
51 | newIns.push_back(transposeOp.getInput()); |
52 | // push in transposeOp's inverse permutation map. |
53 | newMaps.push_back(transposeOp.getMatchingIndexingMap( |
54 | transposeOp.getDpsInputOperand(0))); |
55 | changed = true; |
56 | } |
57 | if (!changed) |
58 | return failure(); |
59 | newMaps.push_back(op.getIndexingMapsArray().back()); |
60 | |
61 | rewriter.replaceOpWithNewOp<ElementwiseOp>( |
62 | op, newIns, op.getDpsInits()[0], op.getKindAttr(), |
63 | rewriter.getAffineMapArrayAttr(newMaps)); |
64 | return success(); |
65 | } |
66 | }; |
67 | |
68 | struct LinalgFoldIntoElementwisePass |
69 | : public impl::LinalgFoldIntoElementwisePassBase< |
70 | LinalgFoldIntoElementwisePass> { |
71 | using impl::LinalgFoldIntoElementwisePassBase< |
72 | LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase; |
73 | |
74 | void runOnOperation() override { |
75 | Operation *op = getOperation(); |
76 | RewritePatternSet patterns(op->getContext()); |
77 | populateLinalgFoldIntoElementwisePatterns(patterns); |
78 | |
79 | if (failed(applyPatternsGreedily(op, std::move(patterns)))) |
80 | return signalPassFailure(); |
81 | } |
82 | }; |
83 | } // namespace |
84 | |
85 | void mlir::linalg::populateLinalgFoldIntoElementwisePatterns( |
86 | RewritePatternSet &patterns) { |
87 | patterns.add<FoldTransposePattern>(arg: patterns.getContext()); |
88 | } |
89 | |