1//===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
10#include "mlir/Dialect/Affine/IR/AffineOps.h"
11#include "mlir/Dialect/Affine/LoopUtils.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/SCF/IR/SCF.h"
16#include "mlir/Dialect/SCF/Transforms/Patterns.h"
17#include "mlir/Dialect/SCF/Transforms/Transforms.h"
18#include "mlir/Dialect/SCF/Utils/Utils.h"
19#include "mlir/Dialect/Transform/IR/TransformDialect.h"
20#include "mlir/Dialect/Transform/IR/TransformOps.h"
21#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
22#include "mlir/Dialect/Utils/StaticValueUtils.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/IR/BuiltinAttributes.h"
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/OpDefinition.h"
27
28using namespace mlir;
29using namespace mlir::affine;
30
31//===----------------------------------------------------------------------===//
32// Apply...PatternsOp
33//===----------------------------------------------------------------------===//
34
35void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
36 RewritePatternSet &patterns) {
37 scf::populateSCFForLoopCanonicalizationPatterns(patterns);
38}
39
40void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
41 TypeConverter &typeConverter, RewritePatternSet &patterns) {
42 scf::populateSCFStructuralTypeConversions(typeConverter, patterns);
43}
44
45void transform::ApplySCFStructuralConversionPatternsOp::
46 populateConversionTargetRules(const TypeConverter &typeConverter,
47 ConversionTarget &conversionTarget) {
48 scf::populateSCFStructuralTypeConversionTarget(typeConverter,
49 conversionTarget);
50}
51
52//===----------------------------------------------------------------------===//
53// ForallToForOp
54//===----------------------------------------------------------------------===//
55
56DiagnosedSilenceableFailure
57transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
58 transform::TransformResults &results,
59 transform::TransformState &state) {
60 auto payload = state.getPayloadOps(getTarget());
61 if (!llvm::hasSingleElement(payload))
62 return emitSilenceableError() << "expected a single payload op";
63
64 auto target = dyn_cast<scf::ForallOp>(*payload.begin());
65 if (!target) {
66 DiagnosedSilenceableFailure diag =
67 emitSilenceableError() << "expected the payload to be scf.forall";
68 diag.attachNote((*payload.begin())->getLoc()) << "payload op";
69 return diag;
70 }
71
72 if (!target.getOutputs().empty()) {
73 return emitSilenceableError()
74 << "unsupported shared outputs (didn't bufferize?)";
75 }
76
77 SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
78
79 if (getNumResults() != lbs.size()) {
80 DiagnosedSilenceableFailure diag =
81 emitSilenceableError()
82 << "op expects as many results (" << getNumResults()
83 << ") as payload has induction variables (" << lbs.size() << ")";
84 diag.attachNote(target.getLoc()) << "payload op";
85 return diag;
86 }
87
88 SmallVector<Operation *> opResults;
89 if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
90 DiagnosedSilenceableFailure diag = emitSilenceableError()
91 << "failed to convert forall into for";
92 return diag;
93 }
94
95 for (auto &&[i, res] : llvm::enumerate(opResults)) {
96 results.set(cast<OpResult>(getTransformed()[i]), {res});
97 }
98 return DiagnosedSilenceableFailure::success();
99}
100
101//===----------------------------------------------------------------------===//
102// LoopOutlineOp
103//===----------------------------------------------------------------------===//
104
105/// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
106/// the provided rewriter for all operations to remain compatible with the
107/// rewriting infra, as opposed to just splicing the op in place.
108static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
109 Operation *op) {
110 if (op->getNumRegions() != 1)
111 return nullptr;
112 OpBuilder::InsertionGuard g(b);
113 b.setInsertionPoint(op);
114 scf::ExecuteRegionOp executeRegionOp =
115 b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
116 {
117 OpBuilder::InsertionGuard g(b);
118 b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
119 Operation *clonedOp = b.cloneWithoutRegions(op&: *op);
120 Region &clonedRegion = clonedOp->getRegions().front();
121 assert(clonedRegion.empty() && "expected empty region");
122 b.inlineRegionBefore(region&: op->getRegions().front(), parent&: clonedRegion,
123 before: clonedRegion.end());
124 b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
125 }
126 b.replaceOp(op, executeRegionOp.getResults());
127 return executeRegionOp;
128}
129
130DiagnosedSilenceableFailure
131transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter,
132 transform::TransformResults &results,
133 transform::TransformState &state) {
134 SmallVector<Operation *> functions;
135 SmallVector<Operation *> calls;
136 DenseMap<Operation *, SymbolTable> symbolTables;
137 for (Operation *target : state.getPayloadOps(getTarget())) {
138 Location location = target->getLoc();
139 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
140 scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
141 if (!exec) {
142 DiagnosedSilenceableFailure diag = emitSilenceableError()
143 << "failed to outline";
144 diag.attachNote(target->getLoc()) << "target op";
145 return diag;
146 }
147 func::CallOp call;
148 FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
149 rewriter, location, exec.getRegion(), getFuncName(), &call);
150
151 if (failed(outlined))
152 return emitDefaultDefiniteFailure(target);
153
154 if (symbolTableOp) {
155 SymbolTable &symbolTable =
156 symbolTables.try_emplace(symbolTableOp, symbolTableOp)
157 .first->getSecond();
158 symbolTable.insert(*outlined);
159 call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
160 }
161 functions.push_back(*outlined);
162 calls.push_back(call);
163 }
164 results.set(cast<OpResult>(getFunction()), functions);
165 results.set(cast<OpResult>(getCall()), calls);
166 return DiagnosedSilenceableFailure::success();
167}
168
169//===----------------------------------------------------------------------===//
170// LoopPeelOp
171//===----------------------------------------------------------------------===//
172
173DiagnosedSilenceableFailure
174transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
175 scf::ForOp target,
176 transform::ApplyToEachResultList &results,
177 transform::TransformState &state) {
178 scf::ForOp result;
179 if (getPeelFront()) {
180 LogicalResult status =
181 scf::peelForLoopFirstIteration(rewriter, target, result);
182 if (failed(status)) {
183 DiagnosedSilenceableFailure diag =
184 emitSilenceableError() << "failed to peel the first iteration";
185 return diag;
186 }
187 } else {
188 LogicalResult status =
189 scf::peelForLoopAndSimplifyBounds(rewriter, target, result);
190 if (failed(status)) {
191 DiagnosedSilenceableFailure diag = emitSilenceableError()
192 << "failed to peel the last iteration";
193 return diag;
194 }
195 }
196
197 results.push_back(target);
198 results.push_back(result);
199
200 return DiagnosedSilenceableFailure::success();
201}
202
203//===----------------------------------------------------------------------===//
204// LoopPipelineOp
205//===----------------------------------------------------------------------===//
206
207/// Callback for PipeliningOption. Populates `schedule` with the mapping from an
208/// operation to its logical time position given the iteration interval and the
209/// read latency. The latter is only relevant for vector transfers.
210static void
211loopScheduling(scf::ForOp forOp,
212 std::vector<std::pair<Operation *, unsigned>> &schedule,
213 unsigned iterationInterval, unsigned readLatency) {
214 auto getLatency = [&](Operation *op) -> unsigned {
215 if (isa<vector::TransferReadOp>(Val: op))
216 return readLatency;
217 return 1;
218 };
219
220 DenseMap<Operation *, unsigned> opCycles;
221 std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
222 for (Operation &op : forOp.getBody()->getOperations()) {
223 if (isa<scf::YieldOp>(op))
224 continue;
225 unsigned earlyCycle = 0;
226 for (Value operand : op.getOperands()) {
227 Operation *def = operand.getDefiningOp();
228 if (!def)
229 continue;
230 earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
231 }
232 opCycles[&op] = earlyCycle;
233 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
234 }
235 for (const auto &it : wrappedSchedule) {
236 for (Operation *op : it.second) {
237 unsigned cycle = opCycles[op];
238 schedule.emplace_back(args&: op, args: cycle / iterationInterval);
239 }
240 }
241}
242
243DiagnosedSilenceableFailure
244transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter,
245 scf::ForOp target,
246 transform::ApplyToEachResultList &results,
247 transform::TransformState &state) {
248 scf::PipeliningOption options;
249 options.getScheduleFn =
250 [this](scf::ForOp forOp,
251 std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
252 loopScheduling(forOp, schedule, getIterationInterval(),
253 getReadLatency());
254 };
255 scf::ForLoopPipeliningPattern pattern(options, target->getContext());
256 rewriter.setInsertionPoint(target);
257 FailureOr<scf::ForOp> patternResult =
258 scf::pipelineForLoop(rewriter, target, options);
259 if (succeeded(patternResult)) {
260 results.push_back(*patternResult);
261 return DiagnosedSilenceableFailure::success();
262 }
263 return emitDefaultSilenceableFailure(target);
264}
265
266//===----------------------------------------------------------------------===//
267// LoopPromoteIfOneIterationOp
268//===----------------------------------------------------------------------===//
269
270DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne(
271 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
272 transform::ApplyToEachResultList &results,
273 transform::TransformState &state) {
274 (void)target.promoteIfSingleIteration(rewriter);
275 return DiagnosedSilenceableFailure::success();
276}
277
278void transform::LoopPromoteIfOneIterationOp::getEffects(
279 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
280 consumesHandle(getTarget(), effects);
281 modifiesPayload(effects);
282}
283
284//===----------------------------------------------------------------------===//
285// LoopUnrollOp
286//===----------------------------------------------------------------------===//
287
288DiagnosedSilenceableFailure
289transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
290 Operation *op,
291 transform::ApplyToEachResultList &results,
292 transform::TransformState &state) {
293 LogicalResult result(failure());
294 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
295 result = loopUnrollByFactor(scfFor, getFactor());
296 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
297 result = loopUnrollByFactor(affineFor, getFactor());
298
299 if (failed(result)) {
300 DiagnosedSilenceableFailure diag = emitSilenceableError()
301 << "failed to unroll";
302 return diag;
303 }
304 return DiagnosedSilenceableFailure::success();
305}
306
307//===----------------------------------------------------------------------===//
308// LoopCoalesceOp
309//===----------------------------------------------------------------------===//
310
311DiagnosedSilenceableFailure
312transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
313 Operation *op,
314 transform::ApplyToEachResultList &results,
315 transform::TransformState &state) {
316 LogicalResult result(failure());
317 if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
318 result = coalescePerfectlyNestedSCFForLoops(scfForOp);
319 else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
320 result = coalescePerfectlyNestedAffineLoops(affineForOp);
321
322 results.push_back(op);
323 if (failed(result)) {
324 DiagnosedSilenceableFailure diag = emitSilenceableError()
325 << "failed to coalesce";
326 return diag;
327 }
328 return DiagnosedSilenceableFailure::success();
329}
330
331//===----------------------------------------------------------------------===//
332// TakeAssumedBranchOp
333//===----------------------------------------------------------------------===//
334/// Replaces the given op with the contents of the given single-block region,
335/// using the operands of the block terminator to replace operation results.
336static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op,
337 Region &region) {
338 assert(llvm::hasSingleElement(region) && "expected single-region block");
339 Block *block = &region.front();
340 Operation *terminator = block->getTerminator();
341 ValueRange results = terminator->getOperands();
342 rewriter.inlineBlockBefore(source: block, op, /*blockArgs=*/argValues: {});
343 rewriter.replaceOp(op, newValues: results);
344 rewriter.eraseOp(op: terminator);
345}
346
347DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne(
348 transform::TransformRewriter &rewriter, scf::IfOp ifOp,
349 transform::ApplyToEachResultList &results,
350 transform::TransformState &state) {
351 rewriter.setInsertionPoint(ifOp);
352 Region &region =
353 getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
354 if (!llvm::hasSingleElement(region)) {
355 return emitDefiniteFailure()
356 << "requires an scf.if op with a single-block "
357 << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";
358 }
359 replaceOpWithRegion(rewriter, ifOp, region);
360 return DiagnosedSilenceableFailure::success();
361}
362
363void transform::TakeAssumedBranchOp::getEffects(
364 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
365 onlyReadsHandle(getTarget(), effects);
366 modifiesPayload(effects);
367}
368
369//===----------------------------------------------------------------------===//
370// LoopFuseSiblingOp
371//===----------------------------------------------------------------------===//
372
373/// Check if `target` and `source` are siblings, in the context that `target`
374/// is being fused into `source`.
375///
376/// This is a simple check that just checks if both operations are in the same
377/// block and some checks to ensure that the fused IR does not violate
378/// dominance.
379static DiagnosedSilenceableFailure isOpSibling(Operation *target,
380 Operation *source) {
381 // Check if both operations are same.
382 if (target == source)
383 return emitSilenceableFailure(op: source)
384 << "target and source need to be different loops";
385
386 // Check if both operations are in the same block.
387 if (target->getBlock() != source->getBlock())
388 return emitSilenceableFailure(op: source)
389 << "target and source are not in the same block";
390
391 // Check if fusion will violate dominance.
392 DominanceInfo domInfo(source);
393 if (target->isBeforeInBlock(other: source)) {
394 // Since `target` is before `source`, all users of results of `target`
395 // need to be dominated by `source`.
396 for (Operation *user : target->getUsers()) {
397 if (!domInfo.properlyDominates(a: source, b: user, /*enclosingOpOk=*/false)) {
398 return emitSilenceableFailure(op: target)
399 << "user of results of target should be properly dominated by "
400 "source";
401 }
402 }
403 } else {
404 // Since `target` is after `source`, all values used by `target` need
405 // to dominate `source`.
406
407 // Check if operands of `target` are dominated by `source`.
408 for (Value operand : target->getOperands()) {
409 Operation *operandOp = operand.getDefiningOp();
410 // Operands without defining operations are block arguments. When `target`
411 // and `source` occur in the same block, these operands dominate `source`.
412 if (!operandOp)
413 continue;
414
415 // Operand's defining operation should properly dominate `source`.
416 if (!domInfo.properlyDominates(a: operandOp, b: source,
417 /*enclosingOpOk=*/false))
418 return emitSilenceableFailure(op: target)
419 << "operands of target should be properly dominated by source";
420 }
421
422 // Check if values used by `target` are dominated by `source`.
423 bool failed = false;
424 OpOperand *failedValue = nullptr;
425 visitUsedValuesDefinedAbove(regions: target->getRegions(), callback: [&](OpOperand *operand) {
426 Operation *operandOp = operand->get().getDefiningOp();
427 if (operandOp && !domInfo.properlyDominates(a: operandOp, b: source,
428 /*enclosingOpOk=*/false)) {
429 // `operand` is not an argument of an enclosing block and the defining
430 // op of `operand` is outside `target` but does not dominate `source`.
431 failed = true;
432 failedValue = operand;
433 }
434 });
435
436 if (failed)
437 return emitSilenceableFailure(op: failedValue->getOwner())
438 << "values used inside regions of target should be properly "
439 "dominated by source";
440 }
441
442 return DiagnosedSilenceableFailure::success();
443}
444
445/// Check if `target` scf.forall can be fused into `source` scf.forall.
446///
447/// This simply checks if both loops have the same bounds, steps and mapping.
448/// No attempt is made at checking that the side effects of `target` and
449/// `source` are independent of each other.
450static bool isForallWithIdenticalConfiguration(Operation *target,
451 Operation *source) {
452 auto targetOp = dyn_cast<scf::ForallOp>(target);
453 auto sourceOp = dyn_cast<scf::ForallOp>(source);
454 if (!targetOp || !sourceOp)
455 return false;
456
457 return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
458 targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
459 targetOp.getMixedStep() == sourceOp.getMixedStep() &&
460 targetOp.getMapping() == sourceOp.getMapping();
461}
462
463/// Check if `target` scf.for can be fused into `source` scf.for.
464///
465/// This simply checks if both loops have the same bounds and steps. No attempt
466/// is made at checking that the side effects of `target` and `source` are
467/// independent of each other.
468static bool isForWithIdenticalConfiguration(Operation *target,
469 Operation *source) {
470 auto targetOp = dyn_cast<scf::ForOp>(target);
471 auto sourceOp = dyn_cast<scf::ForOp>(source);
472 if (!targetOp || !sourceOp)
473 return false;
474
475 return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
476 targetOp.getUpperBound() == sourceOp.getUpperBound() &&
477 targetOp.getStep() == sourceOp.getStep();
478}
479
480DiagnosedSilenceableFailure
481transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
482 transform::TransformResults &results,
483 transform::TransformState &state) {
484 auto targetOps = state.getPayloadOps(getTarget());
485 auto sourceOps = state.getPayloadOps(getSource());
486
487 if (!llvm::hasSingleElement(targetOps) ||
488 !llvm::hasSingleElement(sourceOps)) {
489 return emitDefiniteFailure()
490 << "requires exactly one target handle (got "
491 << llvm::range_size(targetOps) << ") and exactly one "
492 << "source handle (got " << llvm::range_size(sourceOps) << ")";
493 }
494
495 Operation *target = *targetOps.begin();
496 Operation *source = *sourceOps.begin();
497
498 // Check if the target and source are siblings.
499 DiagnosedSilenceableFailure diag = isOpSibling(target, source);
500 if (!diag.succeeded())
501 return diag;
502
503 Operation *fusedLoop;
504 /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
505 if (isForWithIdenticalConfiguration(target, source)) {
506 fusedLoop = fuseIndependentSiblingForLoops(
507 cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
508 } else if (isForallWithIdenticalConfiguration(target, source)) {
509 fusedLoop = fuseIndependentSiblingForallLoops(
510 cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
511 } else
512 return emitSilenceableFailure(target->getLoc())
513 << "operations cannot be fused";
514
515 assert(fusedLoop && "failed to fuse operations");
516
517 results.set(cast<OpResult>(getFusedLoop()), {fusedLoop});
518 return DiagnosedSilenceableFailure::success();
519}
520
521//===----------------------------------------------------------------------===//
522// Transform op registration
523//===----------------------------------------------------------------------===//
524
525namespace {
526class SCFTransformDialectExtension
527 : public transform::TransformDialectExtension<
528 SCFTransformDialectExtension> {
529public:
530 using Base::Base;
531
532 void init() {
533 declareGeneratedDialect<affine::AffineDialect>();
534 declareGeneratedDialect<func::FuncDialect>();
535
536 registerTransformOps<
537#define GET_OP_LIST
538#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
539 >();
540 }
541};
542} // namespace
543
544#define GET_OP_CLASSES
545#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
546
547void mlir::scf::registerTransformDialectExtension(DialectRegistry &registry) {
548 registry.addExtensions<SCFTransformDialectExtension>();
549}
550

source code of mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp