1 | //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" |
10 | |
11 | #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" |
12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
13 | #include "mlir/Dialect/Affine/LoopUtils.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
18 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
19 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
20 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
21 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
22 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
23 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
24 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
25 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
26 | #include "mlir/IR/BuiltinAttributes.h" |
27 | #include "mlir/IR/Dominance.h" |
28 | #include "mlir/IR/OpDefinition.h" |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::affine; |
32 | |
33 | //===----------------------------------------------------------------------===// |
34 | // Apply...PatternsOp |
35 | //===----------------------------------------------------------------------===// |
36 | |
37 | void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns( |
38 | RewritePatternSet &patterns) { |
39 | scf::populateSCFForLoopCanonicalizationPatterns(patterns); |
40 | } |
41 | |
42 | void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns( |
43 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
44 | scf::populateSCFStructuralTypeConversions(typeConverter, patterns); |
45 | } |
46 | |
47 | void transform::ApplySCFStructuralConversionPatternsOp:: |
48 | populateConversionTargetRules(const TypeConverter &typeConverter, |
49 | ConversionTarget &conversionTarget) { |
50 | scf::populateSCFStructuralTypeConversionTarget(typeConverter, |
51 | conversionTarget); |
52 | } |
53 | |
54 | void transform::ApplySCFToControlFlowPatternsOp::populatePatterns( |
55 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
56 | populateSCFToControlFlowConversionPatterns(patterns); |
57 | } |
58 | |
59 | //===----------------------------------------------------------------------===// |
60 | // ForallToForOp |
61 | //===----------------------------------------------------------------------===// |
62 | |
63 | DiagnosedSilenceableFailure |
64 | transform::ForallToForOp::apply(transform::TransformRewriter &rewriter, |
65 | transform::TransformResults &results, |
66 | transform::TransformState &state) { |
67 | auto payload = state.getPayloadOps(getTarget()); |
68 | if (!llvm::hasSingleElement(payload)) |
69 | return emitSilenceableError() << "expected a single payload op" ; |
70 | |
71 | auto target = dyn_cast<scf::ForallOp>(*payload.begin()); |
72 | if (!target) { |
73 | DiagnosedSilenceableFailure diag = |
74 | emitSilenceableError() << "expected the payload to be scf.forall" ; |
75 | diag.attachNote((*payload.begin())->getLoc()) << "payload op" ; |
76 | return diag; |
77 | } |
78 | |
79 | if (!target.getOutputs().empty()) { |
80 | return emitSilenceableError() |
81 | << "unsupported shared outputs (didn't bufferize?)" ; |
82 | } |
83 | |
84 | SmallVector<OpFoldResult> lbs = target.getMixedLowerBound(); |
85 | |
86 | if (getNumResults() != lbs.size()) { |
87 | DiagnosedSilenceableFailure diag = |
88 | emitSilenceableError() |
89 | << "op expects as many results (" << getNumResults() |
90 | << ") as payload has induction variables (" << lbs.size() << ")" ; |
91 | diag.attachNote(target.getLoc()) << "payload op" ; |
92 | return diag; |
93 | } |
94 | |
95 | SmallVector<Operation *> opResults; |
96 | if (failed(scf::forallToForLoop(rewriter, target, &opResults))) { |
97 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
98 | << "failed to convert forall into for" ; |
99 | return diag; |
100 | } |
101 | |
102 | for (auto &&[i, res] : llvm::enumerate(opResults)) { |
103 | results.set(cast<OpResult>(getTransformed()[i]), {res}); |
104 | } |
105 | return DiagnosedSilenceableFailure::success(); |
106 | } |
107 | |
108 | //===----------------------------------------------------------------------===// |
109 | // ForallToForOp |
110 | //===----------------------------------------------------------------------===// |
111 | |
112 | DiagnosedSilenceableFailure |
113 | transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter, |
114 | transform::TransformResults &results, |
115 | transform::TransformState &state) { |
116 | auto payload = state.getPayloadOps(getTarget()); |
117 | if (!llvm::hasSingleElement(payload)) |
118 | return emitSilenceableError() << "expected a single payload op" ; |
119 | |
120 | auto target = dyn_cast<scf::ForallOp>(*payload.begin()); |
121 | if (!target) { |
122 | DiagnosedSilenceableFailure diag = |
123 | emitSilenceableError() << "expected the payload to be scf.forall" ; |
124 | diag.attachNote((*payload.begin())->getLoc()) << "payload op" ; |
125 | return diag; |
126 | } |
127 | |
128 | if (!target.getOutputs().empty()) { |
129 | return emitSilenceableError() |
130 | << "unsupported shared outputs (didn't bufferize?)" ; |
131 | } |
132 | |
133 | if (getNumResults() != 1) { |
134 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
135 | << "op expects one result, given " |
136 | << getNumResults(); |
137 | diag.attachNote(target.getLoc()) << "payload op" ; |
138 | return diag; |
139 | } |
140 | |
141 | scf::ParallelOp opResult; |
142 | if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) { |
143 | DiagnosedSilenceableFailure diag = |
144 | emitSilenceableError() << "failed to convert forall into parallel" ; |
145 | return diag; |
146 | } |
147 | |
148 | results.set(cast<OpResult>(getTransformed()[0]), {opResult}); |
149 | return DiagnosedSilenceableFailure::success(); |
150 | } |
151 | |
152 | //===----------------------------------------------------------------------===// |
153 | // LoopOutlineOp |
154 | //===----------------------------------------------------------------------===// |
155 | |
156 | /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses |
157 | /// the provided rewriter for all operations to remain compatible with the |
158 | /// rewriting infra, as opposed to just splicing the op in place. |
159 | static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, |
160 | Operation *op) { |
161 | if (op->getNumRegions() != 1) |
162 | return nullptr; |
163 | OpBuilder::InsertionGuard g(b); |
164 | b.setInsertionPoint(op); |
165 | scf::ExecuteRegionOp executeRegionOp = |
166 | b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes()); |
167 | { |
168 | OpBuilder::InsertionGuard g(b); |
169 | b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); |
170 | Operation *clonedOp = b.cloneWithoutRegions(op&: *op); |
171 | Region &clonedRegion = clonedOp->getRegions().front(); |
172 | assert(clonedRegion.empty() && "expected empty region" ); |
173 | b.inlineRegionBefore(region&: op->getRegions().front(), parent&: clonedRegion, |
174 | before: clonedRegion.end()); |
175 | b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults()); |
176 | } |
177 | b.replaceOp(op, executeRegionOp.getResults()); |
178 | return executeRegionOp; |
179 | } |
180 | |
181 | DiagnosedSilenceableFailure |
182 | transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter, |
183 | transform::TransformResults &results, |
184 | transform::TransformState &state) { |
185 | SmallVector<Operation *> functions; |
186 | SmallVector<Operation *> calls; |
187 | DenseMap<Operation *, SymbolTable> symbolTables; |
188 | for (Operation *target : state.getPayloadOps(getTarget())) { |
189 | Location location = target->getLoc(); |
190 | Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); |
191 | scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); |
192 | if (!exec) { |
193 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
194 | << "failed to outline" ; |
195 | diag.attachNote(target->getLoc()) << "target op" ; |
196 | return diag; |
197 | } |
198 | func::CallOp call; |
199 | FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion( |
200 | rewriter, location, exec.getRegion(), getFuncName(), &call); |
201 | |
202 | if (failed(outlined)) |
203 | return emitDefaultDefiniteFailure(target); |
204 | |
205 | if (symbolTableOp) { |
206 | SymbolTable &symbolTable = |
207 | symbolTables.try_emplace(symbolTableOp, symbolTableOp) |
208 | .first->getSecond(); |
209 | symbolTable.insert(*outlined); |
210 | call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); |
211 | } |
212 | functions.push_back(*outlined); |
213 | calls.push_back(call); |
214 | } |
215 | results.set(cast<OpResult>(getFunction()), functions); |
216 | results.set(cast<OpResult>(getCall()), calls); |
217 | return DiagnosedSilenceableFailure::success(); |
218 | } |
219 | |
220 | //===----------------------------------------------------------------------===// |
221 | // LoopPeelOp |
222 | //===----------------------------------------------------------------------===// |
223 | |
224 | DiagnosedSilenceableFailure |
225 | transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter, |
226 | scf::ForOp target, |
227 | transform::ApplyToEachResultList &results, |
228 | transform::TransformState &state) { |
229 | scf::ForOp result; |
230 | if (getPeelFront()) { |
231 | LogicalResult status = |
232 | scf::peelForLoopFirstIteration(rewriter, target, result); |
233 | if (failed(status)) { |
234 | DiagnosedSilenceableFailure diag = |
235 | emitSilenceableError() << "failed to peel the first iteration" ; |
236 | return diag; |
237 | } |
238 | } else { |
239 | LogicalResult status = |
240 | scf::peelForLoopAndSimplifyBounds(rewriter, target, result); |
241 | if (failed(status)) { |
242 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
243 | << "failed to peel the last iteration" ; |
244 | return diag; |
245 | } |
246 | } |
247 | |
248 | results.push_back(target); |
249 | results.push_back(result); |
250 | |
251 | return DiagnosedSilenceableFailure::success(); |
252 | } |
253 | |
254 | //===----------------------------------------------------------------------===// |
255 | // LoopPipelineOp |
256 | //===----------------------------------------------------------------------===// |
257 | |
258 | /// Callback for PipeliningOption. Populates `schedule` with the mapping from an |
259 | /// operation to its logical time position given the iteration interval and the |
260 | /// read latency. The latter is only relevant for vector transfers. |
261 | static void |
262 | loopScheduling(scf::ForOp forOp, |
263 | std::vector<std::pair<Operation *, unsigned>> &schedule, |
264 | unsigned iterationInterval, unsigned readLatency) { |
265 | auto getLatency = [&](Operation *op) -> unsigned { |
266 | if (isa<vector::TransferReadOp>(Val: op)) |
267 | return readLatency; |
268 | return 1; |
269 | }; |
270 | |
271 | std::optional<int64_t> ubConstant = |
272 | getConstantIntValue(forOp.getUpperBound()); |
273 | std::optional<int64_t> lbConstant = |
274 | getConstantIntValue(forOp.getLowerBound()); |
275 | DenseMap<Operation *, unsigned> opCycles; |
276 | std::map<unsigned, std::vector<Operation *>> wrappedSchedule; |
277 | for (Operation &op : forOp.getBody()->getOperations()) { |
278 | if (isa<scf::YieldOp>(op)) |
279 | continue; |
280 | unsigned earlyCycle = 0; |
281 | for (Value operand : op.getOperands()) { |
282 | Operation *def = operand.getDefiningOp(); |
283 | if (!def) |
284 | continue; |
285 | if (ubConstant && lbConstant) { |
286 | unsigned ubInt = ubConstant.value(); |
287 | unsigned lbInt = lbConstant.value(); |
288 | auto minLatency = std::min(ubInt - lbInt - 1, getLatency(def)); |
289 | earlyCycle = std::max(earlyCycle, opCycles[def] + minLatency); |
290 | } else { |
291 | earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); |
292 | } |
293 | } |
294 | opCycles[&op] = earlyCycle; |
295 | wrappedSchedule[earlyCycle % iterationInterval].push_back(&op); |
296 | } |
297 | for (const auto &it : wrappedSchedule) { |
298 | for (Operation *op : it.second) { |
299 | unsigned cycle = opCycles[op]; |
300 | schedule.emplace_back(args&: op, args: cycle / iterationInterval); |
301 | } |
302 | } |
303 | } |
304 | |
305 | DiagnosedSilenceableFailure |
306 | transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter, |
307 | scf::ForOp target, |
308 | transform::ApplyToEachResultList &results, |
309 | transform::TransformState &state) { |
310 | scf::PipeliningOption options; |
311 | options.getScheduleFn = |
312 | [this](scf::ForOp forOp, |
313 | std::vector<std::pair<Operation *, unsigned>> &schedule) mutable { |
314 | loopScheduling(forOp, schedule, getIterationInterval(), |
315 | getReadLatency()); |
316 | }; |
317 | scf::ForLoopPipeliningPattern pattern(options, target->getContext()); |
318 | rewriter.setInsertionPoint(target); |
319 | FailureOr<scf::ForOp> patternResult = |
320 | scf::pipelineForLoop(rewriter, target, options); |
321 | if (succeeded(patternResult)) { |
322 | results.push_back(*patternResult); |
323 | return DiagnosedSilenceableFailure::success(); |
324 | } |
325 | return emitDefaultSilenceableFailure(target); |
326 | } |
327 | |
328 | //===----------------------------------------------------------------------===// |
329 | // LoopPromoteIfOneIterationOp |
330 | //===----------------------------------------------------------------------===// |
331 | |
332 | DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne( |
333 | transform::TransformRewriter &rewriter, LoopLikeOpInterface target, |
334 | transform::ApplyToEachResultList &results, |
335 | transform::TransformState &state) { |
336 | (void)target.promoteIfSingleIteration(rewriter); |
337 | return DiagnosedSilenceableFailure::success(); |
338 | } |
339 | |
340 | void transform::LoopPromoteIfOneIterationOp::getEffects( |
341 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
342 | consumesHandle(getTargetMutable(), effects); |
343 | modifiesPayload(effects); |
344 | } |
345 | |
346 | //===----------------------------------------------------------------------===// |
347 | // LoopUnrollOp |
348 | //===----------------------------------------------------------------------===// |
349 | |
350 | DiagnosedSilenceableFailure |
351 | transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, |
352 | Operation *op, |
353 | transform::ApplyToEachResultList &results, |
354 | transform::TransformState &state) { |
355 | LogicalResult result(failure()); |
356 | if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) |
357 | result = loopUnrollByFactor(scfFor, getFactor()); |
358 | else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op)) |
359 | result = loopUnrollByFactor(affineFor, getFactor()); |
360 | else |
361 | return emitSilenceableError() |
362 | << "failed to unroll, incorrect type of payload" ; |
363 | |
364 | if (failed(result)) |
365 | return emitSilenceableError() << "failed to unroll" ; |
366 | |
367 | return DiagnosedSilenceableFailure::success(); |
368 | } |
369 | |
370 | //===----------------------------------------------------------------------===// |
371 | // LoopUnrollAndJamOp |
372 | //===----------------------------------------------------------------------===// |
373 | |
374 | DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne( |
375 | transform::TransformRewriter &rewriter, Operation *op, |
376 | transform::ApplyToEachResultList &results, |
377 | transform::TransformState &state) { |
378 | LogicalResult result(failure()); |
379 | if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) |
380 | result = loopUnrollJamByFactor(scfFor, getFactor()); |
381 | else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op)) |
382 | result = loopUnrollJamByFactor(affineFor, getFactor()); |
383 | else |
384 | return emitSilenceableError() |
385 | << "failed to unroll and jam, incorrect type of payload" ; |
386 | |
387 | if (failed(result)) |
388 | return emitSilenceableError() << "failed to unroll and jam" ; |
389 | |
390 | return DiagnosedSilenceableFailure::success(); |
391 | } |
392 | |
393 | //===----------------------------------------------------------------------===// |
394 | // LoopCoalesceOp |
395 | //===----------------------------------------------------------------------===// |
396 | |
397 | DiagnosedSilenceableFailure |
398 | transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter, |
399 | Operation *op, |
400 | transform::ApplyToEachResultList &results, |
401 | transform::TransformState &state) { |
402 | LogicalResult result(failure()); |
403 | if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op)) |
404 | result = coalescePerfectlyNestedSCFForLoops(scfForOp); |
405 | else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op)) |
406 | result = coalescePerfectlyNestedAffineLoops(affineForOp); |
407 | |
408 | results.push_back(op); |
409 | if (failed(result)) { |
410 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
411 | << "failed to coalesce" ; |
412 | return diag; |
413 | } |
414 | return DiagnosedSilenceableFailure::success(); |
415 | } |
416 | |
417 | //===----------------------------------------------------------------------===// |
418 | // TakeAssumedBranchOp |
419 | //===----------------------------------------------------------------------===// |
420 | /// Replaces the given op with the contents of the given single-block region, |
421 | /// using the operands of the block terminator to replace operation results. |
422 | static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, |
423 | Region ®ion) { |
424 | assert(llvm::hasSingleElement(region) && "expected single-region block" ); |
425 | Block *block = ®ion.front(); |
426 | Operation *terminator = block->getTerminator(); |
427 | ValueRange results = terminator->getOperands(); |
428 | rewriter.inlineBlockBefore(source: block, op, /*blockArgs=*/argValues: {}); |
429 | rewriter.replaceOp(op, newValues: results); |
430 | rewriter.eraseOp(op: terminator); |
431 | } |
432 | |
433 | DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne( |
434 | transform::TransformRewriter &rewriter, scf::IfOp ifOp, |
435 | transform::ApplyToEachResultList &results, |
436 | transform::TransformState &state) { |
437 | rewriter.setInsertionPoint(ifOp); |
438 | Region ®ion = |
439 | getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion(); |
440 | if (!llvm::hasSingleElement(region)) { |
441 | return emitDefiniteFailure() |
442 | << "requires an scf.if op with a single-block " |
443 | << ((getTakeElseBranch()) ? "`else`" : "`then`" ) << " region" ; |
444 | } |
445 | replaceOpWithRegion(rewriter, ifOp, region); |
446 | return DiagnosedSilenceableFailure::success(); |
447 | } |
448 | |
449 | void transform::TakeAssumedBranchOp::getEffects( |
450 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
451 | onlyReadsHandle(getTargetMutable(), effects); |
452 | modifiesPayload(effects); |
453 | } |
454 | |
455 | //===----------------------------------------------------------------------===// |
456 | // LoopFuseSiblingOp |
457 | //===----------------------------------------------------------------------===// |
458 | |
459 | /// Check if `target` and `source` are siblings, in the context that `target` |
460 | /// is being fused into `source`. |
461 | /// |
462 | /// This is a simple check that just checks if both operations are in the same |
463 | /// block and some checks to ensure that the fused IR does not violate |
464 | /// dominance. |
465 | static DiagnosedSilenceableFailure isOpSibling(Operation *target, |
466 | Operation *source) { |
467 | // Check if both operations are same. |
468 | if (target == source) |
469 | return emitSilenceableFailure(op: source) |
470 | << "target and source need to be different loops" ; |
471 | |
472 | // Check if both operations are in the same block. |
473 | if (target->getBlock() != source->getBlock()) |
474 | return emitSilenceableFailure(op: source) |
475 | << "target and source are not in the same block" ; |
476 | |
477 | // Check if fusion will violate dominance. |
478 | DominanceInfo domInfo(source); |
479 | if (target->isBeforeInBlock(other: source)) { |
480 | // Since `target` is before `source`, all users of results of `target` |
481 | // need to be dominated by `source`. |
482 | for (Operation *user : target->getUsers()) { |
483 | if (!domInfo.properlyDominates(a: source, b: user, /*enclosingOpOk=*/false)) { |
484 | return emitSilenceableFailure(op: target) |
485 | << "user of results of target should be properly dominated by " |
486 | "source" ; |
487 | } |
488 | } |
489 | } else { |
490 | // Since `target` is after `source`, all values used by `target` need |
491 | // to dominate `source`. |
492 | |
493 | // Check if operands of `target` are dominated by `source`. |
494 | for (Value operand : target->getOperands()) { |
495 | Operation *operandOp = operand.getDefiningOp(); |
496 | // Operands without defining operations are block arguments. When `target` |
497 | // and `source` occur in the same block, these operands dominate `source`. |
498 | if (!operandOp) |
499 | continue; |
500 | |
501 | // Operand's defining operation should properly dominate `source`. |
502 | if (!domInfo.properlyDominates(a: operandOp, b: source, |
503 | /*enclosingOpOk=*/false)) |
504 | return emitSilenceableFailure(op: target) |
505 | << "operands of target should be properly dominated by source" ; |
506 | } |
507 | |
508 | // Check if values used by `target` are dominated by `source`. |
509 | bool failed = false; |
510 | OpOperand *failedValue = nullptr; |
511 | visitUsedValuesDefinedAbove(regions: target->getRegions(), callback: [&](OpOperand *operand) { |
512 | Operation *operandOp = operand->get().getDefiningOp(); |
513 | if (operandOp && !domInfo.properlyDominates(a: operandOp, b: source, |
514 | /*enclosingOpOk=*/false)) { |
515 | // `operand` is not an argument of an enclosing block and the defining |
516 | // op of `operand` is outside `target` but does not dominate `source`. |
517 | failed = true; |
518 | failedValue = operand; |
519 | } |
520 | }); |
521 | |
522 | if (failed) |
523 | return emitSilenceableFailure(op: failedValue->getOwner()) |
524 | << "values used inside regions of target should be properly " |
525 | "dominated by source" ; |
526 | } |
527 | |
528 | return DiagnosedSilenceableFailure::success(); |
529 | } |
530 | |
531 | /// Check if `target` scf.forall can be fused into `source` scf.forall. |
532 | /// |
533 | /// This simply checks if both loops have the same bounds, steps and mapping. |
534 | /// No attempt is made at checking that the side effects of `target` and |
535 | /// `source` are independent of each other. |
536 | static bool isForallWithIdenticalConfiguration(Operation *target, |
537 | Operation *source) { |
538 | auto targetOp = dyn_cast<scf::ForallOp>(target); |
539 | auto sourceOp = dyn_cast<scf::ForallOp>(source); |
540 | if (!targetOp || !sourceOp) |
541 | return false; |
542 | |
543 | return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && |
544 | targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && |
545 | targetOp.getMixedStep() == sourceOp.getMixedStep() && |
546 | targetOp.getMapping() == sourceOp.getMapping(); |
547 | } |
548 | |
549 | /// Check if `target` scf.for can be fused into `source` scf.for. |
550 | /// |
551 | /// This simply checks if both loops have the same bounds and steps. No attempt |
552 | /// is made at checking that the side effects of `target` and `source` are |
553 | /// independent of each other. |
554 | static bool isForWithIdenticalConfiguration(Operation *target, |
555 | Operation *source) { |
556 | auto targetOp = dyn_cast<scf::ForOp>(target); |
557 | auto sourceOp = dyn_cast<scf::ForOp>(source); |
558 | if (!targetOp || !sourceOp) |
559 | return false; |
560 | |
561 | return targetOp.getLowerBound() == sourceOp.getLowerBound() && |
562 | targetOp.getUpperBound() == sourceOp.getUpperBound() && |
563 | targetOp.getStep() == sourceOp.getStep(); |
564 | } |
565 | |
566 | DiagnosedSilenceableFailure |
567 | transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, |
568 | transform::TransformResults &results, |
569 | transform::TransformState &state) { |
570 | auto targetOps = state.getPayloadOps(getTarget()); |
571 | auto sourceOps = state.getPayloadOps(getSource()); |
572 | |
573 | if (!llvm::hasSingleElement(targetOps) || |
574 | !llvm::hasSingleElement(sourceOps)) { |
575 | return emitDefiniteFailure() |
576 | << "requires exactly one target handle (got " |
577 | << llvm::range_size(targetOps) << ") and exactly one " |
578 | << "source handle (got " << llvm::range_size(sourceOps) << ")" ; |
579 | } |
580 | |
581 | Operation *target = *targetOps.begin(); |
582 | Operation *source = *sourceOps.begin(); |
583 | |
584 | // Check if the target and source are siblings. |
585 | DiagnosedSilenceableFailure diag = isOpSibling(target, source); |
586 | if (!diag.succeeded()) |
587 | return diag; |
588 | |
589 | Operation *fusedLoop; |
590 | /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. |
591 | if (isForWithIdenticalConfiguration(target, source)) { |
592 | fusedLoop = fuseIndependentSiblingForLoops( |
593 | cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter); |
594 | } else if (isForallWithIdenticalConfiguration(target, source)) { |
595 | fusedLoop = fuseIndependentSiblingForallLoops( |
596 | cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter); |
597 | } else { |
598 | return emitSilenceableFailure(target->getLoc()) |
599 | << "operations cannot be fused" ; |
600 | } |
601 | |
602 | assert(fusedLoop && "failed to fuse operations" ); |
603 | |
604 | results.set(cast<OpResult>(getFusedLoop()), {fusedLoop}); |
605 | return DiagnosedSilenceableFailure::success(); |
606 | } |
607 | |
608 | //===----------------------------------------------------------------------===// |
609 | // Transform op registration |
610 | //===----------------------------------------------------------------------===// |
611 | |
612 | namespace { |
613 | class SCFTransformDialectExtension |
614 | : public transform::TransformDialectExtension< |
615 | SCFTransformDialectExtension> { |
616 | public: |
617 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension) |
618 | |
619 | using Base::Base; |
620 | |
621 | void init() { |
622 | declareGeneratedDialect<affine::AffineDialect>(); |
623 | declareGeneratedDialect<func::FuncDialect>(); |
624 | |
625 | registerTransformOps< |
626 | #define GET_OP_LIST |
627 | #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" |
628 | >(); |
629 | } |
630 | }; |
631 | } // namespace |
632 | |
633 | #define GET_OP_CLASSES |
634 | #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" |
635 | |
636 | void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) { |
637 | registry.addExtensions<SCFTransformDialectExtension>(); |
638 | } |
639 | |