1 | //===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines transform dialect operations used for testing |
10 | // TilingInterface |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
16 | #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" |
17 | #include "mlir/Dialect/Transform/IR/TransformAttrs.h" |
18 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
19 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
20 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
21 | #include "mlir/IR/Dominance.h" |
22 | #include "mlir/IR/OpImplementation.h" |
23 | #include "mlir/Interfaces/TilingInterface.h" |
24 | |
25 | #define GET_OP_CLASSES |
26 | #include "TestTilingInterfaceTransformOps.h.inc" |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::transform; |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // TestFuseAndYieldOp |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) { |
36 | SmallVector<Operation *> worklist; |
37 | llvm::SmallDenseSet<Operation *> producers; |
38 | worklist.push_back(op); |
39 | producers.insert(op); |
40 | while (!worklist.empty()) { |
41 | Operation *current = worklist.pop_back_val(); |
42 | for (OpOperand &operand : current->getOpOperands()) { |
43 | Operation *producer = operand.get().getDefiningOp(); |
44 | if (!producer || !isa<TilingInterface>(producer) || |
45 | producers.contains(producer)) |
46 | continue; |
47 | worklist.push_back(producer); |
48 | producers.insert(producer); |
49 | } |
50 | } |
51 | return producers; |
52 | } |
53 | |
54 | /// Apply a tile and fuse transformation to all payload ops and store both the |
55 | /// tiled operation as well as the created tile loops. |
56 | template <typename Range> |
57 | static LogicalResult |
58 | applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, |
59 | Range &&payloadOps, unsigned numLoops, |
60 | ArrayRef<OpFoldResult> tileSizes, |
61 | ArrayRef<int64_t> interchange, bool useForall, |
62 | TransformResults &transformResults) { |
63 | SmallVector<Operation *> tiledOps; |
64 | SmallVector<SmallVector<Operation *>> loopOps(numLoops); |
65 | |
66 | for (Operation *target : payloadOps) { |
67 | auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
68 | if (!tilingInterfaceOp) |
69 | return transformOp->emitError(message: "only TilingInterface ops are supported" ); |
70 | DominanceInfo dominanceInfo(tilingInterfaceOp); |
71 | |
72 | llvm::SmallDenseSet<Operation *> tiledAndFusedOps = |
73 | collectTiledAndFusedOps(tilingInterfaceOp); |
74 | llvm::DenseSet<Operation *> yieldReplacementsFor; |
75 | for (auto op : tiledAndFusedOps) { |
76 | if (llvm::any_of(op->getUsers(), [&](Operation *user) { |
77 | return dominanceInfo.properlyDominates(tilingInterfaceOp, user); |
78 | })) { |
79 | yieldReplacementsFor.insert(op); |
80 | } |
81 | } |
82 | |
83 | scf::SCFTilingOptions tilingOptions; |
84 | tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); |
85 | if (useForall) { |
86 | tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); |
87 | } |
88 | |
89 | scf::SCFTileAndFuseOptions tileAndFuseOptions; |
90 | tileAndFuseOptions.setTilingOptions(tilingOptions); |
91 | |
92 | scf::SCFTileAndFuseOptions::ControlFnTy controlFn = |
93 | [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, |
94 | bool isDestinationOperand) { |
95 | Operation *owner = originalProducer.getOwner(); |
96 | bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); |
97 | return std::make_tuple(true, yieldProducerReplacement); |
98 | }; |
99 | tileAndFuseOptions.setFusionControlFn(controlFn); |
100 | |
101 | rewriter.setInsertionPoint(target); |
102 | FailureOr<scf::SCFTileAndFuseResult> tiledResults = |
103 | scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, |
104 | tileAndFuseOptions); |
105 | if (failed(tiledResults)) |
106 | return failure(); |
107 | |
108 | // Perform the replacement of tiled and fused values. |
109 | SmallVector<Operation *> opsToReplace{target}; |
110 | llvm::append_range(opsToReplace, tiledResults->fusedProducers); |
111 | for (Operation *toReplace : opsToReplace) { |
112 | for (OpResult res : toReplace->getResults()) |
113 | if (auto replacement = tiledResults->replacements.lookup(res)) { |
114 | Operation *replacementOp = replacement.getDefiningOp(); |
115 | rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) { |
116 | Operation *user = use.getOwner(); |
117 | return dominanceInfo.properlyDominates(a: replacementOp, b: user) && |
118 | user->getParentOp() == replacementOp->getParentOp(); |
119 | }); |
120 | } |
121 | |
122 | if (toReplace->use_empty()) { |
123 | rewriter.eraseOp(op: toReplace); |
124 | } |
125 | } |
126 | |
127 | // Report back the relevant handles to the transform op. |
128 | tiledOps.push_back(Elt: tiledResults->tiledAndFusedOps.front()); |
129 | assert(tiledResults->loops.size() == numLoops && |
130 | "Mismatched number of loops, tile and fuse transform should have " |
131 | "failed" ); |
132 | for (unsigned int i = 0; i < numLoops; ++i) |
133 | loopOps[i].push_back(Elt: tiledResults->loops[i]); |
134 | } |
135 | |
136 | transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledOps); |
137 | for (unsigned int i = 0; i < numLoops; ++i) |
138 | transformResults.set(value: transformOp->getOpResult(idx: i + 1), ops&: loopOps[i]); |
139 | |
140 | return success(); |
141 | } |
142 | |
143 | DiagnosedSilenceableFailure |
144 | transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, |
145 | TransformResults &transformResults, |
146 | TransformState &state) { |
147 | SmallVector<int64_t> tileSizes = |
148 | extractFromIntegerArrayAttr<int64_t>(getTileSizes()); |
149 | SmallVector<int64_t> tileInterchange = |
150 | extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); |
151 | |
152 | SmallVector<OpFoldResult> tileSizesOfr = |
153 | getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); |
154 | |
155 | LogicalResult result = applyTileAndFuseToAll( |
156 | rewriter, getOperation(), state.getPayloadOps(getTarget()), |
157 | tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, |
158 | tileInterchange, getUseForall(), transformResults); |
159 | return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
160 | : DiagnosedSilenceableFailure::success(); |
161 | } |
162 | |
163 | //===----------------------------------------------------------------------===// |
164 | // TestTileUsingForallOp |
165 | //===----------------------------------------------------------------------===// |
166 | |
167 | /// Apply a tiling transformation to all payload ops and store both the |
168 | /// tiled operation as well as the created tile loops. |
169 | template <typename Range> |
170 | static LogicalResult |
171 | applyTileToAll(RewriterBase &rewriter, Operation *transformOp, |
172 | Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes, |
173 | ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping, |
174 | TransformResults &transformResults) { |
175 | SmallVector<Operation *> tiledOps; |
176 | SmallVector<Operation *> loopOps; |
177 | |
178 | for (Operation *target : payloadOps) { |
179 | auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
180 | if (!tilingInterfaceOp) |
181 | return transformOp->emitError(message: "only TilingInterface ops are supported" ); |
182 | scf::SCFTilingOptions tilingOptions; |
183 | tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); |
184 | if (mapping) { |
185 | auto mappingAttrs = |
186 | llvm::map_to_vector(mapping.value(), [](Attribute attr) { |
187 | return cast<DeviceMappingAttrInterface>(attr); |
188 | }); |
189 | tilingOptions.setMapping(mappingAttrs); |
190 | } |
191 | tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); |
192 | |
193 | rewriter.setInsertionPoint(target); |
194 | FailureOr<scf::SCFTilingResult> tiledResults = |
195 | scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions); |
196 | if (failed(tiledResults)) |
197 | return failure(); |
198 | |
199 | // Perform the replacement of tiled and fused values. |
200 | rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements); |
201 | |
202 | // Report back the relevant handles to the transform op. |
203 | tiledOps.push_back(Elt: tiledResults->tiledOps.front()); |
204 | for (Operation *loop : tiledResults->loops) |
205 | loopOps.push_back(loop); |
206 | } |
207 | |
208 | transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledOps); |
209 | for (auto [index, loop] : llvm::enumerate(First&: loopOps)) |
210 | transformResults.set(value: transformOp->getOpResult(idx: index + 1), ops: {loop}); |
211 | |
212 | return success(); |
213 | } |
214 | |
215 | DiagnosedSilenceableFailure |
216 | transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter, |
217 | TransformResults &transformResults, |
218 | TransformState &state) { |
219 | SmallVector<int64_t> tileSizes = |
220 | extractFromIntegerArrayAttr<int64_t>(getTileSizes()); |
221 | SmallVector<int64_t> interchange = |
222 | extractFromIntegerArrayAttr<int64_t>(getInterchange()); |
223 | SmallVector<OpFoldResult> tileSizesOfr = |
224 | getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); |
225 | |
226 | LogicalResult result = |
227 | applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()), |
228 | tileSizesOfr, interchange, getMapping(), transformResults); |
229 | return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
230 | : DiagnosedSilenceableFailure::success(); |
231 | } |
232 | |
233 | void transform::TestTileUsingForallOp::getEffects( |
234 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
235 | consumesHandle(getTarget(), effects); |
236 | producesHandle(getTiledOp(), effects); |
237 | producesHandle(getLoops(), effects); |
238 | modifiesPayload(effects); |
239 | } |
240 | |
241 | //===----------------------------------------------------------------------===// |
242 | // TestFuseUsingForallOp |
243 | //===----------------------------------------------------------------------===// |
244 | |
245 | /// Apply a tiling transformation to all payload ops and store both the |
246 | /// tiled operation as well as the created tile loops. |
247 | template <typename Range> |
248 | static LogicalResult applyTilingToAll( |
249 | RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, |
250 | unsigned numLoops, TransformResults &transformResults, |
251 | function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)> |
252 | applyFn) { |
253 | SmallVector<Operation *> tiledLinalgOps; |
254 | SmallVector<SmallVector<Operation *>> loopOps(1); |
255 | |
256 | for (Operation *target : payloadOps) { |
257 | auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
258 | if (!tilingInterfaceOp) |
259 | return transformOp->emitError(message: "only TilingInterface ops are supported" ); |
260 | |
261 | rewriter.setInsertionPoint(target); |
262 | FailureOr<scf::SCFTileAndFuseResult> tiledResults = |
263 | applyFn(tilingInterfaceOp); |
264 | if (failed(result: tiledResults)) |
265 | return failure(); |
266 | |
267 | // Perform the replacement of tiled and fused values. |
268 | SmallVector<Operation *> opsToReplace{target}; |
269 | llvm::append_range(opsToReplace, tiledResults->fusedProducers); |
270 | for (Operation *toReplace : opsToReplace) { |
271 | for (OpResult res : toReplace->getResults()) |
272 | if (auto replacement = tiledResults->replacements.lookup(res)) |
273 | rewriter.replaceAllUsesWith(res, replacement); |
274 | if (toReplace->use_empty()) |
275 | rewriter.eraseOp(op: toReplace); |
276 | } |
277 | |
278 | // Report back the relevant handles to the transform op. |
279 | tiledLinalgOps.push_back(Elt: tiledResults->tiledAndFusedOps.front()); |
280 | assert(tiledResults->loops.size() == 1 && |
281 | cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops && |
282 | "Mismatched number of loops, tile and fuse transform should have " |
283 | "failed" ); |
284 | loopOps[0] = {tiledResults->loops[0]}; |
285 | } |
286 | |
287 | transformResults.set(value: transformOp->getOpResult(idx: 0), ops&: tiledLinalgOps); |
288 | if (!loopOps.empty()) |
289 | transformResults.set(value: transformOp->getOpResult(idx: 1), ops&: loopOps[0]); |
290 | |
291 | return success(); |
292 | } |
293 | |
294 | DiagnosedSilenceableFailure |
295 | transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter, |
296 | TransformResults &transformResults, |
297 | TransformState &state) { |
298 | SmallVector<int64_t> tileSizes = |
299 | extractFromIntegerArrayAttr<int64_t>(getTileSizes()); |
300 | SmallVector<int64_t> tileInterchange = |
301 | extractFromIntegerArrayAttr<int64_t>(getInterchange()); |
302 | |
303 | scf::SCFTilingOptions tilingOptions; |
304 | tilingOptions.interchangeVector = tileInterchange; |
305 | SmallVector<OpFoldResult> tileSizesOfr = |
306 | getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); |
307 | tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); |
308 | tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); |
309 | scf::SCFTileAndFuseOptions tileAndFuseOptions; |
310 | tileAndFuseOptions.tilingOptions = tilingOptions; |
311 | LogicalResult result = applyTilingToAll( |
312 | rewriter, getOperation(), state.getPayloadOps(getRootOp()), |
313 | tileSizes.size() - llvm::count(tileSizes, 0), transformResults, |
314 | [&](TilingInterface tilingInterfaceOp) |
315 | -> FailureOr<scf::SCFTileAndFuseResult> { |
316 | return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, |
317 | tileAndFuseOptions); |
318 | }); |
319 | return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
320 | : DiagnosedSilenceableFailure::success(); |
321 | } |
322 | |
323 | void transform::TestFuseUsingForallOp::getEffects( |
324 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
325 | consumesHandle(getRootOp(), effects); |
326 | producesHandle(getTiledOps(), effects); |
327 | producesHandle(getLoops(), effects); |
328 | modifiesPayload(effects); |
329 | } |
330 | |
331 | #define GET_OP_CLASSES |
332 | #include "TestTilingInterfaceTransformOps.cpp.inc" |
333 | |
334 | namespace { |
335 | class TestTilingInterfaceDialectExtension |
336 | : public transform::TransformDialectExtension< |
337 | TestTilingInterfaceDialectExtension> { |
338 | public: |
339 | using Base::Base; |
340 | |
341 | void init() { |
342 | declareDependentDialect<affine::AffineDialect>(); |
343 | declareDependentDialect<index::IndexDialect>(); |
344 | declareDependentDialect<scf::SCFDialect>(); |
345 | declareDependentDialect<tensor::TensorDialect>(); |
346 | |
347 | registerTransformOps< |
348 | #define GET_OP_LIST |
349 | #include "TestTilingInterfaceTransformOps.cpp.inc" |
350 | >(); |
351 | } |
352 | }; |
353 | } // namespace |
354 | |
355 | namespace test { |
356 | void registerTestTilingInterfaceTransformDialectExtension( |
357 | DialectRegistry ®istry) { |
358 | registry.addExtensions<TestTilingInterfaceDialectExtension>(); |
359 | } |
360 | } // namespace test |
361 | |