1//===- TensorTransformOps.cpp - Implementation of tensor transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/SCF/IR/SCF.h"
13#include "mlir/Dialect/Tensor/IR/Tensor.h"
14#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
15#include "mlir/Dialect/Tensor/Utils/Utils.h"
16#include "mlir/Dialect/Transform/IR/TransformDialect.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21using namespace mlir;
22using namespace tensor;
23
24//===----------------------------------------------------------------------===//
25// FindPayloadReplacementOpInterface implementations
26//===----------------------------------------------------------------------===//
27
28namespace {
29struct ExtractSliceOpReplacementInterface
30 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
31 ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
32 SmallVector<Value> getNextOperands(Operation *op) const {
33 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
34 if (!isCastLikeExtractSliceOp(extractSliceOp))
35 return {};
36 return {extractSliceOp.getSource()};
37 }
38};
39
40struct InsertSliceOpReplacementInterface
41 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
42 InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
43 SmallVector<Value> getNextOperands(Operation *op) const {
44 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
45 if (!isCastLikeInsertSliceOp(insertSliceOp))
46 return {};
47 return {insertSliceOp.getSource()};
48 }
49};
50
51struct ReshapeOpReplacementInterface
52 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
53 ReshapeOpReplacementInterface, tensor::ReshapeOp> {
54 SmallVector<Value> getNextOperands(Operation *op) const {
55 auto reshapeOp = cast<tensor::ReshapeOp>(op);
56 return {reshapeOp.getSource()};
57 }
58};
59
60template <typename ConcreteOp>
61struct ReassociativeReshapeOpReplacementInterface
62 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
63 ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
64 SmallVector<Value> getNextOperands(Operation *op) const {
65 auto reshapeOp = cast<ConcreteOp>(op);
66 return {reshapeOp.getSrc()};
67 }
68};
69} // namespace
70
71void tensor::registerFindPayloadReplacementOpInterfaceExternalModels(
72 DialectRegistry &registry) {
73 registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
74 CollapseShapeOp::attachInterface<
75 ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
76 ExpandShapeOp::attachInterface<
77 ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
78 ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
79 InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
80 ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
81 });
82}
83
84//===----------------------------------------------------------------------===//
85// Apply...PatternsOp
86//===----------------------------------------------------------------------===//
87
88void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
89 RewritePatternSet &patterns) {
90 tensor::populateDecomposeTensorConcatPatterns(patterns);
91}
92
93void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
94 populatePatterns(RewritePatternSet &patterns) {
95 tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);
96}
97
98void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
99 RewritePatternSet &patterns) {
100 tensor::populateFoldTensorEmptyPatterns(patterns, getFoldSingleUseOnly());
101}
102
103void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
104 RewritePatternSet &patterns) {
105 tensor::populateFoldTensorSubsetOpPatterns(patterns);
106}
107
108void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
109 populatePatterns(RewritePatternSet &patterns) {
110 tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
111}
112
113void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
114 populatePatterns(RewritePatternSet &patterns) {
115 tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
116}
117
118void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
119 RewritePatternSet &patterns) {
120 tensor::populateReassociativeReshapeFoldingPatterns(patterns);
121}
122
123void transform::ApplyBubbleUpExtractSlicePatternsOp::populatePatterns(
124 RewritePatternSet &patterns) {
125 tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
126}
127
128void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
129 RewritePatternSet &patterns) {
130 ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {
131 Operation *producer = fusedOperand->get().getDefiningOp();
132 return producer && producer->hasOneUse();
133 };
134
135 ControlFoldFn aggressiveControlFn = [](OpOperand *fusedOperand) {
136 return true;
137 };
138
139 // Add folding with reshape by expansion patterns.
140 if (getAggressive())
141 tensor::populateRewriteAsConstantPatterns(patterns, aggressiveControlFn);
142 else
143 tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn);
144}
145
146//===----------------------------------------------------------------------===//
147// TypeConversionCastTensorShapeOp
148//===----------------------------------------------------------------------===//
149
150void transform::TypeConversionCastShapeDynamicDimsOp::
151 populateTypeMaterializations(TypeConverter &converter) {
152 bool ignoreDynamicInfo = getIgnoreDynamicInfo();
153 converter.addSourceMaterialization([ignoreDynamicInfo](
154 OpBuilder &builder, Type resultType,
155 ValueRange inputs,
156 Location loc) -> Value {
157 if (inputs.size() != 1) {
158 return Value();
159 }
160 Value input = inputs[0];
161 if (!ignoreDynamicInfo &&
162 !tensor::preservesStaticInformation(resultType, input.getType())) {
163 return Value();
164 }
165 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
166 return Value();
167 }
168 return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
169 });
170 converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
171 ValueRange inputs,
172 Location loc) -> Value {
173 if (inputs.size() != 1) {
174 return Value();
175 }
176 Value input = inputs[0];
177 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
178 return Value();
179 }
180 return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
181 });
182}
183
184//===----------------------------------------------------------------------===//
185// MakeLoopIndependentOp
186//===----------------------------------------------------------------------===//
187
188DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne(
189 transform::TransformRewriter &rewriter, Operation *target,
190 transform::ApplyToEachResultList &results,
191 transform::TransformState &state) {
192 // Gather IVs.
193 SmallVector<Value> ivs;
194 Operation *nextOp = target;
195 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
196 nextOp = nextOp->getParentOfType<scf::ForOp>();
197 if (!nextOp) {
198 DiagnosedSilenceableFailure diag = emitSilenceableError()
199 << "could not find " << i
200 << "-th enclosing loop";
201 diag.attachNote(target->getLoc()) << "target op";
202 return diag;
203 }
204 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
205 }
206
207 // Rewrite IR.
208 FailureOr<Value> replacement = failure();
209 if (auto padOp = dyn_cast<tensor::PadOp>(target)) {
210 replacement = tensor::buildIndependentOp(rewriter, padOp, ivs);
211 } else if (auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) {
212 replacement = tensor::buildIndependentOp(rewriter, emptyOp, ivs);
213 } else {
214 DiagnosedSilenceableFailure diag = emitSilenceableError()
215 << "unsupported target op";
216 diag.attachNote(target->getLoc()) << "target op";
217 return diag;
218 }
219 if (failed(replacement)) {
220 DiagnosedSilenceableFailure diag =
221 emitSilenceableError() << "could not make target op loop-independent";
222 diag.attachNote(target->getLoc()) << "target op";
223 return diag;
224 }
225 rewriter.replaceOp(target, *replacement);
226 results.push_back(replacement->getDefiningOp());
227 return DiagnosedSilenceableFailure::success();
228}
229
230//===----------------------------------------------------------------------===//
231// Transform op registration
232//===----------------------------------------------------------------------===//
233
234namespace {
235class TensorTransformDialectExtension
236 : public transform::TransformDialectExtension<
237 TensorTransformDialectExtension> {
238public:
239 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TensorTransformDialectExtension)
240
241 using Base::Base;
242
243 void init() {
244 declareGeneratedDialect<affine::AffineDialect>();
245 declareGeneratedDialect<tensor::TensorDialect>();
246
247 registerTransformOps<
248#define GET_OP_LIST
249#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
250 >();
251 }
252};
253} // namespace
254
255#define GET_OP_CLASSES
256#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
257
258void mlir::tensor::registerTransformDialectExtension(
259 DialectRegistry &registry) {
260 registry.addExtensions<TensorTransformDialectExtension>();
261}
262

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp