1//===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
10
11#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Affine/LoopUtils.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/SCF/IR/SCF.h"
16#include "mlir/Dialect/SCF/Transforms/Patterns.h"
17#include "mlir/Dialect/SCF/Transforms/Transforms.h"
18#include "mlir/Dialect/SCF/Utils/Utils.h"
19#include "mlir/Dialect/Transform/IR/TransformDialect.h"
20#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/IR/BuiltinAttributes.h"
24#include "mlir/IR/Dominance.h"
25#include "mlir/IR/OpDefinition.h"
26
27using namespace mlir;
28using namespace mlir::affine;
29
30//===----------------------------------------------------------------------===//
31// Apply...PatternsOp
32//===----------------------------------------------------------------------===//
33
34void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
35 RewritePatternSet &patterns) {
36 scf::populateSCFForLoopCanonicalizationPatterns(patterns);
37}
38
39void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
40 TypeConverter &typeConverter, RewritePatternSet &patterns) {
41 scf::populateSCFStructuralTypeConversions(typeConverter, patterns);
42}
43
44void transform::ApplySCFStructuralConversionPatternsOp::
45 populateConversionTargetRules(const TypeConverter &typeConverter,
46 ConversionTarget &conversionTarget) {
47 scf::populateSCFStructuralTypeConversionTarget(typeConverter,
48 target&: conversionTarget);
49}
50
51void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
52 TypeConverter &typeConverter, RewritePatternSet &patterns) {
53 populateSCFToControlFlowConversionPatterns(patterns);
54}
55
56//===----------------------------------------------------------------------===//
57// ForallToForOp
58//===----------------------------------------------------------------------===//
59
60DiagnosedSilenceableFailure
61transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
62 transform::TransformResults &results,
63 transform::TransformState &state) {
64 auto payload = state.getPayloadOps(value: getTarget());
65 if (!llvm::hasSingleElement(C&: payload))
66 return emitSilenceableError() << "expected a single payload op";
67
68 auto target = dyn_cast<scf::ForallOp>(Val: *payload.begin());
69 if (!target) {
70 DiagnosedSilenceableFailure diag =
71 emitSilenceableError() << "expected the payload to be scf.forall";
72 diag.attachNote(loc: (*payload.begin())->getLoc()) << "payload op";
73 return diag;
74 }
75
76 if (!target.getOutputs().empty()) {
77 return emitSilenceableError()
78 << "unsupported shared outputs (didn't bufferize?)";
79 }
80
81 SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
82
83 if (getNumResults() != lbs.size()) {
84 DiagnosedSilenceableFailure diag =
85 emitSilenceableError()
86 << "op expects as many results (" << getNumResults()
87 << ") as payload has induction variables (" << lbs.size() << ")";
88 diag.attachNote(loc: target.getLoc()) << "payload op";
89 return diag;
90 }
91
92 SmallVector<Operation *> opResults;
93 if (failed(Result: scf::forallToForLoop(rewriter, forallOp: target, results: &opResults))) {
94 DiagnosedSilenceableFailure diag = emitSilenceableError()
95 << "failed to convert forall into for";
96 return diag;
97 }
98
99 for (auto &&[i, res] : llvm::enumerate(First&: opResults)) {
100 results.set(value: cast<OpResult>(Val: getTransformed()[i]), ops: {res});
101 }
102 return DiagnosedSilenceableFailure::success();
103}
104
105//===----------------------------------------------------------------------===//
106// ForallToForOp
107//===----------------------------------------------------------------------===//
108
109DiagnosedSilenceableFailure
110transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
111 transform::TransformResults &results,
112 transform::TransformState &state) {
113 auto payload = state.getPayloadOps(value: getTarget());
114 if (!llvm::hasSingleElement(C&: payload))
115 return emitSilenceableError() << "expected a single payload op";
116
117 auto target = dyn_cast<scf::ForallOp>(Val: *payload.begin());
118 if (!target) {
119 DiagnosedSilenceableFailure diag =
120 emitSilenceableError() << "expected the payload to be scf.forall";
121 diag.attachNote(loc: (*payload.begin())->getLoc()) << "payload op";
122 return diag;
123 }
124
125 if (!target.getOutputs().empty()) {
126 return emitSilenceableError()
127 << "unsupported shared outputs (didn't bufferize?)";
128 }
129
130 if (getNumResults() != 1) {
131 DiagnosedSilenceableFailure diag = emitSilenceableError()
132 << "op expects one result, given "
133 << getNumResults();
134 diag.attachNote(loc: target.getLoc()) << "payload op";
135 return diag;
136 }
137
138 scf::ParallelOp opResult;
139 if (failed(Result: scf::forallToParallelLoop(rewriter, forallOp: target, result: &opResult))) {
140 DiagnosedSilenceableFailure diag =
141 emitSilenceableError() << "failed to convert forall into parallel";
142 return diag;
143 }
144
145 results.set(value: cast<OpResult>(Val: getTransformed()[0]), ops: {opResult});
146 return DiagnosedSilenceableFailure::success();
147}
148
149//===----------------------------------------------------------------------===//
150// LoopOutlineOp
151//===----------------------------------------------------------------------===//
152
153/// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
154/// the provided rewriter for all operations to remain compatible with the
155/// rewriting infra, as opposed to just splicing the op in place.
156static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
157 Operation *op) {
158 if (op->getNumRegions() != 1)
159 return nullptr;
160 OpBuilder::InsertionGuard g(b);
161 b.setInsertionPoint(op);
162 scf::ExecuteRegionOp executeRegionOp =
163 b.create<scf::ExecuteRegionOp>(location: op->getLoc(), args: op->getResultTypes());
164 {
165 OpBuilder::InsertionGuard g(b);
166 b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
167 Operation *clonedOp = b.cloneWithoutRegions(op&: *op);
168 Region &clonedRegion = clonedOp->getRegions().front();
169 assert(clonedRegion.empty() && "expected empty region");
170 b.inlineRegionBefore(region&: op->getRegions().front(), parent&: clonedRegion,
171 before: clonedRegion.end());
172 b.create<scf::YieldOp>(location: op->getLoc(), args: clonedOp->getResults());
173 }
174 b.replaceOp(op, newValues: executeRegionOp.getResults());
175 return executeRegionOp;
176}
177
178DiagnosedSilenceableFailure
179transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter,
180 transform::TransformResults &results,
181 transform::TransformState &state) {
182 SmallVector<Operation *> functions;
183 SmallVector<Operation *> calls;
184 DenseMap<Operation *, SymbolTable> symbolTables;
185 for (Operation *target : state.getPayloadOps(value: getTarget())) {
186 Location location = target->getLoc();
187 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from: target);
188 scf::ExecuteRegionOp exec = wrapInExecuteRegion(b&: rewriter, op: target);
189 if (!exec) {
190 DiagnosedSilenceableFailure diag = emitSilenceableError()
191 << "failed to outline";
192 diag.attachNote(loc: target->getLoc()) << "target op";
193 return diag;
194 }
195 func::CallOp call;
196 FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
197 rewriter, loc: location, region&: exec.getRegion(), funcName: getFuncName(), callOp: &call);
198
199 if (failed(Result: outlined))
200 return emitDefaultDefiniteFailure(target);
201
202 if (symbolTableOp) {
203 SymbolTable &symbolTable =
204 symbolTables.try_emplace(Key: symbolTableOp, Args&: symbolTableOp)
205 .first->getSecond();
206 symbolTable.insert(symbol: *outlined);
207 call.setCalleeAttr(FlatSymbolRefAttr::get(symbol: *outlined));
208 }
209 functions.push_back(Elt: *outlined);
210 calls.push_back(Elt: call);
211 }
212 results.set(value: cast<OpResult>(Val: getFunction()), ops&: functions);
213 results.set(value: cast<OpResult>(Val: getCall()), ops&: calls);
214 return DiagnosedSilenceableFailure::success();
215}
216
217//===----------------------------------------------------------------------===//
218// LoopPeelOp
219//===----------------------------------------------------------------------===//
220
221DiagnosedSilenceableFailure
222transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
223 scf::ForOp target,
224 transform::ApplyToEachResultList &results,
225 transform::TransformState &state) {
226 scf::ForOp result;
227 if (getPeelFront()) {
228 LogicalResult status =
229 scf::peelForLoopFirstIteration(rewriter, forOp: target, partialIteration&: result);
230 if (failed(Result: status)) {
231 DiagnosedSilenceableFailure diag =
232 emitSilenceableError() << "failed to peel the first iteration";
233 return diag;
234 }
235 } else {
236 LogicalResult status =
237 scf::peelForLoopAndSimplifyBounds(rewriter, forOp: target, partialIteration&: result);
238 if (failed(Result: status)) {
239 DiagnosedSilenceableFailure diag = emitSilenceableError()
240 << "failed to peel the last iteration";
241 return diag;
242 }
243 }
244
245 results.push_back(op: target);
246 results.push_back(op: result);
247
248 return DiagnosedSilenceableFailure::success();
249}
250
251//===----------------------------------------------------------------------===//
252// LoopPipelineOp
253//===----------------------------------------------------------------------===//
254
255/// Callback for PipeliningOption. Populates `schedule` with the mapping from an
256/// operation to its logical time position given the iteration interval and the
257/// read latency. The latter is only relevant for vector transfers.
258static void
259loopScheduling(scf::ForOp forOp,
260 std::vector<std::pair<Operation *, unsigned>> &schedule,
261 unsigned iterationInterval, unsigned readLatency) {
262 auto getLatency = [&](Operation *op) -> unsigned {
263 if (isa<vector::TransferReadOp>(Val: op))
264 return readLatency;
265 return 1;
266 };
267
268 std::optional<int64_t> ubConstant =
269 getConstantIntValue(ofr: forOp.getUpperBound());
270 std::optional<int64_t> lbConstant =
271 getConstantIntValue(ofr: forOp.getLowerBound());
272 DenseMap<Operation *, unsigned> opCycles;
273 std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
274 for (Operation &op : forOp.getBody()->getOperations()) {
275 if (isa<scf::YieldOp>(Val: op))
276 continue;
277 unsigned earlyCycle = 0;
278 for (Value operand : op.getOperands()) {
279 Operation *def = operand.getDefiningOp();
280 if (!def)
281 continue;
282 if (ubConstant && lbConstant) {
283 unsigned ubInt = ubConstant.value();
284 unsigned lbInt = lbConstant.value();
285 auto minLatency = std::min(a: ubInt - lbInt - 1, b: getLatency(def));
286 earlyCycle = std::max(a: earlyCycle, b: opCycles[def] + minLatency);
287 } else {
288 earlyCycle = std::max(a: earlyCycle, b: opCycles[def] + getLatency(def));
289 }
290 }
291 opCycles[&op] = earlyCycle;
292 wrappedSchedule[earlyCycle % iterationInterval].push_back(x: &op);
293 }
294 for (const auto &it : wrappedSchedule) {
295 for (Operation *op : it.second) {
296 unsigned cycle = opCycles[op];
297 schedule.emplace_back(args&: op, args: cycle / iterationInterval);
298 }
299 }
300}
301
302DiagnosedSilenceableFailure
303transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter,
304 scf::ForOp target,
305 transform::ApplyToEachResultList &results,
306 transform::TransformState &state) {
307 scf::PipeliningOption options;
308 options.getScheduleFn =
309 [this](scf::ForOp forOp,
310 std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
311 loopScheduling(forOp, schedule, iterationInterval: getIterationInterval(),
312 readLatency: getReadLatency());
313 };
314 scf::ForLoopPipeliningPattern pattern(options, target->getContext());
315 rewriter.setInsertionPoint(target);
316 FailureOr<scf::ForOp> patternResult =
317 scf::pipelineForLoop(rewriter, forOp: target, options);
318 if (succeeded(Result: patternResult)) {
319 results.push_back(op: *patternResult);
320 return DiagnosedSilenceableFailure::success();
321 }
322 return emitDefaultSilenceableFailure(target);
323}
324
325//===----------------------------------------------------------------------===//
326// LoopPromoteIfOneIterationOp
327//===----------------------------------------------------------------------===//
328
329DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne(
330 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
331 transform::ApplyToEachResultList &results,
332 transform::TransformState &state) {
333 (void)target.promoteIfSingleIteration(rewriter);
334 return DiagnosedSilenceableFailure::success();
335}
336
337void transform::LoopPromoteIfOneIterationOp::getEffects(
338 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
339 consumesHandle(handles: getTargetMutable(), effects);
340 modifiesPayload(effects);
341}
342
343//===----------------------------------------------------------------------===//
344// LoopUnrollOp
345//===----------------------------------------------------------------------===//
346
347DiagnosedSilenceableFailure
348transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
349 Operation *op,
350 transform::ApplyToEachResultList &results,
351 transform::TransformState &state) {
352 LogicalResult result(failure());
353 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(Val: op))
354 result = loopUnrollByFactor(forOp: scfFor, unrollFactor: getFactor());
355 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(Val: op))
356 result = loopUnrollByFactor(forOp: affineFor, unrollFactor: getFactor());
357 else
358 return emitSilenceableError()
359 << "failed to unroll, incorrect type of payload";
360
361 if (failed(Result: result))
362 return emitSilenceableError() << "failed to unroll";
363
364 return DiagnosedSilenceableFailure::success();
365}
366
367//===----------------------------------------------------------------------===//
368// LoopUnrollAndJamOp
369//===----------------------------------------------------------------------===//
370
371DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
372 transform::TransformRewriter &rewriter, Operation *op,
373 transform::ApplyToEachResultList &results,
374 transform::TransformState &state) {
375 LogicalResult result(failure());
376 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(Val: op))
377 result = loopUnrollJamByFactor(forOp: scfFor, unrollFactor: getFactor());
378 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(Val: op))
379 result = loopUnrollJamByFactor(forOp: affineFor, unrollJamFactor: getFactor());
380 else
381 return emitSilenceableError()
382 << "failed to unroll and jam, incorrect type of payload";
383
384 if (failed(Result: result))
385 return emitSilenceableError() << "failed to unroll and jam";
386
387 return DiagnosedSilenceableFailure::success();
388}
389
390//===----------------------------------------------------------------------===//
391// LoopCoalesceOp
392//===----------------------------------------------------------------------===//
393
394DiagnosedSilenceableFailure
395transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
396 Operation *op,
397 transform::ApplyToEachResultList &results,
398 transform::TransformState &state) {
399 LogicalResult result(failure());
400 if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(Val: op))
401 result = coalescePerfectlyNestedSCFForLoops(op: scfForOp);
402 else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(Val: op))
403 result = coalescePerfectlyNestedAffineLoops(op: affineForOp);
404
405 results.push_back(op);
406 if (failed(Result: result)) {
407 DiagnosedSilenceableFailure diag = emitSilenceableError()
408 << "failed to coalesce";
409 return diag;
410 }
411 return DiagnosedSilenceableFailure::success();
412}
413
414//===----------------------------------------------------------------------===//
415// TakeAssumedBranchOp
416//===----------------------------------------------------------------------===//
417/// Replaces the given op with the contents of the given single-block region,
418/// using the operands of the block terminator to replace operation results.
419static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op,
420 Region &region) {
421 assert(llvm::hasSingleElement(region) && "expected single-region block");
422 Block *block = &region.front();
423 Operation *terminator = block->getTerminator();
424 ValueRange results = terminator->getOperands();
425 rewriter.inlineBlockBefore(source: block, op, /*blockArgs=*/argValues: {});
426 rewriter.replaceOp(op, newValues: results);
427 rewriter.eraseOp(op: terminator);
428}
429
430DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne(
431 transform::TransformRewriter &rewriter, scf::IfOp ifOp,
432 transform::ApplyToEachResultList &results,
433 transform::TransformState &state) {
434 rewriter.setInsertionPoint(ifOp);
435 Region &region =
436 getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
437 if (!llvm::hasSingleElement(C&: region)) {
438 return emitDefiniteFailure()
439 << "requires an scf.if op with a single-block "
440 << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";
441 }
442 replaceOpWithRegion(rewriter, op: ifOp, region);
443 return DiagnosedSilenceableFailure::success();
444}
445
446void transform::TakeAssumedBranchOp::getEffects(
447 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
448 onlyReadsHandle(handles: getTargetMutable(), effects);
449 modifiesPayload(effects);
450}
451
452//===----------------------------------------------------------------------===//
453// LoopFuseSiblingOp
454//===----------------------------------------------------------------------===//
455
456/// Check if `target` and `source` are siblings, in the context that `target`
457/// is being fused into `source`.
458///
459/// This is a simple check that just checks if both operations are in the same
460/// block and some checks to ensure that the fused IR does not violate
461/// dominance.
462static DiagnosedSilenceableFailure isOpSibling(Operation *target,
463 Operation *source) {
464 // Check if both operations are same.
465 if (target == source)
466 return emitSilenceableFailure(op: source)
467 << "target and source need to be different loops";
468
469 // Check if both operations are in the same block.
470 if (target->getBlock() != source->getBlock())
471 return emitSilenceableFailure(op: source)
472 << "target and source are not in the same block";
473
474 // Check if fusion will violate dominance.
475 DominanceInfo domInfo(source);
476 if (target->isBeforeInBlock(other: source)) {
477 // Since `target` is before `source`, all users of results of `target`
478 // need to be dominated by `source`.
479 for (Operation *user : target->getUsers()) {
480 if (!domInfo.properlyDominates(a: source, b: user, /*enclosingOpOk=*/false)) {
481 return emitSilenceableFailure(op: target)
482 << "user of results of target should be properly dominated by "
483 "source";
484 }
485 }
486 } else {
487 // Since `target` is after `source`, all values used by `target` need
488 // to dominate `source`.
489
490 // Check if operands of `target` are dominated by `source`.
491 for (Value operand : target->getOperands()) {
492 Operation *operandOp = operand.getDefiningOp();
493 // Operands without defining operations are block arguments. When `target`
494 // and `source` occur in the same block, these operands dominate `source`.
495 if (!operandOp)
496 continue;
497
498 // Operand's defining operation should properly dominate `source`.
499 if (!domInfo.properlyDominates(a: operandOp, b: source,
500 /*enclosingOpOk=*/false))
501 return emitSilenceableFailure(op: target)
502 << "operands of target should be properly dominated by source";
503 }
504
505 // Check if values used by `target` are dominated by `source`.
506 bool failed = false;
507 OpOperand *failedValue = nullptr;
508 visitUsedValuesDefinedAbove(regions: target->getRegions(), callback: [&](OpOperand *operand) {
509 Operation *operandOp = operand->get().getDefiningOp();
510 if (operandOp && !domInfo.properlyDominates(a: operandOp, b: source,
511 /*enclosingOpOk=*/false)) {
512 // `operand` is not an argument of an enclosing block and the defining
513 // op of `operand` is outside `target` but does not dominate `source`.
514 failed = true;
515 failedValue = operand;
516 }
517 });
518
519 if (failed)
520 return emitSilenceableFailure(op: failedValue->getOwner())
521 << "values used inside regions of target should be properly "
522 "dominated by source";
523 }
524
525 return DiagnosedSilenceableFailure::success();
526}
527
528/// Check if `target` scf.forall can be fused into `source` scf.forall.
529///
530/// This simply checks if both loops have the same bounds, steps and mapping.
531/// No attempt is made at checking that the side effects of `target` and
532/// `source` are independent of each other.
533static bool isForallWithIdenticalConfiguration(Operation *target,
534 Operation *source) {
535 auto targetOp = dyn_cast<scf::ForallOp>(Val: target);
536 auto sourceOp = dyn_cast<scf::ForallOp>(Val: source);
537 if (!targetOp || !sourceOp)
538 return false;
539
540 return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
541 targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
542 targetOp.getMixedStep() == sourceOp.getMixedStep() &&
543 targetOp.getMapping() == sourceOp.getMapping();
544}
545
546/// Check if `target` scf.for can be fused into `source` scf.for.
547///
548/// This simply checks if both loops have the same bounds and steps. No attempt
549/// is made at checking that the side effects of `target` and `source` are
550/// independent of each other.
551static bool isForWithIdenticalConfiguration(Operation *target,
552 Operation *source) {
553 auto targetOp = dyn_cast<scf::ForOp>(Val: target);
554 auto sourceOp = dyn_cast<scf::ForOp>(Val: source);
555 if (!targetOp || !sourceOp)
556 return false;
557
558 return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
559 targetOp.getUpperBound() == sourceOp.getUpperBound() &&
560 targetOp.getStep() == sourceOp.getStep();
561}
562
563DiagnosedSilenceableFailure
564transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
565 transform::TransformResults &results,
566 transform::TransformState &state) {
567 auto targetOps = state.getPayloadOps(value: getTarget());
568 auto sourceOps = state.getPayloadOps(value: getSource());
569
570 if (!llvm::hasSingleElement(C&: targetOps) ||
571 !llvm::hasSingleElement(C&: sourceOps)) {
572 return emitDefiniteFailure()
573 << "requires exactly one target handle (got "
574 << llvm::range_size(Range&: targetOps) << ") and exactly one "
575 << "source handle (got " << llvm::range_size(Range&: sourceOps) << ")";
576 }
577
578 Operation *target = *targetOps.begin();
579 Operation *source = *sourceOps.begin();
580
581 // Check if the target and source are siblings.
582 DiagnosedSilenceableFailure diag = isOpSibling(target, source);
583 if (!diag.succeeded())
584 return diag;
585
586 Operation *fusedLoop;
587 /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
588 if (isForWithIdenticalConfiguration(target, source)) {
589 fusedLoop = fuseIndependentSiblingForLoops(
590 target: cast<scf::ForOp>(Val: target), source: cast<scf::ForOp>(Val: source), rewriter);
591 } else if (isForallWithIdenticalConfiguration(target, source)) {
592 fusedLoop = fuseIndependentSiblingForallLoops(
593 target: cast<scf::ForallOp>(Val: target), source: cast<scf::ForallOp>(Val: source), rewriter);
594 } else {
595 return emitSilenceableFailure(loc: target->getLoc())
596 << "operations cannot be fused";
597 }
598
599 assert(fusedLoop && "failed to fuse operations");
600
601 results.set(value: cast<OpResult>(Val: getFusedLoop()), ops: {fusedLoop});
602 return DiagnosedSilenceableFailure::success();
603}
604
605//===----------------------------------------------------------------------===//
606// Transform op registration
607//===----------------------------------------------------------------------===//
608
609namespace {
610class SCFTransformDialectExtension
611 : public transform::TransformDialectExtension<
612 SCFTransformDialectExtension> {
613public:
614 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension)
615
616 using Base::Base;
617
618 void init() {
619 declareGeneratedDialect<affine::AffineDialect>();
620 declareGeneratedDialect<func::FuncDialect>();
621
622 registerTransformOps<
623#define GET_OP_LIST
624#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
625 >();
626 }
627};
628} // namespace
629
630#define GET_OP_CLASSES
631#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
632
633void mlir::scf::registerTransformDialectExtension(DialectRegistry &registry) {
634 registry.addExtensions<SCFTransformDialectExtension>();
635}
636

source code of mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp