1//===- InlineScalarOperands.cpp - Pass to inline scalar operands =============//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns/pass to inline scalar operands into a generic
10// operation. A scalar operand is an operand whose indexing map has a constant
11// rhs.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Linalg/Passes.h"
16
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/Linalg/IR/Linalg.h"
20#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21#include "mlir/IR/AffineExpr.h"
22#include "mlir/IR/AffineMap.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_LINALGINLINESCALAROPERANDSPASS
27#include "mlir/Dialect/Linalg/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33namespace {
34struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
35 using OpRewritePattern<GenericOp>::OpRewritePattern;
36 LogicalResult matchAndRewrite(GenericOp genericOp,
37 PatternRewriter &rewriter) const override {
38 if (!genericOp.hasPureTensorSemantics())
39 return failure();
40
41 SmallVector<size_t> scalarOperands;
42 SmallVector<AffineMap> newIndexingMaps;
43 SmallVector<Value> newOperands;
44 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
45 AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
46 if (genericOp.isDpsInput(opOperand) && map.isConstant()) {
47 scalarOperands.emplace_back(opOperand->getOperandNumber());
48 } else {
49 newIndexingMaps.emplace_back(map);
50 newOperands.emplace_back(opOperand->get());
51 }
52 }
53
54 if (scalarOperands.empty())
55 return failure();
56
57 for (OpOperand &opOperand : genericOp.getDpsInitsMutable())
58 newIndexingMaps.emplace_back(
59 genericOp.getMatchingIndexingMap(&opOperand));
60
61 Location loc = genericOp->getLoc();
62 SmallVector<Value> outputOperands = genericOp.getOutputs();
63 auto newOp = rewriter.create<GenericOp>(
64 loc, genericOp->getResultTypes(), newOperands, outputOperands,
65 newIndexingMaps, genericOp.getIteratorTypesArray());
66 rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(),
67 newOp.getRegion().begin());
68
69 Block *body = newOp.getBody();
70 PatternRewriter::InsertionGuard guard(rewriter);
71 rewriter.setInsertionPointToStart(body);
72
73 for (auto idx : llvm::reverse(scalarOperands)) {
74 OpOperand *opOperand = genericOp.getDpsInputOperand(idx);
75 AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
76 SmallVector<int64_t> indices = map.getConstantResults();
77 SmallVector<Value> indicesValues;
78 for (auto idx : indices)
79 indicesValues.emplace_back(
80 rewriter.create<arith::ConstantIndexOp>(loc, idx));
81 Value extractedValue = rewriter.create<tensor::ExtractOp>(
82 loc, opOperand->get(), indicesValues);
83 body->getArgument(idx).replaceAllUsesWith(extractedValue);
84 body->eraseArgument(idx);
85 }
86
87 rewriter.replaceOp(genericOp, newOp->getResults());
88 return success();
89 }
90};
91} // namespace
92
93/// Patterns that are used to inline constant operands into linalg generic
94/// ops.
95void mlir::linalg::populateInlineConstantOperandsPatterns(
96 RewritePatternSet &patterns) {
97 auto *context = patterns.getContext();
98 patterns.add<InlineScalarOperands>(arg&: context);
99}
100
101namespace {
102/// Pass that removes unit-extent dims within generic ops.
103struct LinalgInlineScalarOperandsPass
104 : public impl::LinalgInlineScalarOperandsPassBase<
105 LinalgInlineScalarOperandsPass> {
106 using impl::LinalgInlineScalarOperandsPassBase<
107 LinalgInlineScalarOperandsPass>::LinalgInlineScalarOperandsPassBase;
108 void runOnOperation() override {
109 Operation *op = getOperation();
110 MLIRContext &ctx = getContext();
111 RewritePatternSet patterns(&ctx);
112 populateInlineConstantOperandsPatterns(patterns);
113 (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
114 }
115};
116} // namespace
117

source code of mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp