1 | //===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines transform dialect operations used for testing |
10 | // TilingInterface |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
16 | #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" |
17 | #include "mlir/Dialect/Transform/IR/TransformAttrs.h" |
18 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
19 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
20 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
21 | #include "mlir/IR/Dominance.h" |
22 | #include "mlir/IR/OpImplementation.h" |
23 | #include "mlir/Interfaces/TilingInterface.h" |
24 | |
25 | #define GET_OP_CLASSES |
26 | #include "TestTilingInterfaceTransformOps.h.inc" |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::transform; |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // TestFuseAndYieldOp |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) { |
36 | SmallVector<Operation *> worklist; |
37 | llvm::SmallDenseSet<Operation *> producers; |
38 | worklist.push_back(op); |
39 | producers.insert(op); |
40 | while (!worklist.empty()) { |
41 | Operation *current = worklist.pop_back_val(); |
42 | for (OpOperand &operand : current->getOpOperands()) { |
43 | Operation *producer = operand.get().getDefiningOp(); |
44 | if (!producer || !isa<TilingInterface>(producer) || |
45 | producers.contains(producer)) |
46 | continue; |
47 | worklist.push_back(producer); |
48 | producers.insert(producer); |
49 | } |
50 | } |
51 | return producers; |
52 | } |
53 | |
54 | /// Apply a tile and fuse transformation to all payload ops and store both the |
55 | /// tiled operation as well as the created tile loops. |
56 | template <typename Range> |
57 | static LogicalResult |
58 | applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, |
59 | Range &&payloadOps, unsigned numLoops, |
60 | ArrayRef<OpFoldResult> tileSizes, |
61 | ArrayRef<int64_t> interchange, bool useForall, |
62 | TransformResults &transformResults) { |
63 | SmallVector<Operation *> tiledOps; |
64 | SmallVector<SmallVector<Operation *>> loopOps(numLoops); |
65 | |
66 | for (Operation *target : payloadOps) { |
67 | auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
68 | if (!tilingInterfaceOp) |
69 | return transformOp->emitError(message: "only TilingInterface ops are supported" ); |
70 | DominanceInfo dominanceInfo(tilingInterfaceOp); |
71 | |
72 | llvm::SmallDenseSet<Operation *> tiledAndFusedOps = |
73 | collectTiledAndFusedOps(tilingInterfaceOp); |
74 | llvm::DenseSet<Operation *> yieldReplacementsFor; |
75 | for (auto op : tiledAndFusedOps) { |
76 | if (llvm::any_of(op->getUsers(), [&](Operation *user) { |
77 | return dominanceInfo.properlyDominates(tilingInterfaceOp, user); |
78 | })) { |
79 | yieldReplacementsFor.insert(op); |
80 | } |
81 | } |
82 | |
83 | scf::SCFTilingOptions tilingOptions; |
84 | tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); |
85 | if (useForall) { |
86 | tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); |
87 | } |
88 | |
89 | scf::SCFTileAndFuseOptions tileAndFuseOptions; |
90 | tileAndFuseOptions.setTilingOptions(tilingOptions); |
91 | |
92 | scf::SCFTileAndFuseOptions::ControlFnTy controlFn = |
93 | [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, |
94 | bool isDestinationOperand) |
95 | -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> { |
96 | Operation *owner = originalProducer.getOwner(); |
97 | bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); |
98 | return scf::SCFTileAndFuseOptions::ControlFnResult{ |
99 | yieldProducerReplacement}; |
100 | }; |
101 | tileAndFuseOptions.setFusionControlFn(controlFn); |
102 | |
103 | rewriter.setInsertionPoint(target); |
104 | FailureOr<scf::SCFTileAndFuseResult> tiledResults = |
105 | scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, |
106 | tileAndFuseOptions); |
107 | if (failed(tiledResults)) |
108 | return failure(); |
109 | |
110 | // Perform the replacement of tiled and fused values. |
111 | SmallVector<Operation *> opsToReplace{target}; |
112 | llvm::append_range(opsToReplace, tiledResults->fusedProducers); |
113 | for (Operation *toReplace : opsToReplace) { |
114 | for (OpResult res : toReplace->getResults()) |
115 | if (auto replacement = tiledResults->replacements.lookup(res)) { |
116 | Operation *replacementOp = replacement.getDefiningOp(); |
117 | rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) { |
118 | Operation *user = use.getOwner(); |
119 | return dominanceInfo.properlyDominates(a: replacementOp, b: user) && |
120 | user->getParentOp() == replacementOp->getParentOp(); |
121 | }); |
122 | } |
123 | |
124 | if (toReplace->use_empty()) { |
125 | rewriter.eraseOp(op: toReplace); |
126 | } |
127 | } |
128 | |
129 | // Report back the relevant handles to the transform op. |
130 | tiledOps.push_back(Elt: tiledResults->tiledAndFusedOps.front()); |
131 | assert(tiledResults->loops.size() == numLoops && |
132 | "Mismatched number of loops, tile and fuse transform should have " |
133 | "failed" ); |
134 | for (unsigned int i = 0; i < numLoops; ++i) |
135 | loopOps[i].push_back(Elt: tiledResults->loops[i]); |
136 | } |
137 | |
138 | transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledOps); |
139 | for (unsigned int i = 0; i < numLoops; ++i) |
140 | transformResults.set(value: transformOp->getOpResult(idx: i + 1), ops&: loopOps[i]); |
141 | |
142 | return success(); |
143 | } |
144 | |
145 | DiagnosedSilenceableFailure |
146 | transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, |
147 | TransformResults &transformResults, |
148 | TransformState &state) { |
149 | SmallVector<int64_t> tileSizes = |
150 | extractFromIntegerArrayAttr<int64_t>(getTileSizes()); |
151 | SmallVector<int64_t> tileInterchange = |
152 | extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); |
153 | |
154 | SmallVector<OpFoldResult> tileSizesOfr = |
155 | getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); |
156 | |
157 | LogicalResult result = applyTileAndFuseToAll( |
158 | rewriter, getOperation(), state.getPayloadOps(getTarget()), |
159 | tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, |
160 | tileInterchange, getUseForall(), transformResults); |
161 | return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
162 | : DiagnosedSilenceableFailure::success(); |
163 | } |
164 | |
165 | //===----------------------------------------------------------------------===// |
166 | // TestFuseConsumerOp |
167 | //===----------------------------------------------------------------------===// |
168 | |
169 | /// Apply fusing of consumer transformation to all payload ops and store both |
170 | /// the original consumer operation as well as the fused consumer operation. |
171 | template <typename Range> |
172 | static LogicalResult applyFuseConsumer( |
173 | RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, |
174 | MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse, |
175 | TransformResults &transformResults) { |
176 | SmallVector<Operation *> originalConsumerOps; |
177 | SmallVector<Operation *> fusedConsumerOps; |
178 | |
179 | for (Operation *target : payloadOps) { |
180 | rewriter.setInsertionPoint(target); |
181 | |
182 | while (numConsumerToFuse--) { |
183 | FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults = |
184 | scf::tileAndFuseConsumerOfSlice(rewriter, target, loops); |
185 | |
186 | if (failed(Result: fuseConsumerResults)) |
187 | return failure(); |
188 | |
189 | // Report back the relevant handles to the transform op. |
190 | originalConsumerOps.push_back( |
191 | Elt: fuseConsumerResults->origConsumerOperand->getOwner()); |
192 | fusedConsumerOps.push_back( |
193 | Elt: fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner()); |
194 | } |
195 | } |
196 | |
197 | transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: originalConsumerOps); |
198 | transformResults.set(value: transformOp->getOpResult(idx: 1), ops&: fusedConsumerOps); |
199 | return success(); |
200 | } |
201 | |
202 | DiagnosedSilenceableFailure |
203 | transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, |
204 | TransformResults &transformResults, |
205 | TransformState &state) { |
206 | SmallVector<LoopLikeOpInterface> loops; |
207 | for (auto op : llvm::reverse(getLoops())) { |
208 | auto loopLikeOp = |
209 | dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(op).begin()); |
210 | if (!loopLikeOp) { |
211 | return DiagnosedSilenceableFailure::definiteFailure(); |
212 | } |
213 | loops.push_back(loopLikeOp); |
214 | } |
215 | LogicalResult result = applyFuseConsumer( |
216 | rewriter, getOperation(), state.getPayloadOps(getTarget()), loops, |
217 | getNumConsumerToFuse(), transformResults); |
218 | return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
219 | : DiagnosedSilenceableFailure::success(); |
220 | } |
221 | |
222 | void transform::TestFuseConsumerOp::getEffects( |
223 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
224 | consumesHandle(getTargetMutable(), effects); |
225 | consumesHandle(getLoopsMutable(), effects); |
226 | producesHandle(getOperation()->getOpResults(), effects); |
227 | modifiesPayload(effects); |
228 | } |
229 | |
230 | //===----------------------------------------------------------------------===// |
231 | // TestTileUsingForallOp |
232 | //===----------------------------------------------------------------------===// |
233 | |
234 | /// Apply a tiling transformation to all payload ops and store both the |
235 | /// tiled operation as well as the created tile loops. |
236 | template <typename Range> |
237 | static LogicalResult |
238 | applyTileToAll(RewriterBase &rewriter, Operation *transformOp, |
239 | Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes, |
240 | ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping, |
241 | TransformResults &transformResults) { |
242 | SmallVector<Operation *> tiledOps; |
243 | SmallVector<Operation *> loopOps; |
244 | |
245 | for (Operation *target : payloadOps) { |
246 | auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
247 | if (!tilingInterfaceOp) |
248 | return transformOp->emitError(message: "only TilingInterface ops are supported" ); |
249 | scf::SCFTilingOptions tilingOptions; |
250 | tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); |
251 | if (mapping) { |
252 | tilingOptions.setMapping(mapping.value().getValue()); |
253 | } |
254 | tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); |
255 | |
256 | rewriter.setInsertionPoint(target); |
257 | FailureOr<scf::SCFTilingResult> tiledResults = |
258 | scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions); |
259 | if (failed(tiledResults)) |
260 | return failure(); |
261 | |
262 | // Perform the replacement of tiled and fused values. |
263 | rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements); |
264 | |
265 | // Report back the relevant handles to the transform op. |
266 | tiledOps.push_back(Elt: tiledResults->tiledOps.front()); |
267 | for (Operation *loop : tiledResults->loops) |
268 | loopOps.push_back(loop); |
269 | } |
270 | |
271 | transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledOps); |
272 | for (auto [index, loop] : llvm::enumerate(First&: loopOps)) |
273 | transformResults.set(value: transformOp->getOpResult(idx: index + 1), ops: {loop}); |
274 | |
275 | return success(); |
276 | } |
277 | |
278 | DiagnosedSilenceableFailure |
279 | transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter, |
280 | TransformResults &transformResults, |
281 | TransformState &state) { |
282 | SmallVector<int64_t> tileSizes = |
283 | extractFromIntegerArrayAttr<int64_t>(getTileSizes()); |
284 | SmallVector<int64_t> interchange = |
285 | extractFromIntegerArrayAttr<int64_t>(getInterchange()); |
286 | SmallVector<OpFoldResult> tileSizesOfr = |
287 | getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); |
288 | |
289 | LogicalResult result = |
290 | applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()), |
291 | tileSizesOfr, interchange, getMapping(), transformResults); |
292 | return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
293 | : DiagnosedSilenceableFailure::success(); |
294 | } |
295 | |
296 | void transform::TestTileUsingForallOp::getEffects( |
297 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
298 | consumesHandle(getTargetMutable(), effects); |
299 | producesHandle(getOperation()->getOpResults(), effects); |
300 | modifiesPayload(effects); |
301 | } |
302 | |
303 | //===----------------------------------------------------------------------===// |
304 | // TestFuseUsingForallOp |
305 | //===----------------------------------------------------------------------===// |
306 | |
307 | /// Apply a tiling transformation to all payload ops and store both the |
308 | /// tiled operation as well as the created tile loops. |
309 | template <typename Range> |
310 | static LogicalResult applyTilingToAll( |
311 | RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, |
312 | unsigned numLoops, TransformResults &transformResults, |
313 | function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)> |
314 | applyFn) { |
315 | SmallVector<Operation *> tiledLinalgOps; |
316 | SmallVector<SmallVector<Operation *>> loopOps(1); |
317 | |
318 | for (Operation *target : payloadOps) { |
319 | auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
320 | if (!tilingInterfaceOp) |
321 | return transformOp->emitError(message: "only TilingInterface ops are supported" ); |
322 | |
323 | rewriter.setInsertionPoint(target); |
324 | FailureOr<scf::SCFTileAndFuseResult> tiledResults = |
325 | applyFn(tilingInterfaceOp); |
326 | if (failed(Result: tiledResults)) |
327 | return failure(); |
328 | |
329 | // Perform the replacement of tiled and fused values. |
330 | SmallVector<Operation *> opsToReplace{target}; |
331 | llvm::append_range(opsToReplace, tiledResults->fusedProducers); |
332 | for (Operation *toReplace : opsToReplace) { |
333 | for (OpResult res : toReplace->getResults()) |
334 | if (auto replacement = tiledResults->replacements.lookup(res)) |
335 | rewriter.replaceAllUsesWith(res, replacement); |
336 | if (toReplace->use_empty()) |
337 | rewriter.eraseOp(op: toReplace); |
338 | } |
339 | |
340 | // Report back the relevant handles to the transform op. |
341 | tiledLinalgOps.push_back(Elt: tiledResults->tiledAndFusedOps.front()); |
342 | assert(tiledResults->loops.size() == 1 && |
343 | cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops && |
344 | "Mismatched number of loops, tile and fuse transform should have " |
345 | "failed" ); |
346 | loopOps[0] = {tiledResults->loops[0]}; |
347 | } |
348 | |
349 | transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledLinalgOps); |
350 | if (!loopOps.empty()) |
351 | transformResults.set(value: transformOp->getOpResult(idx: 1), ops&: loopOps[0]); |
352 | |
353 | return success(); |
354 | } |
355 | |
356 | DiagnosedSilenceableFailure |
357 | transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter, |
358 | TransformResults &transformResults, |
359 | TransformState &state) { |
360 | SmallVector<int64_t> tileSizes = |
361 | extractFromIntegerArrayAttr<int64_t>(getTileSizes()); |
362 | SmallVector<int64_t> tileInterchange = |
363 | extractFromIntegerArrayAttr<int64_t>(getInterchange()); |
364 | |
365 | scf::SCFTilingOptions tilingOptions; |
366 | tilingOptions.interchangeVector = tileInterchange; |
367 | SmallVector<OpFoldResult> tileSizesOfr = |
368 | getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); |
369 | tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); |
370 | tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); |
371 | scf::SCFTileAndFuseOptions tileAndFuseOptions; |
372 | tileAndFuseOptions.tilingOptions = tilingOptions; |
373 | LogicalResult result = applyTilingToAll( |
374 | rewriter, getOperation(), state.getPayloadOps(getRootOp()), |
375 | tileSizes.size() - llvm::count(tileSizes, 0), transformResults, |
376 | [&](TilingInterface tilingInterfaceOp) |
377 | -> FailureOr<scf::SCFTileAndFuseResult> { |
378 | return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, |
379 | tileAndFuseOptions); |
380 | }); |
381 | return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
382 | : DiagnosedSilenceableFailure::success(); |
383 | } |
384 | |
385 | void transform::TestFuseUsingForallOp::getEffects( |
386 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
387 | consumesHandle(getRootOpMutable(), effects); |
388 | producesHandle(getOperation()->getOpResults(), effects); |
389 | modifiesPayload(effects); |
390 | } |
391 | |
392 | #define GET_OP_CLASSES |
393 | #include "TestTilingInterfaceTransformOps.cpp.inc" |
394 | |
395 | namespace { |
396 | class TestTilingInterfaceDialectExtension |
397 | : public transform::TransformDialectExtension< |
398 | TestTilingInterfaceDialectExtension> { |
399 | public: |
400 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
401 | TestTilingInterfaceDialectExtension) |
402 | |
403 | using Base::Base; |
404 | |
405 | void init() { |
406 | declareDependentDialect<affine::AffineDialect>(); |
407 | declareDependentDialect<index::IndexDialect>(); |
408 | declareDependentDialect<scf::SCFDialect>(); |
409 | declareDependentDialect<tensor::TensorDialect>(); |
410 | |
411 | registerTransformOps< |
412 | #define GET_OP_LIST |
413 | #include "TestTilingInterfaceTransformOps.cpp.inc" |
414 | >(); |
415 | } |
416 | }; |
417 | } // namespace |
418 | |
419 | namespace test { |
420 | void registerTestTilingInterfaceTransformDialectExtension( |
421 | DialectRegistry ®istry) { |
422 | registry.addExtensions<TestTilingInterfaceDialectExtension>(); |
423 | } |
424 | } // namespace test |
425 | |