1//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Passes.h"
10
11#include "mlir/Dialect/Linalg/IR/Linalg.h"
12#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13#include "mlir/Dialect/Linalg/Utils/Utils.h"
14#include "mlir/Transforms/DialectConversion.h"
15
16namespace mlir {
17#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
18#include "mlir/Dialect/Linalg/Passes.h.inc"
19} // namespace mlir
20
21using namespace mlir;
22
23static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
24 if (!OpTrait::hasElementwiseMappableTraits(op))
25 return false;
26
27 // TODO: The conversion pattern can be made to work for `any_of` here, but
28 // it's more complex as it requires tracking which operands are scalars.
29 return llvm::all_of(Range: op->getOperandTypes(), P: llvm::IsaPred<RankedTensorType>);
30}
31
32/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
33/// the result types and return a list of values such that, for each result type
34/// `t` and value `v` at the same index `idx`:
35/// 1. `v.getType() == t`
36/// 2. If an operand of `op` has type `t`, let `operand_first` be the first
37/// such operand. Then`v == operand_first`.
38/// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
39/// a. Static and dynamic dims extracted from the first operand of `op`.
40/// b. Elemental type equal to the elemental type of `t`.
41///
42/// This is sufficient because ElementwiseMappable guarantees that "The static
43/// types of all vector (resp. tensor) operands and results must have the same
44/// shape".
45static SmallVector<Value, 4>
46getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
47 assert(isElementwiseMappableOpOnRankedTensors(op));
48 Location loc = op->getLoc();
49 ValueRange operands = op->getOperands();
50 TypeRange rankedTensorTypes = op->getResultTypes();
51 SmallVector<Value, 4> res;
52 res.reserve(N: rankedTensorTypes.size());
53 for (Type t : rankedTensorTypes) {
54 // Try to find an operand with type matching the result tensor.
55 bool found = false;
56 for (Value v : operands) {
57 if (v.getType() == t) {
58 found = true;
59 res.push_back(Elt: v);
60 break;
61 }
62 }
63 if (found)
64 continue;
65
66 // Extract static / dynamic shape mix from the first operand.
67 res.push_back(Elt: b.create<tensor::EmptyOp>(
68 location: loc, args: tensor::getMixedSizes(builder&: b, loc, value: operands.front()),
69 args: cast<RankedTensorType>(Val&: t).getElementType()));
70 }
71 return res;
72}
73
74namespace {
75struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
76 ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
77 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
78 LogicalResult matchAndRewrite(Operation *op,
79 PatternRewriter &rewriter) const final {
80 if (!isElementwiseMappableOpOnRankedTensors(op))
81 return rewriter.notifyMatchFailure(
82 arg&: op, msg: "requires elementwise op on ranked tensors");
83
84 auto rank = cast<RankedTensorType>(Val: op->getResult(idx: 0).getType()).getRank();
85 SmallVector<AffineMap, 3> indexingMaps(
86 op->getNumResults() + op->getNumOperands(),
87 rewriter.getMultiDimIdentityMap(rank));
88 SmallVector<utils::IteratorType, 6> iteratorTypes(
89 rank, utils::IteratorType::parallel);
90 auto outputs = getOrCreateOperandsMatchingResultTypes(b&: rewriter, op);
91 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
92 op, /*resultTensorTypes=*/args: op->getResultTypes(),
93 /*inputs=*/args: op->getOperands(),
94 /*outputs=*/args&: outputs,
95 /*indexingMaps=*/args&: indexingMaps,
96 /*iteratorTypes=*/args&: iteratorTypes,
97 /*bodyBuilder=*/
98 args: [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
99 auto resultTypes = llvm::to_vector<6>(
100 Range: llvm::map_range(C: op->getResultTypes(), F: [](Type type) {
101 return cast<TensorType>(Val&: type).getElementType();
102 }));
103 auto *scalarOp =
104 builder.create(loc, opName: op->getName().getIdentifier(),
105 operands: regionArgs.take_front(n: op->getNumOperands()),
106 types: resultTypes, attributes: op->getAttrs());
107 builder.create<linalg::YieldOp>(location: loc, args: scalarOp->getResults());
108 });
109 return success();
110 }
111};
112} // namespace
113
114void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
115 RewritePatternSet &patterns) {
116 patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
117 arg: patterns.getContext());
118}
119
120namespace {
121class ConvertElementwiseToLinalgPass
122 : public impl::ConvertElementwiseToLinalgPassBase<
123 ConvertElementwiseToLinalgPass> {
124 using impl::ConvertElementwiseToLinalgPassBase<
125 ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
126
127 void runOnOperation() final {
128 auto *func = getOperation();
129 auto *context = &getContext();
130 ConversionTarget target(*context);
131 RewritePatternSet patterns(context);
132
133 mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
134 target.markUnknownOpDynamicallyLegal(fn: [](Operation *op) {
135 return !isElementwiseMappableOpOnRankedTensors(op);
136 });
137
138 if (failed(Result: applyPartialConversion(op: func, target, patterns: std::move(patterns))))
139 signalPassFailure();
140 }
141};
142} // namespace
143

source code of mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp