1//===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines transform dialect operations used for testing
10// TilingInterface
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Index/IR/IndexDialect.h"
16#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
17#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
18#include "mlir/Dialect/Transform/IR/TransformDialect.h"
19#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22#include "mlir/IR/Dominance.h"
23#include "mlir/IR/OpImplementation.h"
24#include "mlir/Interfaces/TilingInterface.h"
25#include "llvm/Support/Debug.h"
26
27#define DEBUG_TYPE "test-tiling-interface"
28
29#define GET_OP_CLASSES
30#include "TestTilingInterfaceTransformOps.h.inc"
31
32using namespace mlir;
33using namespace mlir::transform;
34
35//===----------------------------------------------------------------------===//
36// TestFuseAndYieldOp
37//===----------------------------------------------------------------------===//
38
39static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
40 SmallVector<Operation *> worklist;
41 llvm::SmallDenseSet<Operation *> producers;
42 worklist.push_back(Elt: op);
43 producers.insert(V: op);
44 while (!worklist.empty()) {
45 Operation *current = worklist.pop_back_val();
46 for (OpOperand &operand : current->getOpOperands()) {
47 Operation *producer = operand.get().getDefiningOp();
48 if (!producer || !isa<TilingInterface>(Val: producer) ||
49 producers.contains(V: producer))
50 continue;
51 worklist.push_back(Elt: producer);
52 producers.insert(V: producer);
53 }
54 }
55 return producers;
56}
57
58/// Apply a tile and fuse transformation to all payload ops and store both the
59/// tiled operation as well as the created tile loops.
60template <typename Range>
61static LogicalResult
62applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
63 Range &&payloadOps, unsigned numLoops,
64 scf::SCFTilingOptions tilingOptions,
65 TransformResults &transformResults) {
66 SmallVector<Operation *> tiledOps;
67 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
68
69 for (Operation *target : payloadOps) {
70 auto tilingInterfaceOp = dyn_cast<TilingInterface>(Val: target);
71 if (!tilingInterfaceOp)
72 return transformOp->emitError(message: "only TilingInterface ops are supported");
73 DominanceInfo dominanceInfo(tilingInterfaceOp);
74
75 llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
76 collectTiledAndFusedOps(op: tilingInterfaceOp);
77 llvm::DenseSet<Operation *> yieldReplacementsFor;
78 for (auto op : tiledAndFusedOps) {
79 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
80 return dominanceInfo.properlyDominates(a: tilingInterfaceOp, b: user);
81 })) {
82 yieldReplacementsFor.insert(V: op);
83 }
84 }
85
86 scf::SCFTileAndFuseOptions tileAndFuseOptions;
87 tileAndFuseOptions.setTilingOptions(tilingOptions);
88
89 scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
90 [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
91 bool isDestinationOperand)
92 -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
93 Operation *owner = originalProducer.getOwner();
94 bool yieldProducerReplacement = yieldReplacementsFor.contains(V: owner);
95 return scf::SCFTileAndFuseOptions::ControlFnResult{
96 .yieldProducerReplacement: yieldProducerReplacement};
97 };
98 tileAndFuseOptions.setFusionControlFn(controlFn);
99
100 rewriter.setInsertionPoint(target);
101 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
102 scf::tileConsumerAndFuseProducersUsingSCF(rewriter, consumer: tilingInterfaceOp,
103 options: tileAndFuseOptions);
104 if (failed(Result: tiledResults))
105 return failure();
106
107 // Perform the replacement of tiled and fused values.
108 SmallVector<Operation *> opsToReplace{target};
109 llvm::append_range(C&: opsToReplace, R&: tiledResults->fusedProducers);
110 for (Operation *toReplace : opsToReplace) {
111 for (OpResult res : toReplace->getResults())
112 if (auto replacement = tiledResults->replacements.lookup(Val: res)) {
113 Operation *replacementOp = replacement.getDefiningOp();
114 rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
115 Operation *user = use.getOwner();
116 return dominanceInfo.properlyDominates(a: replacementOp, b: user) &&
117 user->getParentOp() == replacementOp->getParentOp();
118 });
119 }
120
121 if (toReplace->use_empty()) {
122 rewriter.eraseOp(op: toReplace);
123 }
124 }
125
126 // Report back the relevant handles to the transform op.
127 tiledOps.push_back(Elt: tiledResults->tiledAndFusedOps.front());
128 assert(tiledResults->loops.size() == numLoops &&
129 "Mismatched number of loops, tile and fuse transform should have "
130 "failed");
131 for (unsigned int i = 0; i < numLoops; ++i)
132 loopOps[i].push_back(Elt: tiledResults->loops[i]);
133 }
134
135 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledOps);
136 for (unsigned int i = 0; i < numLoops; ++i)
137 transformResults.set(value: transformOp->getOpResult(idx: i + 1), ops&: loopOps[i]);
138
139 return success();
140}
141
142DiagnosedSilenceableFailure
143transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
144 TransformResults &transformResults,
145 TransformState &state) {
146 SmallVector<int64_t> tileSizes =
147 extractFromIntegerArrayAttr<int64_t>(attr: getTileSizes());
148 SmallVector<int64_t> tileInterchange =
149 extractFromIntegerArrayAttr<int64_t>(attr: getTileInterchange());
150
151 SmallVector<OpFoldResult> tileSizesOfr =
152 getAsIndexOpFoldResult(ctx: rewriter.getContext(), values: tileSizes);
153
154 scf::SCFTilingOptions tilingOptions;
155 tilingOptions.setTileSizes(tileSizesOfr).setInterchange(tileInterchange);
156 if (getUseForall()) {
157 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
158 }
159
160 LogicalResult result = applyTileAndFuseToAll(
161 rewriter, transformOp: getOperation(), payloadOps: state.getPayloadOps(value: getTarget()),
162 numLoops: tileSizes.size() - llvm::count(Range&: tileSizes, Element: 0), tilingOptions,
163 transformResults);
164 return failed(Result: result) ? DiagnosedSilenceableFailure::definiteFailure()
165 : DiagnosedSilenceableFailure::success();
166}
167
168//===----------------------------------------------------------------------===//
169// TestFuseConsumerOp
170//===----------------------------------------------------------------------===//
171
172/// Apply fusing of consumer transformation to all payload ops and store both
173/// the original consumer operation as well as the fused consumer operation.
174static LogicalResult applyFuseConsumer(
175 RewriterBase &rewriter, Operation *transformOp,
176 ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
177 uint32_t numConsumerToFuse, TransformResults &transformResults) {
178 SmallVector<Operation *> originalConsumerOps;
179 SmallVector<Operation *> fusedConsumerOps;
180
181 rewriter.setInsertionPoint(slices.front());
182
183 while (numConsumerToFuse--) {
184 FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
185 scf::tileAndFuseConsumerOfSlices(rewriter, candidateSlices: slices, loops);
186
187 if (failed(Result: fuseConsumerResults))
188 return slices.front()->emitOpError(message: "failed to fuse consumer of slice");
189
190 // Report back the relevant handles to the transform op.
191 for (OpOperand *origConsumerOperand :
192 fuseConsumerResults->origConsumerOperands) {
193 originalConsumerOps.push_back(Elt: origConsumerOperand->getOwner());
194 }
195 for (OpOperand *tiledAndFusedConsumerOperand :
196 fuseConsumerResults->tiledAndFusedConsumerOperands) {
197 fusedConsumerOps.push_back(Elt: tiledAndFusedConsumerOperand->getOwner());
198 }
199 }
200
201 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: originalConsumerOps);
202 transformResults.set(value: transformOp->getOpResult(idx: 1), ops&: fusedConsumerOps);
203 return success();
204}
205
206DiagnosedSilenceableFailure
207transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
208 TransformResults &transformResults,
209 TransformState &state) {
210 SmallVector<Operation *> slices;
211 for (auto op : getTargets()) {
212 auto sliceOp = *state.getPayloadOps(value: op).begin();
213 slices.push_back(Elt: sliceOp);
214 }
215
216 SmallVector<LoopLikeOpInterface> loops;
217 for (auto op : llvm::reverse(C: getLoops())) {
218 auto loopLikeOp =
219 dyn_cast<LoopLikeOpInterface>(Val: *state.getPayloadOps(value: op).begin());
220 if (!loopLikeOp) {
221 return DiagnosedSilenceableFailure::definiteFailure();
222 }
223 loops.push_back(Elt: loopLikeOp);
224 }
225 LogicalResult result =
226 applyFuseConsumer(rewriter, transformOp: getOperation(), slices, loops,
227 numConsumerToFuse: getNumConsumerToFuse(), transformResults);
228 return failed(Result: result) ? DiagnosedSilenceableFailure::definiteFailure()
229 : DiagnosedSilenceableFailure::success();
230}
231
232void transform::TestFuseConsumerOp::getEffects(
233 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
234 consumesHandle(handles: getTargetsMutable(), effects);
235 consumesHandle(handles: getLoopsMutable(), effects);
236 producesHandle(handles: getOperation()->getOpResults(), effects);
237 modifiesPayload(effects);
238}
239
240//===----------------------------------------------------------------------===//
241// TestTileUsingForallOp
242//===----------------------------------------------------------------------===//
243
244/// Apply a tiling transformation to all payload ops and store both the
245/// tiled operation as well as the created tile loops.
246template <typename Range>
247static LogicalResult
248applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
249 Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
250 ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
251 TransformResults &transformResults) {
252 SmallVector<Operation *> tiledOps;
253 SmallVector<Operation *> loopOps;
254
255 for (Operation *target : payloadOps) {
256 auto tilingInterfaceOp = dyn_cast<TilingInterface>(Val: target);
257 if (!tilingInterfaceOp)
258 return transformOp->emitError(message: "only TilingInterface ops are supported");
259 scf::SCFTilingOptions tilingOptions;
260 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
261 if (mapping) {
262 tilingOptions.setMapping(mapping.value().getValue());
263 }
264 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
265
266 rewriter.setInsertionPoint(target);
267 FailureOr<scf::SCFTilingResult> tiledResults =
268 scf::tileUsingSCF(rewriter, op: tilingInterfaceOp, options: tilingOptions);
269 if (failed(Result: tiledResults))
270 return failure();
271
272 // Perform the replacement of tiled and fused values.
273 rewriter.replaceOp(op: tilingInterfaceOp, newValues: tiledResults->replacements);
274
275 // Report back the relevant handles to the transform op.
276 tiledOps.push_back(Elt: tiledResults->tiledOps.front());
277 for (Operation *loop : tiledResults->loops)
278 loopOps.push_back(Elt: loop);
279 }
280
281 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledOps);
282 for (auto [index, loop] : llvm::enumerate(First&: loopOps))
283 transformResults.set(value: transformOp->getOpResult(idx: index + 1), ops: {loop});
284
285 return success();
286}
287
288DiagnosedSilenceableFailure
289transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
290 TransformResults &transformResults,
291 TransformState &state) {
292 SmallVector<int64_t> tileSizes =
293 extractFromIntegerArrayAttr<int64_t>(attr: getTileSizes());
294 SmallVector<int64_t> interchange =
295 extractFromIntegerArrayAttr<int64_t>(attr: getInterchange());
296 SmallVector<OpFoldResult> tileSizesOfr =
297 getAsIndexOpFoldResult(ctx: rewriter.getContext(), values: tileSizes);
298
299 LogicalResult result =
300 applyTileToAll(rewriter, transformOp: getOperation(), payloadOps: state.getPayloadOps(value: getTarget()),
301 tileSizes: tileSizesOfr, interchange, mapping: getMapping(), transformResults);
302 return failed(Result: result) ? DiagnosedSilenceableFailure::definiteFailure()
303 : DiagnosedSilenceableFailure::success();
304}
305
306void transform::TestTileUsingForallOp::getEffects(
307 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
308 consumesHandle(handles: getTargetMutable(), effects);
309 producesHandle(handles: getOperation()->getOpResults(), effects);
310 modifiesPayload(effects);
311}
312
313//===----------------------------------------------------------------------===//
314// TestFuseUsingForallOp
315//===----------------------------------------------------------------------===//
316
317/// Apply a tiling transformation to all payload ops and store both the
318/// tiled operation as well as the created tile loops.
319template <typename Range>
320static LogicalResult applyTilingToAll(
321 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
322 unsigned numLoops, TransformResults &transformResults,
323 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
324 applyFn) {
325 SmallVector<Operation *> tiledLinalgOps;
326 SmallVector<SmallVector<Operation *>> loopOps(1);
327
328 for (Operation *target : payloadOps) {
329 auto tilingInterfaceOp = dyn_cast<TilingInterface>(Val: target);
330 if (!tilingInterfaceOp)
331 return transformOp->emitError(message: "only TilingInterface ops are supported");
332
333 rewriter.setInsertionPoint(target);
334 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
335 applyFn(tilingInterfaceOp);
336 if (failed(Result: tiledResults))
337 return failure();
338
339 // Perform the replacement of tiled and fused values.
340 SmallVector<Operation *> opsToReplace{target};
341 llvm::append_range(C&: opsToReplace, R&: tiledResults->fusedProducers);
342 for (Operation *toReplace : opsToReplace) {
343 for (OpResult res : toReplace->getResults())
344 if (auto replacement = tiledResults->replacements.lookup(Val: res))
345 rewriter.replaceAllUsesWith(from: res, to: replacement);
346 if (toReplace->use_empty())
347 rewriter.eraseOp(op: toReplace);
348 }
349
350 // Report back the relevant handles to the transform op.
351 tiledLinalgOps.push_back(Elt: tiledResults->tiledAndFusedOps.front());
352 assert(tiledResults->loops.size() == 1 &&
353 cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
354 "Mismatched number of loops, tile and fuse transform should have "
355 "failed");
356 loopOps[0] = {tiledResults->loops[0]};
357 }
358
359 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledLinalgOps);
360 if (!loopOps.empty())
361 transformResults.set(value: transformOp->getOpResult(idx: 1), ops&: loopOps[0]);
362
363 return success();
364}
365
366DiagnosedSilenceableFailure
367transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
368 TransformResults &transformResults,
369 TransformState &state) {
370 SmallVector<int64_t> tileSizes =
371 extractFromIntegerArrayAttr<int64_t>(attr: getTileSizes());
372 SmallVector<int64_t> tileInterchange =
373 extractFromIntegerArrayAttr<int64_t>(attr: getInterchange());
374
375 scf::SCFTilingOptions tilingOptions;
376 tilingOptions.interchangeVector = tileInterchange;
377 SmallVector<OpFoldResult> tileSizesOfr =
378 getAsIndexOpFoldResult(ctx: rewriter.getContext(), values: tileSizes);
379 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
380 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
381 scf::SCFTileAndFuseOptions tileAndFuseOptions;
382 tileAndFuseOptions.tilingOptions = tilingOptions;
383 LogicalResult result = applyTilingToAll(
384 rewriter, transformOp: getOperation(), payloadOps: state.getPayloadOps(value: getRootOp()),
385 numLoops: tileSizes.size() - llvm::count(Range&: tileSizes, Element: 0), transformResults,
386 applyFn: [&](TilingInterface tilingInterfaceOp)
387 -> FailureOr<scf::SCFTileAndFuseResult> {
388 return tileConsumerAndFuseProducersUsingSCF(rewriter, consumer: tilingInterfaceOp,
389 options: tileAndFuseOptions);
390 });
391 return failed(Result: result) ? DiagnosedSilenceableFailure::definiteFailure()
392 : DiagnosedSilenceableFailure::success();
393}
394
395void transform::TestFuseUsingForallOp::getEffects(
396 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
397 consumesHandle(handles: getRootOpMutable(), effects);
398 producesHandle(handles: getOperation()->getOpResults(), effects);
399 modifiesPayload(effects);
400}
401
402//===----------------------------------------------------------------------===//
403// TestTileAndFuseOuterParallelPartialReduction
404//===----------------------------------------------------------------------===//
405
406DiagnosedSilenceableFailure
407transform::TestTileAndFuseOuterParallelPartialReductionOp::apply(
408 TransformRewriter &rewriter, TransformResults &transformResults,
409 TransformState &state) {
410 auto target =
411 dyn_cast<TilingInterface>(Val: *state.getPayloadOps(value: getRootOp()).begin());
412 if (!target) {
413 emitOpError(message: "expected root operation to implement `TilingInterface`");
414 return DiagnosedSilenceableFailure::definiteFailure();
415 }
416
417 SmallVector<unsigned> reductionDims =
418 extractFromIntegerArrayAttr<unsigned>(attr: getReductionDims());
419 if (reductionDims.empty()) {
420 for (auto [index, iterator] :
421 llvm::enumerate(First: target.getLoopIteratorTypes()))
422 if (iterator == utils::IteratorType::reduction)
423 reductionDims.push_back(Elt: index);
424 }
425
426 if (reductionDims.empty()) {
427 emitOpError(
428 message: "no reduction dimension specified or found in the target operation");
429 return DiagnosedSilenceableFailure::definiteFailure();
430 }
431
432 SmallVector<int64_t> reductionTileSizes =
433 extractFromIntegerArrayAttr<int64_t>(attr: getTileSizes());
434 if (reductionTileSizes.size() != reductionDims.size()) {
435 emitOpError(
436 message: "missing tile sizes for reduction dimensions that are to be tiled");
437 return DiagnosedSilenceableFailure::definiteFailure();
438 }
439
440 // Adjust tile sizes so that it corresponds to the reduction iterator types.
441 SmallVector<OpFoldResult> tileSizes;
442 int reductionTileSizeNum = 0;
443 OpFoldResult zero = rewriter.getIndexAttr(value: 0);
444 for (auto iterator : target.getLoopIteratorTypes()) {
445 if (iterator == utils::IteratorType::parallel) {
446 tileSizes.push_back(Elt: zero);
447 continue;
448 }
449 tileSizes.push_back(
450 Elt: rewriter.getIndexAttr(value: reductionTileSizes[reductionTileSizeNum++]));
451 }
452
453 scf::SCFTilingOptions tilingOptions;
454 tilingOptions.setTileSizes(tileSizes)
455 .setLoopType(scf::SCFTilingOptions::LoopType::ForallOp)
456 .setReductionTilingStrategy(
457 ReductionTilingStrategy::PartialReductionOuterParallel)
458 .setReductionDims(reductionDims);
459 if (auto mapping = getMapping()) {
460 tilingOptions.setMapping(getMapping().value());
461 }
462
463 LogicalResult result = applyTileAndFuseToAll(
464 rewriter, transformOp: getOperation(), payloadOps: state.getPayloadOps(value: getRootOp()),
465 /*numLoops =*/1, tilingOptions, transformResults);
466
467 return failed(Result: result) ? DiagnosedSilenceableFailure::definiteFailure()
468 : DiagnosedSilenceableFailure::success();
469}
470
471#define GET_OP_CLASSES
472#include "TestTilingInterfaceTransformOps.cpp.inc"
473
474namespace {
475class TestTilingInterfaceDialectExtension
476 : public transform::TransformDialectExtension<
477 TestTilingInterfaceDialectExtension> {
478public:
479 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
480 TestTilingInterfaceDialectExtension)
481
482 using Base::Base;
483
484 void init() {
485 declareDependentDialect<affine::AffineDialect>();
486 declareDependentDialect<index::IndexDialect>();
487 declareDependentDialect<scf::SCFDialect>();
488 declareDependentDialect<tensor::TensorDialect>();
489
490 registerTransformOps<
491#define GET_OP_LIST
492#include "TestTilingInterfaceTransformOps.cpp.inc"
493 >();
494 }
495};
496} // namespace
497
498namespace test {
499void registerTestTilingInterfaceTransformDialectExtension(
500 DialectRegistry &registry) {
501 registry.addExtensions<TestTilingInterfaceDialectExtension>();
502}
503} // namespace test
504

source code of mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp