1//===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines transform dialect operations used for testing
10// TilingInterface
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Index/IR/IndexDialect.h"
16#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
17#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
18#include "mlir/Dialect/Transform/IR/TransformDialect.h"
19#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/IR/Dominance.h"
22#include "mlir/IR/OpImplementation.h"
23#include "mlir/Interfaces/TilingInterface.h"
24
25#define GET_OP_CLASSES
26#include "TestTilingInterfaceTransformOps.h.inc"
27
28using namespace mlir;
29using namespace mlir::transform;
30
31//===----------------------------------------------------------------------===//
32// TestFuseAndYieldOp
33//===----------------------------------------------------------------------===//
34
35static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
36 SmallVector<Operation *> worklist;
37 llvm::SmallDenseSet<Operation *> producers;
38 worklist.push_back(op);
39 producers.insert(op);
40 while (!worklist.empty()) {
41 Operation *current = worklist.pop_back_val();
42 for (OpOperand &operand : current->getOpOperands()) {
43 Operation *producer = operand.get().getDefiningOp();
44 if (!producer || !isa<TilingInterface>(producer) ||
45 producers.contains(producer))
46 continue;
47 worklist.push_back(producer);
48 producers.insert(producer);
49 }
50 }
51 return producers;
52}
53
54/// Apply a tile and fuse transformation to all payload ops and store both the
55/// tiled operation as well as the created tile loops.
56template <typename Range>
57static LogicalResult
58applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
59 Range &&payloadOps, unsigned numLoops,
60 ArrayRef<OpFoldResult> tileSizes,
61 ArrayRef<int64_t> interchange, bool useForall,
62 TransformResults &transformResults) {
63 SmallVector<Operation *> tiledOps;
64 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
65
66 for (Operation *target : payloadOps) {
67 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
68 if (!tilingInterfaceOp)
69 return transformOp->emitError(message: "only TilingInterface ops are supported");
70 DominanceInfo dominanceInfo(tilingInterfaceOp);
71
72 llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
73 collectTiledAndFusedOps(tilingInterfaceOp);
74 llvm::DenseSet<Operation *> yieldReplacementsFor;
75 for (auto op : tiledAndFusedOps) {
76 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
77 return dominanceInfo.properlyDominates(tilingInterfaceOp, user);
78 })) {
79 yieldReplacementsFor.insert(op);
80 }
81 }
82
83 scf::SCFTilingOptions tilingOptions;
84 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
85 if (useForall) {
86 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
87 }
88
89 scf::SCFTileAndFuseOptions tileAndFuseOptions;
90 tileAndFuseOptions.setTilingOptions(tilingOptions);
91
92 scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
93 [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
94 bool isDestinationOperand)
95 -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
96 Operation *owner = originalProducer.getOwner();
97 bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
98 return scf::SCFTileAndFuseOptions::ControlFnResult{
99 yieldProducerReplacement};
100 };
101 tileAndFuseOptions.setFusionControlFn(controlFn);
102
103 rewriter.setInsertionPoint(target);
104 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
105 scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
106 tileAndFuseOptions);
107 if (failed(tiledResults))
108 return failure();
109
110 // Perform the replacement of tiled and fused values.
111 SmallVector<Operation *> opsToReplace{target};
112 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
113 for (Operation *toReplace : opsToReplace) {
114 for (OpResult res : toReplace->getResults())
115 if (auto replacement = tiledResults->replacements.lookup(res)) {
116 Operation *replacementOp = replacement.getDefiningOp();
117 rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
118 Operation *user = use.getOwner();
119 return dominanceInfo.properlyDominates(a: replacementOp, b: user) &&
120 user->getParentOp() == replacementOp->getParentOp();
121 });
122 }
123
124 if (toReplace->use_empty()) {
125 rewriter.eraseOp(op: toReplace);
126 }
127 }
128
129 // Report back the relevant handles to the transform op.
130 tiledOps.push_back(Elt: tiledResults->tiledAndFusedOps.front());
131 assert(tiledResults->loops.size() == numLoops &&
132 "Mismatched number of loops, tile and fuse transform should have "
133 "failed");
134 for (unsigned int i = 0; i < numLoops; ++i)
135 loopOps[i].push_back(Elt: tiledResults->loops[i]);
136 }
137
138 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledOps);
139 for (unsigned int i = 0; i < numLoops; ++i)
140 transformResults.set(value: transformOp->getOpResult(idx: i + 1), ops&: loopOps[i]);
141
142 return success();
143}
144
145DiagnosedSilenceableFailure
146transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
147 TransformResults &transformResults,
148 TransformState &state) {
149 SmallVector<int64_t> tileSizes =
150 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
151 SmallVector<int64_t> tileInterchange =
152 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
153
154 SmallVector<OpFoldResult> tileSizesOfr =
155 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
156
157 LogicalResult result = applyTileAndFuseToAll(
158 rewriter, getOperation(), state.getPayloadOps(getTarget()),
159 tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
160 tileInterchange, getUseForall(), transformResults);
161 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
162 : DiagnosedSilenceableFailure::success();
163}
164
165//===----------------------------------------------------------------------===//
166// TestFuseConsumerOp
167//===----------------------------------------------------------------------===//
168
169/// Apply fusing of consumer transformation to all payload ops and store both
170/// the original consumer operation as well as the fused consumer operation.
171template <typename Range>
172static LogicalResult applyFuseConsumer(
173 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
174 MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse,
175 TransformResults &transformResults) {
176 SmallVector<Operation *> originalConsumerOps;
177 SmallVector<Operation *> fusedConsumerOps;
178
179 for (Operation *target : payloadOps) {
180 rewriter.setInsertionPoint(target);
181
182 while (numConsumerToFuse--) {
183 FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
184 scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);
185
186 if (failed(Result: fuseConsumerResults))
187 return failure();
188
189 // Report back the relevant handles to the transform op.
190 originalConsumerOps.push_back(
191 Elt: fuseConsumerResults->origConsumerOperand->getOwner());
192 fusedConsumerOps.push_back(
193 Elt: fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
194 }
195 }
196
197 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: originalConsumerOps);
198 transformResults.set(value: transformOp->getOpResult(idx: 1), ops&: fusedConsumerOps);
199 return success();
200}
201
202DiagnosedSilenceableFailure
203transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
204 TransformResults &transformResults,
205 TransformState &state) {
206 SmallVector<LoopLikeOpInterface> loops;
207 for (auto op : llvm::reverse(getLoops())) {
208 auto loopLikeOp =
209 dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(op).begin());
210 if (!loopLikeOp) {
211 return DiagnosedSilenceableFailure::definiteFailure();
212 }
213 loops.push_back(loopLikeOp);
214 }
215 LogicalResult result = applyFuseConsumer(
216 rewriter, getOperation(), state.getPayloadOps(getTarget()), loops,
217 getNumConsumerToFuse(), transformResults);
218 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
219 : DiagnosedSilenceableFailure::success();
220}
221
222void transform::TestFuseConsumerOp::getEffects(
223 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
224 consumesHandle(getTargetMutable(), effects);
225 consumesHandle(getLoopsMutable(), effects);
226 producesHandle(getOperation()->getOpResults(), effects);
227 modifiesPayload(effects);
228}
229
230//===----------------------------------------------------------------------===//
231// TestTileUsingForallOp
232//===----------------------------------------------------------------------===//
233
234/// Apply a tiling transformation to all payload ops and store both the
235/// tiled operation as well as the created tile loops.
236template <typename Range>
237static LogicalResult
238applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
239 Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
240 ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
241 TransformResults &transformResults) {
242 SmallVector<Operation *> tiledOps;
243 SmallVector<Operation *> loopOps;
244
245 for (Operation *target : payloadOps) {
246 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
247 if (!tilingInterfaceOp)
248 return transformOp->emitError(message: "only TilingInterface ops are supported");
249 scf::SCFTilingOptions tilingOptions;
250 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
251 if (mapping) {
252 tilingOptions.setMapping(mapping.value().getValue());
253 }
254 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
255
256 rewriter.setInsertionPoint(target);
257 FailureOr<scf::SCFTilingResult> tiledResults =
258 scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
259 if (failed(tiledResults))
260 return failure();
261
262 // Perform the replacement of tiled and fused values.
263 rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);
264
265 // Report back the relevant handles to the transform op.
266 tiledOps.push_back(Elt: tiledResults->tiledOps.front());
267 for (Operation *loop : tiledResults->loops)
268 loopOps.push_back(loop);
269 }
270
271 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledOps);
272 for (auto [index, loop] : llvm::enumerate(First&: loopOps))
273 transformResults.set(value: transformOp->getOpResult(idx: index + 1), ops: {loop});
274
275 return success();
276}
277
278DiagnosedSilenceableFailure
279transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
280 TransformResults &transformResults,
281 TransformState &state) {
282 SmallVector<int64_t> tileSizes =
283 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
284 SmallVector<int64_t> interchange =
285 extractFromIntegerArrayAttr<int64_t>(getInterchange());
286 SmallVector<OpFoldResult> tileSizesOfr =
287 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
288
289 LogicalResult result =
290 applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()),
291 tileSizesOfr, interchange, getMapping(), transformResults);
292 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
293 : DiagnosedSilenceableFailure::success();
294}
295
296void transform::TestTileUsingForallOp::getEffects(
297 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
298 consumesHandle(getTargetMutable(), effects);
299 producesHandle(getOperation()->getOpResults(), effects);
300 modifiesPayload(effects);
301}
302
303//===----------------------------------------------------------------------===//
304// TestFuseUsingForallOp
305//===----------------------------------------------------------------------===//
306
307/// Apply a tiling transformation to all payload ops and store both the
308/// tiled operation as well as the created tile loops.
309template <typename Range>
310static LogicalResult applyTilingToAll(
311 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
312 unsigned numLoops, TransformResults &transformResults,
313 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
314 applyFn) {
315 SmallVector<Operation *> tiledLinalgOps;
316 SmallVector<SmallVector<Operation *>> loopOps(1);
317
318 for (Operation *target : payloadOps) {
319 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
320 if (!tilingInterfaceOp)
321 return transformOp->emitError(message: "only TilingInterface ops are supported");
322
323 rewriter.setInsertionPoint(target);
324 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
325 applyFn(tilingInterfaceOp);
326 if (failed(Result: tiledResults))
327 return failure();
328
329 // Perform the replacement of tiled and fused values.
330 SmallVector<Operation *> opsToReplace{target};
331 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
332 for (Operation *toReplace : opsToReplace) {
333 for (OpResult res : toReplace->getResults())
334 if (auto replacement = tiledResults->replacements.lookup(res))
335 rewriter.replaceAllUsesWith(res, replacement);
336 if (toReplace->use_empty())
337 rewriter.eraseOp(op: toReplace);
338 }
339
340 // Report back the relevant handles to the transform op.
341 tiledLinalgOps.push_back(Elt: tiledResults->tiledAndFusedOps.front());
342 assert(tiledResults->loops.size() == 1 &&
343 cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
344 "Mismatched number of loops, tile and fuse transform should have "
345 "failed");
346 loopOps[0] = {tiledResults->loops[0]};
347 }
348
349 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledLinalgOps);
350 if (!loopOps.empty())
351 transformResults.set(value: transformOp->getOpResult(idx: 1), ops&: loopOps[0]);
352
353 return success();
354}
355
356DiagnosedSilenceableFailure
357transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
358 TransformResults &transformResults,
359 TransformState &state) {
360 SmallVector<int64_t> tileSizes =
361 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
362 SmallVector<int64_t> tileInterchange =
363 extractFromIntegerArrayAttr<int64_t>(getInterchange());
364
365 scf::SCFTilingOptions tilingOptions;
366 tilingOptions.interchangeVector = tileInterchange;
367 SmallVector<OpFoldResult> tileSizesOfr =
368 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
369 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
370 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
371 scf::SCFTileAndFuseOptions tileAndFuseOptions;
372 tileAndFuseOptions.tilingOptions = tilingOptions;
373 LogicalResult result = applyTilingToAll(
374 rewriter, getOperation(), state.getPayloadOps(getRootOp()),
375 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
376 [&](TilingInterface tilingInterfaceOp)
377 -> FailureOr<scf::SCFTileAndFuseResult> {
378 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
379 tileAndFuseOptions);
380 });
381 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
382 : DiagnosedSilenceableFailure::success();
383}
384
385void transform::TestFuseUsingForallOp::getEffects(
386 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
387 consumesHandle(getRootOpMutable(), effects);
388 producesHandle(getOperation()->getOpResults(), effects);
389 modifiesPayload(effects);
390}
391
392#define GET_OP_CLASSES
393#include "TestTilingInterfaceTransformOps.cpp.inc"
394
395namespace {
396class TestTilingInterfaceDialectExtension
397 : public transform::TransformDialectExtension<
398 TestTilingInterfaceDialectExtension> {
399public:
400 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
401 TestTilingInterfaceDialectExtension)
402
403 using Base::Base;
404
405 void init() {
406 declareDependentDialect<affine::AffineDialect>();
407 declareDependentDialect<index::IndexDialect>();
408 declareDependentDialect<scf::SCFDialect>();
409 declareDependentDialect<tensor::TensorDialect>();
410
411 registerTransformOps<
412#define GET_OP_LIST
413#include "TestTilingInterfaceTransformOps.cpp.inc"
414 >();
415 }
416};
417} // namespace
418
419namespace test {
420void registerTestTilingInterfaceTransformDialectExtension(
421 DialectRegistry &registry) {
422 registry.addExtensions<TestTilingInterfaceDialectExtension>();
423}
424} // namespace test
425

source code of mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp