1 | //===- TensorTransformOps.cpp - Implementation of tensor transform ops ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/SCF/IR/SCF.h" |
13 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
14 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
15 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
16 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
17 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/Transforms/DialectConversion.h" |
20 | |
21 | using namespace mlir; |
22 | using namespace tensor; |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | // FindPayloadReplacementOpInterface implementations |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | namespace { |
29 | struct |
30 | : public transform::FindPayloadReplacementOpInterface::ExternalModel< |
31 | ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> { |
32 | SmallVector<Value> getNextOperands(Operation *op) const { |
33 | auto = cast<tensor::ExtractSliceOp>(op); |
34 | if (!isCastLikeExtractSliceOp(extractSliceOp)) |
35 | return {}; |
36 | return {extractSliceOp.getSource()}; |
37 | } |
38 | }; |
39 | |
40 | struct InsertSliceOpReplacementInterface |
41 | : public transform::FindPayloadReplacementOpInterface::ExternalModel< |
42 | InsertSliceOpReplacementInterface, tensor::InsertSliceOp> { |
43 | SmallVector<Value> getNextOperands(Operation *op) const { |
44 | auto insertSliceOp = cast<tensor::InsertSliceOp>(op); |
45 | if (!isCastLikeInsertSliceOp(insertSliceOp)) |
46 | return {}; |
47 | return {insertSliceOp.getSource()}; |
48 | } |
49 | }; |
50 | |
51 | struct ReshapeOpReplacementInterface |
52 | : public transform::FindPayloadReplacementOpInterface::ExternalModel< |
53 | ReshapeOpReplacementInterface, tensor::ReshapeOp> { |
54 | SmallVector<Value> getNextOperands(Operation *op) const { |
55 | auto reshapeOp = cast<tensor::ReshapeOp>(op); |
56 | return {reshapeOp.getSource()}; |
57 | } |
58 | }; |
59 | |
60 | template <typename ConcreteOp> |
61 | struct ReassociativeReshapeOpReplacementInterface |
62 | : public transform::FindPayloadReplacementOpInterface::ExternalModel< |
63 | ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> { |
64 | SmallVector<Value> getNextOperands(Operation *op) const { |
65 | auto reshapeOp = cast<ConcreteOp>(op); |
66 | return {reshapeOp.getSrc()}; |
67 | } |
68 | }; |
69 | } // namespace |
70 | |
71 | void tensor::registerFindPayloadReplacementOpInterfaceExternalModels( |
72 | DialectRegistry ®istry) { |
73 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) { |
74 | CollapseShapeOp::attachInterface< |
75 | ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx); |
76 | ExpandShapeOp::attachInterface< |
77 | ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx); |
78 | ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx); |
79 | InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx); |
80 | ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx); |
81 | }); |
82 | } |
83 | |
84 | //===----------------------------------------------------------------------===// |
85 | // Apply...PatternsOp |
86 | //===----------------------------------------------------------------------===// |
87 | |
88 | void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns( |
89 | RewritePatternSet &patterns) { |
90 | tensor::populateDecomposeTensorConcatPatterns(patterns); |
91 | } |
92 | |
93 | void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp:: |
94 | populatePatterns(RewritePatternSet &patterns) { |
95 | tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); |
96 | } |
97 | |
98 | void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns( |
99 | RewritePatternSet &patterns) { |
100 | tensor::populateFoldTensorEmptyPatterns(patterns, getFoldSingleUseOnly()); |
101 | } |
102 | |
103 | void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns( |
104 | RewritePatternSet &patterns) { |
105 | tensor::populateFoldIntoPackAndUnpackPatterns(patterns); |
106 | } |
107 | |
108 | void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns( |
109 | RewritePatternSet &patterns) { |
110 | tensor::populateFoldTensorSubsetOpPatterns(patterns); |
111 | } |
112 | |
113 | void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp:: |
114 | populatePatterns(RewritePatternSet &patterns) { |
115 | tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); |
116 | } |
117 | |
118 | void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp:: |
119 | populatePatterns(RewritePatternSet &patterns) { |
120 | tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); |
121 | } |
122 | |
123 | void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns( |
124 | RewritePatternSet &patterns) { |
125 | tensor::populateReassociativeReshapeFoldingPatterns(patterns); |
126 | } |
127 | |
128 | void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns( |
129 | RewritePatternSet &patterns) { |
130 | tensor::populateRewriteAsConstantPatterns(patterns); |
131 | } |
132 | |
133 | //===----------------------------------------------------------------------===// |
134 | // TypeConversionCastTensorShapeOp |
135 | //===----------------------------------------------------------------------===// |
136 | |
137 | void transform::TypeConversionCastShapeDynamicDimsOp:: |
138 | populateTypeMaterializations(TypeConverter &converter) { |
139 | bool ignoreDynamicInfo = getIgnoreDynamicInfo(); |
140 | converter.addSourceMaterialization([ignoreDynamicInfo]( |
141 | OpBuilder &builder, Type resultType, |
142 | ValueRange inputs, |
143 | Location loc) -> std::optional<Value> { |
144 | if (inputs.size() != 1) { |
145 | return std::nullopt; |
146 | } |
147 | Value input = inputs[0]; |
148 | if (!ignoreDynamicInfo && |
149 | !tensor::preservesStaticInformation(resultType, input.getType())) { |
150 | return std::nullopt; |
151 | } |
152 | if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) { |
153 | return std::nullopt; |
154 | } |
155 | return builder.create<tensor::CastOp>(loc, resultType, input).getResult(); |
156 | }); |
157 | converter.addTargetMaterialization([](OpBuilder &builder, Type resultType, |
158 | ValueRange inputs, |
159 | Location loc) -> std::optional<Value> { |
160 | if (inputs.size() != 1) { |
161 | return std::nullopt; |
162 | } |
163 | Value input = inputs[0]; |
164 | if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) { |
165 | return std::nullopt; |
166 | } |
167 | return builder.create<tensor::CastOp>(loc, resultType, input).getResult(); |
168 | }); |
169 | } |
170 | |
171 | //===----------------------------------------------------------------------===// |
172 | // MakeLoopIndependentOp |
173 | //===----------------------------------------------------------------------===// |
174 | |
175 | DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne( |
176 | transform::TransformRewriter &rewriter, Operation *target, |
177 | transform::ApplyToEachResultList &results, |
178 | transform::TransformState &state) { |
179 | // Gather IVs. |
180 | SmallVector<Value> ivs; |
181 | Operation *nextOp = target; |
182 | for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) { |
183 | nextOp = nextOp->getParentOfType<scf::ForOp>(); |
184 | if (!nextOp) { |
185 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
186 | << "could not find " << i |
187 | << "-th enclosing loop" ; |
188 | diag.attachNote(target->getLoc()) << "target op" ; |
189 | return diag; |
190 | } |
191 | ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar()); |
192 | } |
193 | |
194 | // Rewrite IR. |
195 | FailureOr<Value> replacement = failure(); |
196 | if (auto padOp = dyn_cast<tensor::PadOp>(target)) { |
197 | replacement = tensor::buildIndependentOp(rewriter, padOp, ivs); |
198 | } else if (auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) { |
199 | replacement = tensor::buildIndependentOp(rewriter, emptyOp, ivs); |
200 | } else { |
201 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
202 | << "unsupported target op" ; |
203 | diag.attachNote(target->getLoc()) << "target op" ; |
204 | return diag; |
205 | } |
206 | if (failed(replacement)) { |
207 | DiagnosedSilenceableFailure diag = |
208 | emitSilenceableError() << "could not make target op loop-independent" ; |
209 | diag.attachNote(target->getLoc()) << "target op" ; |
210 | return diag; |
211 | } |
212 | rewriter.replaceOp(target, *replacement); |
213 | results.push_back(replacement->getDefiningOp()); |
214 | return DiagnosedSilenceableFailure::success(); |
215 | } |
216 | |
217 | //===----------------------------------------------------------------------===// |
218 | // Transform op registration |
219 | //===----------------------------------------------------------------------===// |
220 | |
221 | namespace { |
222 | class TensorTransformDialectExtension |
223 | : public transform::TransformDialectExtension< |
224 | TensorTransformDialectExtension> { |
225 | public: |
226 | using Base::Base; |
227 | |
228 | void init() { |
229 | declareGeneratedDialect<affine::AffineDialect>(); |
230 | declareGeneratedDialect<tensor::TensorDialect>(); |
231 | |
232 | registerTransformOps< |
233 | #define GET_OP_LIST |
234 | #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc" |
235 | >(); |
236 | } |
237 | }; |
238 | } // namespace |
239 | |
240 | #define GET_OP_CLASSES |
241 | #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc" |
242 | |
243 | void mlir::tensor::registerTransformDialectExtension( |
244 | DialectRegistry ®istry) { |
245 | registry.addExtensions<TensorTransformDialectExtension>(); |
246 | } |
247 | |