1//===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines transform dialect operations used for testing
10// TilingInterface
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Index/IR/IndexDialect.h"
16#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
17#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
18#include "mlir/Dialect/Transform/IR/TransformDialect.h"
19#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/IR/Dominance.h"
22#include "mlir/IR/OpImplementation.h"
23#include "mlir/Interfaces/TilingInterface.h"
24
25#define GET_OP_CLASSES
26#include "TestTilingInterfaceTransformOps.h.inc"
27
28using namespace mlir;
29using namespace mlir::transform;
30
31//===----------------------------------------------------------------------===//
32// TestFuseAndYieldOp
33//===----------------------------------------------------------------------===//
34
35static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
36 SmallVector<Operation *> worklist;
37 llvm::SmallDenseSet<Operation *> producers;
38 worklist.push_back(op);
39 producers.insert(op);
40 while (!worklist.empty()) {
41 Operation *current = worklist.pop_back_val();
42 for (OpOperand &operand : current->getOpOperands()) {
43 Operation *producer = operand.get().getDefiningOp();
44 if (!producer || !isa<TilingInterface>(producer) ||
45 producers.contains(producer))
46 continue;
47 worklist.push_back(producer);
48 producers.insert(producer);
49 }
50 }
51 return producers;
52}
53
54/// Apply a tile and fuse transformation to all payload ops and store both the
55/// tiled operation as well as the created tile loops.
56template <typename Range>
57static LogicalResult
58applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
59 Range &&payloadOps, unsigned numLoops,
60 ArrayRef<OpFoldResult> tileSizes,
61 ArrayRef<int64_t> interchange, bool useForall,
62 TransformResults &transformResults) {
63 SmallVector<Operation *> tiledOps;
64 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
65
66 for (Operation *target : payloadOps) {
67 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
68 if (!tilingInterfaceOp)
69 return transformOp->emitError(message: "only TilingInterface ops are supported");
70 DominanceInfo dominanceInfo(tilingInterfaceOp);
71
72 llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
73 collectTiledAndFusedOps(tilingInterfaceOp);
74 llvm::DenseSet<Operation *> yieldReplacementsFor;
75 for (auto op : tiledAndFusedOps) {
76 if (llvm::any_of(op->getUsers(), [&](Operation *user) {
77 return dominanceInfo.properlyDominates(tilingInterfaceOp, user);
78 })) {
79 yieldReplacementsFor.insert(op);
80 }
81 }
82
83 scf::SCFTilingOptions tilingOptions;
84 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
85 if (useForall) {
86 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
87 }
88
89 scf::SCFTileAndFuseOptions tileAndFuseOptions;
90 tileAndFuseOptions.setTilingOptions(tilingOptions);
91
92 scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
93 [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
94 bool isDestinationOperand) {
95 Operation *owner = originalProducer.getOwner();
96 bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
97 return std::make_tuple(true, yieldProducerReplacement);
98 };
99 tileAndFuseOptions.setFusionControlFn(controlFn);
100
101 rewriter.setInsertionPoint(target);
102 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
103 scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
104 tileAndFuseOptions);
105 if (failed(tiledResults))
106 return failure();
107
108 // Perform the replacement of tiled and fused values.
109 SmallVector<Operation *> opsToReplace{target};
110 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
111 for (Operation *toReplace : opsToReplace) {
112 for (OpResult res : toReplace->getResults())
113 if (auto replacement = tiledResults->replacements.lookup(res)) {
114 Operation *replacementOp = replacement.getDefiningOp();
115 rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
116 Operation *user = use.getOwner();
117 return dominanceInfo.properlyDominates(a: replacementOp, b: user) &&
118 user->getParentOp() == replacementOp->getParentOp();
119 });
120 }
121
122 if (toReplace->use_empty()) {
123 rewriter.eraseOp(op: toReplace);
124 }
125 }
126
127 // Report back the relevant handles to the transform op.
128 tiledOps.push_back(Elt: tiledResults->tiledAndFusedOps.front());
129 assert(tiledResults->loops.size() == numLoops &&
130 "Mismatched number of loops, tile and fuse transform should have "
131 "failed");
132 for (unsigned int i = 0; i < numLoops; ++i)
133 loopOps[i].push_back(Elt: tiledResults->loops[i]);
134 }
135
136 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledOps);
137 for (unsigned int i = 0; i < numLoops; ++i)
138 transformResults.set(value: transformOp->getOpResult(idx: i + 1), ops&: loopOps[i]);
139
140 return success();
141}
142
143DiagnosedSilenceableFailure
144transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
145 TransformResults &transformResults,
146 TransformState &state) {
147 SmallVector<int64_t> tileSizes =
148 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
149 SmallVector<int64_t> tileInterchange =
150 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
151
152 SmallVector<OpFoldResult> tileSizesOfr =
153 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
154
155 LogicalResult result = applyTileAndFuseToAll(
156 rewriter, getOperation(), state.getPayloadOps(getTarget()),
157 tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
158 tileInterchange, getUseForall(), transformResults);
159 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
160 : DiagnosedSilenceableFailure::success();
161}
162
163//===----------------------------------------------------------------------===//
164// TestTileUsingForallOp
165//===----------------------------------------------------------------------===//
166
167/// Apply a tiling transformation to all payload ops and store both the
168/// tiled operation as well as the created tile loops.
169template <typename Range>
170static LogicalResult
171applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
172 Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
173 ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
174 TransformResults &transformResults) {
175 SmallVector<Operation *> tiledOps;
176 SmallVector<Operation *> loopOps;
177
178 for (Operation *target : payloadOps) {
179 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
180 if (!tilingInterfaceOp)
181 return transformOp->emitError(message: "only TilingInterface ops are supported");
182 scf::SCFTilingOptions tilingOptions;
183 tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
184 if (mapping) {
185 auto mappingAttrs =
186 llvm::map_to_vector(mapping.value(), [](Attribute attr) {
187 return cast<DeviceMappingAttrInterface>(attr);
188 });
189 tilingOptions.setMapping(mappingAttrs);
190 }
191 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
192
193 rewriter.setInsertionPoint(target);
194 FailureOr<scf::SCFTilingResult> tiledResults =
195 scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
196 if (failed(tiledResults))
197 return failure();
198
199 // Perform the replacement of tiled and fused values.
200 rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);
201
202 // Report back the relevant handles to the transform op.
203 tiledOps.push_back(Elt: tiledResults->tiledOps.front());
204 for (Operation *loop : tiledResults->loops)
205 loopOps.push_back(loop);
206 }
207
208 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledOps);
209 for (auto [index, loop] : llvm::enumerate(First&: loopOps))
210 transformResults.set(value: transformOp->getOpResult(idx: index + 1), ops: {loop});
211
212 return success();
213}
214
215DiagnosedSilenceableFailure
216transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
217 TransformResults &transformResults,
218 TransformState &state) {
219 SmallVector<int64_t> tileSizes =
220 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
221 SmallVector<int64_t> interchange =
222 extractFromIntegerArrayAttr<int64_t>(getInterchange());
223 SmallVector<OpFoldResult> tileSizesOfr =
224 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
225
226 LogicalResult result =
227 applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()),
228 tileSizesOfr, interchange, getMapping(), transformResults);
229 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
230 : DiagnosedSilenceableFailure::success();
231}
232
233void transform::TestTileUsingForallOp::getEffects(
234 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
235 consumesHandle(getTarget(), effects);
236 producesHandle(getTiledOp(), effects);
237 producesHandle(getLoops(), effects);
238 modifiesPayload(effects);
239}
240
241//===----------------------------------------------------------------------===//
242// TestFuseUsingForallOp
243//===----------------------------------------------------------------------===//
244
245/// Apply a tiling transformation to all payload ops and store both the
246/// tiled operation as well as the created tile loops.
247template <typename Range>
248static LogicalResult applyTilingToAll(
249 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
250 unsigned numLoops, TransformResults &transformResults,
251 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
252 applyFn) {
253 SmallVector<Operation *> tiledLinalgOps;
254 SmallVector<SmallVector<Operation *>> loopOps(1);
255
256 for (Operation *target : payloadOps) {
257 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
258 if (!tilingInterfaceOp)
259 return transformOp->emitError(message: "only TilingInterface ops are supported");
260
261 rewriter.setInsertionPoint(target);
262 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
263 applyFn(tilingInterfaceOp);
264 if (failed(result: tiledResults))
265 return failure();
266
267 // Perform the replacement of tiled and fused values.
268 SmallVector<Operation *> opsToReplace{target};
269 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
270 for (Operation *toReplace : opsToReplace) {
271 for (OpResult res : toReplace->getResults())
272 if (auto replacement = tiledResults->replacements.lookup(res))
273 rewriter.replaceAllUsesWith(res, replacement);
274 if (toReplace->use_empty())
275 rewriter.eraseOp(op: toReplace);
276 }
277
278 // Report back the relevant handles to the transform op.
279 tiledLinalgOps.push_back(Elt: tiledResults->tiledAndFusedOps.front());
280 assert(tiledResults->loops.size() == 1 &&
281 cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
282 "Mismatched number of loops, tile and fuse transform should have "
283 "failed");
284 loopOps[0] = {tiledResults->loops[0]};
285 }
286
287 transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledLinalgOps);
288 if (!loopOps.empty())
289 transformResults.set(value: transformOp->getOpResult(idx: 1), ops&: loopOps[0]);
290
291 return success();
292}
293
294DiagnosedSilenceableFailure
295transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
296 TransformResults &transformResults,
297 TransformState &state) {
298 SmallVector<int64_t> tileSizes =
299 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
300 SmallVector<int64_t> tileInterchange =
301 extractFromIntegerArrayAttr<int64_t>(getInterchange());
302
303 scf::SCFTilingOptions tilingOptions;
304 tilingOptions.interchangeVector = tileInterchange;
305 SmallVector<OpFoldResult> tileSizesOfr =
306 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
307 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
308 tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
309 scf::SCFTileAndFuseOptions tileAndFuseOptions;
310 tileAndFuseOptions.tilingOptions = tilingOptions;
311 LogicalResult result = applyTilingToAll(
312 rewriter, getOperation(), state.getPayloadOps(getRootOp()),
313 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
314 [&](TilingInterface tilingInterfaceOp)
315 -> FailureOr<scf::SCFTileAndFuseResult> {
316 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
317 tileAndFuseOptions);
318 });
319 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
320 : DiagnosedSilenceableFailure::success();
321}
322
323void transform::TestFuseUsingForallOp::getEffects(
324 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
325 consumesHandle(getRootOp(), effects);
326 producesHandle(getTiledOps(), effects);
327 producesHandle(getLoops(), effects);
328 modifiesPayload(effects);
329}
330
331#define GET_OP_CLASSES
332#include "TestTilingInterfaceTransformOps.cpp.inc"
333
334namespace {
335class TestTilingInterfaceDialectExtension
336 : public transform::TransformDialectExtension<
337 TestTilingInterfaceDialectExtension> {
338public:
339 using Base::Base;
340
341 void init() {
342 declareDependentDialect<affine::AffineDialect>();
343 declareDependentDialect<index::IndexDialect>();
344 declareDependentDialect<scf::SCFDialect>();
345 declareDependentDialect<tensor::TensorDialect>();
346
347 registerTransformOps<
348#define GET_OP_LIST
349#include "TestTilingInterfaceTransformOps.cpp.inc"
350 >();
351 }
352};
353} // namespace
354
355namespace test {
356void registerTestTilingInterfaceTransformDialectExtension(
357 DialectRegistry &registry) {
358 registry.addExtensions<TestTilingInterfaceDialectExtension>();
359}
360} // namespace test
361

source code of mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp