1//===- TensorTransformOps.cpp - Implementation of tensor transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/SCF/IR/SCF.h"
13#include "mlir/Dialect/Tensor/IR/Tensor.h"
14#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
15#include "mlir/Dialect/Tensor/Utils/Utils.h"
16#include "mlir/Dialect/Transform/IR/TransformDialect.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21using namespace mlir;
22using namespace tensor;
23
24//===----------------------------------------------------------------------===//
25// FindPayloadReplacementOpInterface implementations
26//===----------------------------------------------------------------------===//
27
28namespace {
29struct ExtractSliceOpReplacementInterface
30 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
31 ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
32 SmallVector<Value> getNextOperands(Operation *op) const {
33 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
34 if (!isCastLikeExtractSliceOp(extractSliceOp))
35 return {};
36 return {extractSliceOp.getSource()};
37 }
38};
39
40struct InsertSliceOpReplacementInterface
41 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
42 InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
43 SmallVector<Value> getNextOperands(Operation *op) const {
44 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
45 if (!isCastLikeInsertSliceOp(insertSliceOp))
46 return {};
47 return {insertSliceOp.getSource()};
48 }
49};
50
51struct ReshapeOpReplacementInterface
52 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
53 ReshapeOpReplacementInterface, tensor::ReshapeOp> {
54 SmallVector<Value> getNextOperands(Operation *op) const {
55 auto reshapeOp = cast<tensor::ReshapeOp>(op);
56 return {reshapeOp.getSource()};
57 }
58};
59
60template <typename ConcreteOp>
61struct ReassociativeReshapeOpReplacementInterface
62 : public transform::FindPayloadReplacementOpInterface::ExternalModel<
63 ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
64 SmallVector<Value> getNextOperands(Operation *op) const {
65 auto reshapeOp = cast<ConcreteOp>(op);
66 return {reshapeOp.getSrc()};
67 }
68};
69} // namespace
70
71void tensor::registerFindPayloadReplacementOpInterfaceExternalModels(
72 DialectRegistry &registry) {
73 registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
74 CollapseShapeOp::attachInterface<
75 ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
76 ExpandShapeOp::attachInterface<
77 ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
78 ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
79 InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
80 ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
81 });
82}
83
84//===----------------------------------------------------------------------===//
85// Apply...PatternsOp
86//===----------------------------------------------------------------------===//
87
88void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
89 RewritePatternSet &patterns) {
90 tensor::populateDecomposeTensorConcatPatterns(patterns);
91}
92
93void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
94 populatePatterns(RewritePatternSet &patterns) {
95 tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);
96}
97
98void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
99 RewritePatternSet &patterns) {
100 tensor::populateFoldTensorEmptyPatterns(patterns, getFoldSingleUseOnly());
101}
102
103void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
104 RewritePatternSet &patterns) {
105 tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
106}
107
108void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
109 RewritePatternSet &patterns) {
110 tensor::populateFoldTensorSubsetOpPatterns(patterns);
111}
112
113void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
114 populatePatterns(RewritePatternSet &patterns) {
115 tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
116}
117
118void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
119 populatePatterns(RewritePatternSet &patterns) {
120 tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
121}
122
123void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
124 RewritePatternSet &patterns) {
125 tensor::populateReassociativeReshapeFoldingPatterns(patterns);
126}
127
128void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
129 RewritePatternSet &patterns) {
130 tensor::populateRewriteAsConstantPatterns(patterns);
131}
132
133//===----------------------------------------------------------------------===//
134// TypeConversionCastTensorShapeOp
135//===----------------------------------------------------------------------===//
136
137void transform::TypeConversionCastShapeDynamicDimsOp::
138 populateTypeMaterializations(TypeConverter &converter) {
139 bool ignoreDynamicInfo = getIgnoreDynamicInfo();
140 converter.addSourceMaterialization([ignoreDynamicInfo](
141 OpBuilder &builder, Type resultType,
142 ValueRange inputs,
143 Location loc) -> std::optional<Value> {
144 if (inputs.size() != 1) {
145 return std::nullopt;
146 }
147 Value input = inputs[0];
148 if (!ignoreDynamicInfo &&
149 !tensor::preservesStaticInformation(resultType, input.getType())) {
150 return std::nullopt;
151 }
152 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
153 return std::nullopt;
154 }
155 return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
156 });
157 converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
158 ValueRange inputs,
159 Location loc) -> std::optional<Value> {
160 if (inputs.size() != 1) {
161 return std::nullopt;
162 }
163 Value input = inputs[0];
164 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
165 return std::nullopt;
166 }
167 return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
168 });
169}
170
171//===----------------------------------------------------------------------===//
172// MakeLoopIndependentOp
173//===----------------------------------------------------------------------===//
174
175DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne(
176 transform::TransformRewriter &rewriter, Operation *target,
177 transform::ApplyToEachResultList &results,
178 transform::TransformState &state) {
179 // Gather IVs.
180 SmallVector<Value> ivs;
181 Operation *nextOp = target;
182 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
183 nextOp = nextOp->getParentOfType<scf::ForOp>();
184 if (!nextOp) {
185 DiagnosedSilenceableFailure diag = emitSilenceableError()
186 << "could not find " << i
187 << "-th enclosing loop";
188 diag.attachNote(target->getLoc()) << "target op";
189 return diag;
190 }
191 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
192 }
193
194 // Rewrite IR.
195 FailureOr<Value> replacement = failure();
196 if (auto padOp = dyn_cast<tensor::PadOp>(target)) {
197 replacement = tensor::buildIndependentOp(rewriter, padOp, ivs);
198 } else if (auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) {
199 replacement = tensor::buildIndependentOp(rewriter, emptyOp, ivs);
200 } else {
201 DiagnosedSilenceableFailure diag = emitSilenceableError()
202 << "unsupported target op";
203 diag.attachNote(target->getLoc()) << "target op";
204 return diag;
205 }
206 if (failed(replacement)) {
207 DiagnosedSilenceableFailure diag =
208 emitSilenceableError() << "could not make target op loop-independent";
209 diag.attachNote(target->getLoc()) << "target op";
210 return diag;
211 }
212 rewriter.replaceOp(target, *replacement);
213 results.push_back(replacement->getDefiningOp());
214 return DiagnosedSilenceableFailure::success();
215}
216
217//===----------------------------------------------------------------------===//
218// Transform op registration
219//===----------------------------------------------------------------------===//
220
221namespace {
222class TensorTransformDialectExtension
223 : public transform::TransformDialectExtension<
224 TensorTransformDialectExtension> {
225public:
226 using Base::Base;
227
228 void init() {
229 declareGeneratedDialect<affine::AffineDialect>();
230 declareGeneratedDialect<tensor::TensorDialect>();
231
232 registerTransformOps<
233#define GET_OP_LIST
234#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
235 >();
236 }
237};
238} // namespace
239
240#define GET_OP_CLASSES
241#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
242
243void mlir::tensor::registerTransformDialectExtension(
244 DialectRegistry &registry) {
245 registry.addExtensions<TensorTransformDialectExtension>();
246}
247

source code of mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp