1//===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
10
11#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Affine/LoopUtils.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/SCF/Transforms/Patterns.h"
19#include "mlir/Dialect/SCF/Transforms/Transforms.h"
20#include "mlir/Dialect/SCF/Utils/Utils.h"
21#include "mlir/Dialect/Transform/IR/TransformDialect.h"
22#include "mlir/Dialect/Transform/IR/TransformOps.h"
23#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
24#include "mlir/Dialect/Utils/StaticValueUtils.h"
25#include "mlir/Dialect/Vector/IR/VectorOps.h"
26#include "mlir/IR/BuiltinAttributes.h"
27#include "mlir/IR/Dominance.h"
28#include "mlir/IR/OpDefinition.h"
29
30using namespace mlir;
31using namespace mlir::affine;
32
33//===----------------------------------------------------------------------===//
34// Apply...PatternsOp
35//===----------------------------------------------------------------------===//
36
37void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
38 RewritePatternSet &patterns) {
39 scf::populateSCFForLoopCanonicalizationPatterns(patterns);
40}
41
42void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
43 TypeConverter &typeConverter, RewritePatternSet &patterns) {
44 scf::populateSCFStructuralTypeConversions(typeConverter, patterns);
45}
46
47void transform::ApplySCFStructuralConversionPatternsOp::
48 populateConversionTargetRules(const TypeConverter &typeConverter,
49 ConversionTarget &conversionTarget) {
50 scf::populateSCFStructuralTypeConversionTarget(typeConverter,
51 conversionTarget);
52}
53
54void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
55 TypeConverter &typeConverter, RewritePatternSet &patterns) {
56 populateSCFToControlFlowConversionPatterns(patterns);
57}
58
59//===----------------------------------------------------------------------===//
60// ForallToForOp
61//===----------------------------------------------------------------------===//
62
63DiagnosedSilenceableFailure
64transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
65 transform::TransformResults &results,
66 transform::TransformState &state) {
67 auto payload = state.getPayloadOps(getTarget());
68 if (!llvm::hasSingleElement(payload))
69 return emitSilenceableError() << "expected a single payload op";
70
71 auto target = dyn_cast<scf::ForallOp>(*payload.begin());
72 if (!target) {
73 DiagnosedSilenceableFailure diag =
74 emitSilenceableError() << "expected the payload to be scf.forall";
75 diag.attachNote((*payload.begin())->getLoc()) << "payload op";
76 return diag;
77 }
78
79 if (!target.getOutputs().empty()) {
80 return emitSilenceableError()
81 << "unsupported shared outputs (didn't bufferize?)";
82 }
83
84 SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
85
86 if (getNumResults() != lbs.size()) {
87 DiagnosedSilenceableFailure diag =
88 emitSilenceableError()
89 << "op expects as many results (" << getNumResults()
90 << ") as payload has induction variables (" << lbs.size() << ")";
91 diag.attachNote(target.getLoc()) << "payload op";
92 return diag;
93 }
94
95 SmallVector<Operation *> opResults;
96 if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
97 DiagnosedSilenceableFailure diag = emitSilenceableError()
98 << "failed to convert forall into for";
99 return diag;
100 }
101
102 for (auto &&[i, res] : llvm::enumerate(opResults)) {
103 results.set(cast<OpResult>(getTransformed()[i]), {res});
104 }
105 return DiagnosedSilenceableFailure::success();
106}
107
108//===----------------------------------------------------------------------===//
109// ForallToForOp
110//===----------------------------------------------------------------------===//
111
112DiagnosedSilenceableFailure
113transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
114 transform::TransformResults &results,
115 transform::TransformState &state) {
116 auto payload = state.getPayloadOps(getTarget());
117 if (!llvm::hasSingleElement(payload))
118 return emitSilenceableError() << "expected a single payload op";
119
120 auto target = dyn_cast<scf::ForallOp>(*payload.begin());
121 if (!target) {
122 DiagnosedSilenceableFailure diag =
123 emitSilenceableError() << "expected the payload to be scf.forall";
124 diag.attachNote((*payload.begin())->getLoc()) << "payload op";
125 return diag;
126 }
127
128 if (!target.getOutputs().empty()) {
129 return emitSilenceableError()
130 << "unsupported shared outputs (didn't bufferize?)";
131 }
132
133 if (getNumResults() != 1) {
134 DiagnosedSilenceableFailure diag = emitSilenceableError()
135 << "op expects one result, given "
136 << getNumResults();
137 diag.attachNote(target.getLoc()) << "payload op";
138 return diag;
139 }
140
141 scf::ParallelOp opResult;
142 if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
143 DiagnosedSilenceableFailure diag =
144 emitSilenceableError() << "failed to convert forall into parallel";
145 return diag;
146 }
147
148 results.set(cast<OpResult>(getTransformed()[0]), {opResult});
149 return DiagnosedSilenceableFailure::success();
150}
151
152//===----------------------------------------------------------------------===//
153// LoopOutlineOp
154//===----------------------------------------------------------------------===//
155
156/// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
157/// the provided rewriter for all operations to remain compatible with the
158/// rewriting infra, as opposed to just splicing the op in place.
159static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
160 Operation *op) {
161 if (op->getNumRegions() != 1)
162 return nullptr;
163 OpBuilder::InsertionGuard g(b);
164 b.setInsertionPoint(op);
165 scf::ExecuteRegionOp executeRegionOp =
166 b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
167 {
168 OpBuilder::InsertionGuard g(b);
169 b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
170 Operation *clonedOp = b.cloneWithoutRegions(op&: *op);
171 Region &clonedRegion = clonedOp->getRegions().front();
172 assert(clonedRegion.empty() && "expected empty region");
173 b.inlineRegionBefore(region&: op->getRegions().front(), parent&: clonedRegion,
174 before: clonedRegion.end());
175 b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
176 }
177 b.replaceOp(op, executeRegionOp.getResults());
178 return executeRegionOp;
179}
180
181DiagnosedSilenceableFailure
182transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter,
183 transform::TransformResults &results,
184 transform::TransformState &state) {
185 SmallVector<Operation *> functions;
186 SmallVector<Operation *> calls;
187 DenseMap<Operation *, SymbolTable> symbolTables;
188 for (Operation *target : state.getPayloadOps(getTarget())) {
189 Location location = target->getLoc();
190 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
191 scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
192 if (!exec) {
193 DiagnosedSilenceableFailure diag = emitSilenceableError()
194 << "failed to outline";
195 diag.attachNote(target->getLoc()) << "target op";
196 return diag;
197 }
198 func::CallOp call;
199 FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
200 rewriter, location, exec.getRegion(), getFuncName(), &call);
201
202 if (failed(outlined))
203 return emitDefaultDefiniteFailure(target);
204
205 if (symbolTableOp) {
206 SymbolTable &symbolTable =
207 symbolTables.try_emplace(symbolTableOp, symbolTableOp)
208 .first->getSecond();
209 symbolTable.insert(*outlined);
210 call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
211 }
212 functions.push_back(*outlined);
213 calls.push_back(call);
214 }
215 results.set(cast<OpResult>(getFunction()), functions);
216 results.set(cast<OpResult>(getCall()), calls);
217 return DiagnosedSilenceableFailure::success();
218}
219
220//===----------------------------------------------------------------------===//
221// LoopPeelOp
222//===----------------------------------------------------------------------===//
223
224DiagnosedSilenceableFailure
225transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
226 scf::ForOp target,
227 transform::ApplyToEachResultList &results,
228 transform::TransformState &state) {
229 scf::ForOp result;
230 if (getPeelFront()) {
231 LogicalResult status =
232 scf::peelForLoopFirstIteration(rewriter, target, result);
233 if (failed(status)) {
234 DiagnosedSilenceableFailure diag =
235 emitSilenceableError() << "failed to peel the first iteration";
236 return diag;
237 }
238 } else {
239 LogicalResult status =
240 scf::peelForLoopAndSimplifyBounds(rewriter, target, result);
241 if (failed(status)) {
242 DiagnosedSilenceableFailure diag = emitSilenceableError()
243 << "failed to peel the last iteration";
244 return diag;
245 }
246 }
247
248 results.push_back(target);
249 results.push_back(result);
250
251 return DiagnosedSilenceableFailure::success();
252}
253
254//===----------------------------------------------------------------------===//
255// LoopPipelineOp
256//===----------------------------------------------------------------------===//
257
258/// Callback for PipeliningOption. Populates `schedule` with the mapping from an
259/// operation to its logical time position given the iteration interval and the
260/// read latency. The latter is only relevant for vector transfers.
261static void
262loopScheduling(scf::ForOp forOp,
263 std::vector<std::pair<Operation *, unsigned>> &schedule,
264 unsigned iterationInterval, unsigned readLatency) {
265 auto getLatency = [&](Operation *op) -> unsigned {
266 if (isa<vector::TransferReadOp>(Val: op))
267 return readLatency;
268 return 1;
269 };
270
271 std::optional<int64_t> ubConstant =
272 getConstantIntValue(forOp.getUpperBound());
273 std::optional<int64_t> lbConstant =
274 getConstantIntValue(forOp.getLowerBound());
275 DenseMap<Operation *, unsigned> opCycles;
276 std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
277 for (Operation &op : forOp.getBody()->getOperations()) {
278 if (isa<scf::YieldOp>(op))
279 continue;
280 unsigned earlyCycle = 0;
281 for (Value operand : op.getOperands()) {
282 Operation *def = operand.getDefiningOp();
283 if (!def)
284 continue;
285 if (ubConstant && lbConstant) {
286 unsigned ubInt = ubConstant.value();
287 unsigned lbInt = lbConstant.value();
288 auto minLatency = std::min(ubInt - lbInt - 1, getLatency(def));
289 earlyCycle = std::max(earlyCycle, opCycles[def] + minLatency);
290 } else {
291 earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
292 }
293 }
294 opCycles[&op] = earlyCycle;
295 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
296 }
297 for (const auto &it : wrappedSchedule) {
298 for (Operation *op : it.second) {
299 unsigned cycle = opCycles[op];
300 schedule.emplace_back(args&: op, args: cycle / iterationInterval);
301 }
302 }
303}
304
305DiagnosedSilenceableFailure
306transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter,
307 scf::ForOp target,
308 transform::ApplyToEachResultList &results,
309 transform::TransformState &state) {
310 scf::PipeliningOption options;
311 options.getScheduleFn =
312 [this](scf::ForOp forOp,
313 std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
314 loopScheduling(forOp, schedule, getIterationInterval(),
315 getReadLatency());
316 };
317 scf::ForLoopPipeliningPattern pattern(options, target->getContext());
318 rewriter.setInsertionPoint(target);
319 FailureOr<scf::ForOp> patternResult =
320 scf::pipelineForLoop(rewriter, target, options);
321 if (succeeded(patternResult)) {
322 results.push_back(*patternResult);
323 return DiagnosedSilenceableFailure::success();
324 }
325 return emitDefaultSilenceableFailure(target);
326}
327
328//===----------------------------------------------------------------------===//
329// LoopPromoteIfOneIterationOp
330//===----------------------------------------------------------------------===//
331
332DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne(
333 transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
334 transform::ApplyToEachResultList &results,
335 transform::TransformState &state) {
336 (void)target.promoteIfSingleIteration(rewriter);
337 return DiagnosedSilenceableFailure::success();
338}
339
340void transform::LoopPromoteIfOneIterationOp::getEffects(
341 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
342 consumesHandle(getTargetMutable(), effects);
343 modifiesPayload(effects);
344}
345
346//===----------------------------------------------------------------------===//
347// LoopUnrollOp
348//===----------------------------------------------------------------------===//
349
350DiagnosedSilenceableFailure
351transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
352 Operation *op,
353 transform::ApplyToEachResultList &results,
354 transform::TransformState &state) {
355 LogicalResult result(failure());
356 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
357 result = loopUnrollByFactor(scfFor, getFactor());
358 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
359 result = loopUnrollByFactor(affineFor, getFactor());
360 else
361 return emitSilenceableError()
362 << "failed to unroll, incorrect type of payload";
363
364 if (failed(result))
365 return emitSilenceableError() << "failed to unroll";
366
367 return DiagnosedSilenceableFailure::success();
368}
369
370//===----------------------------------------------------------------------===//
371// LoopUnrollAndJamOp
372//===----------------------------------------------------------------------===//
373
374DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
375 transform::TransformRewriter &rewriter, Operation *op,
376 transform::ApplyToEachResultList &results,
377 transform::TransformState &state) {
378 LogicalResult result(failure());
379 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
380 result = loopUnrollJamByFactor(scfFor, getFactor());
381 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
382 result = loopUnrollJamByFactor(affineFor, getFactor());
383 else
384 return emitSilenceableError()
385 << "failed to unroll and jam, incorrect type of payload";
386
387 if (failed(result))
388 return emitSilenceableError() << "failed to unroll and jam";
389
390 return DiagnosedSilenceableFailure::success();
391}
392
393//===----------------------------------------------------------------------===//
394// LoopCoalesceOp
395//===----------------------------------------------------------------------===//
396
397DiagnosedSilenceableFailure
398transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
399 Operation *op,
400 transform::ApplyToEachResultList &results,
401 transform::TransformState &state) {
402 LogicalResult result(failure());
403 if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
404 result = coalescePerfectlyNestedSCFForLoops(scfForOp);
405 else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
406 result = coalescePerfectlyNestedAffineLoops(affineForOp);
407
408 results.push_back(op);
409 if (failed(result)) {
410 DiagnosedSilenceableFailure diag = emitSilenceableError()
411 << "failed to coalesce";
412 return diag;
413 }
414 return DiagnosedSilenceableFailure::success();
415}
416
417//===----------------------------------------------------------------------===//
418// TakeAssumedBranchOp
419//===----------------------------------------------------------------------===//
420/// Replaces the given op with the contents of the given single-block region,
421/// using the operands of the block terminator to replace operation results.
422static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op,
423 Region &region) {
424 assert(llvm::hasSingleElement(region) && "expected single-region block");
425 Block *block = &region.front();
426 Operation *terminator = block->getTerminator();
427 ValueRange results = terminator->getOperands();
428 rewriter.inlineBlockBefore(source: block, op, /*blockArgs=*/argValues: {});
429 rewriter.replaceOp(op, newValues: results);
430 rewriter.eraseOp(op: terminator);
431}
432
433DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne(
434 transform::TransformRewriter &rewriter, scf::IfOp ifOp,
435 transform::ApplyToEachResultList &results,
436 transform::TransformState &state) {
437 rewriter.setInsertionPoint(ifOp);
438 Region &region =
439 getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
440 if (!llvm::hasSingleElement(region)) {
441 return emitDefiniteFailure()
442 << "requires an scf.if op with a single-block "
443 << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";
444 }
445 replaceOpWithRegion(rewriter, ifOp, region);
446 return DiagnosedSilenceableFailure::success();
447}
448
449void transform::TakeAssumedBranchOp::getEffects(
450 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
451 onlyReadsHandle(getTargetMutable(), effects);
452 modifiesPayload(effects);
453}
454
455//===----------------------------------------------------------------------===//
456// LoopFuseSiblingOp
457//===----------------------------------------------------------------------===//
458
459/// Check if `target` and `source` are siblings, in the context that `target`
460/// is being fused into `source`.
461///
462/// This is a simple check that just checks if both operations are in the same
463/// block and some checks to ensure that the fused IR does not violate
464/// dominance.
465static DiagnosedSilenceableFailure isOpSibling(Operation *target,
466 Operation *source) {
467 // Check if both operations are same.
468 if (target == source)
469 return emitSilenceableFailure(op: source)
470 << "target and source need to be different loops";
471
472 // Check if both operations are in the same block.
473 if (target->getBlock() != source->getBlock())
474 return emitSilenceableFailure(op: source)
475 << "target and source are not in the same block";
476
477 // Check if fusion will violate dominance.
478 DominanceInfo domInfo(source);
479 if (target->isBeforeInBlock(other: source)) {
480 // Since `target` is before `source`, all users of results of `target`
481 // need to be dominated by `source`.
482 for (Operation *user : target->getUsers()) {
483 if (!domInfo.properlyDominates(a: source, b: user, /*enclosingOpOk=*/false)) {
484 return emitSilenceableFailure(op: target)
485 << "user of results of target should be properly dominated by "
486 "source";
487 }
488 }
489 } else {
490 // Since `target` is after `source`, all values used by `target` need
491 // to dominate `source`.
492
493 // Check if operands of `target` are dominated by `source`.
494 for (Value operand : target->getOperands()) {
495 Operation *operandOp = operand.getDefiningOp();
496 // Operands without defining operations are block arguments. When `target`
497 // and `source` occur in the same block, these operands dominate `source`.
498 if (!operandOp)
499 continue;
500
501 // Operand's defining operation should properly dominate `source`.
502 if (!domInfo.properlyDominates(a: operandOp, b: source,
503 /*enclosingOpOk=*/false))
504 return emitSilenceableFailure(op: target)
505 << "operands of target should be properly dominated by source";
506 }
507
508 // Check if values used by `target` are dominated by `source`.
509 bool failed = false;
510 OpOperand *failedValue = nullptr;
511 visitUsedValuesDefinedAbove(regions: target->getRegions(), callback: [&](OpOperand *operand) {
512 Operation *operandOp = operand->get().getDefiningOp();
513 if (operandOp && !domInfo.properlyDominates(a: operandOp, b: source,
514 /*enclosingOpOk=*/false)) {
515 // `operand` is not an argument of an enclosing block and the defining
516 // op of `operand` is outside `target` but does not dominate `source`.
517 failed = true;
518 failedValue = operand;
519 }
520 });
521
522 if (failed)
523 return emitSilenceableFailure(op: failedValue->getOwner())
524 << "values used inside regions of target should be properly "
525 "dominated by source";
526 }
527
528 return DiagnosedSilenceableFailure::success();
529}
530
531/// Check if `target` scf.forall can be fused into `source` scf.forall.
532///
533/// This simply checks if both loops have the same bounds, steps and mapping.
534/// No attempt is made at checking that the side effects of `target` and
535/// `source` are independent of each other.
536static bool isForallWithIdenticalConfiguration(Operation *target,
537 Operation *source) {
538 auto targetOp = dyn_cast<scf::ForallOp>(target);
539 auto sourceOp = dyn_cast<scf::ForallOp>(source);
540 if (!targetOp || !sourceOp)
541 return false;
542
543 return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
544 targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
545 targetOp.getMixedStep() == sourceOp.getMixedStep() &&
546 targetOp.getMapping() == sourceOp.getMapping();
547}
548
549/// Check if `target` scf.for can be fused into `source` scf.for.
550///
551/// This simply checks if both loops have the same bounds and steps. No attempt
552/// is made at checking that the side effects of `target` and `source` are
553/// independent of each other.
554static bool isForWithIdenticalConfiguration(Operation *target,
555 Operation *source) {
556 auto targetOp = dyn_cast<scf::ForOp>(target);
557 auto sourceOp = dyn_cast<scf::ForOp>(source);
558 if (!targetOp || !sourceOp)
559 return false;
560
561 return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
562 targetOp.getUpperBound() == sourceOp.getUpperBound() &&
563 targetOp.getStep() == sourceOp.getStep();
564}
565
566DiagnosedSilenceableFailure
567transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
568 transform::TransformResults &results,
569 transform::TransformState &state) {
570 auto targetOps = state.getPayloadOps(getTarget());
571 auto sourceOps = state.getPayloadOps(getSource());
572
573 if (!llvm::hasSingleElement(targetOps) ||
574 !llvm::hasSingleElement(sourceOps)) {
575 return emitDefiniteFailure()
576 << "requires exactly one target handle (got "
577 << llvm::range_size(targetOps) << ") and exactly one "
578 << "source handle (got " << llvm::range_size(sourceOps) << ")";
579 }
580
581 Operation *target = *targetOps.begin();
582 Operation *source = *sourceOps.begin();
583
584 // Check if the target and source are siblings.
585 DiagnosedSilenceableFailure diag = isOpSibling(target, source);
586 if (!diag.succeeded())
587 return diag;
588
589 Operation *fusedLoop;
590 /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
591 if (isForWithIdenticalConfiguration(target, source)) {
592 fusedLoop = fuseIndependentSiblingForLoops(
593 cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
594 } else if (isForallWithIdenticalConfiguration(target, source)) {
595 fusedLoop = fuseIndependentSiblingForallLoops(
596 cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
597 } else {
598 return emitSilenceableFailure(target->getLoc())
599 << "operations cannot be fused";
600 }
601
602 assert(fusedLoop && "failed to fuse operations");
603
604 results.set(cast<OpResult>(getFusedLoop()), {fusedLoop});
605 return DiagnosedSilenceableFailure::success();
606}
607
608//===----------------------------------------------------------------------===//
609// Transform op registration
610//===----------------------------------------------------------------------===//
611
612namespace {
613class SCFTransformDialectExtension
614 : public transform::TransformDialectExtension<
615 SCFTransformDialectExtension> {
616public:
617 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension)
618
619 using Base::Base;
620
621 void init() {
622 declareGeneratedDialect<affine::AffineDialect>();
623 declareGeneratedDialect<func::FuncDialect>();
624
625 registerTransformOps<
626#define GET_OP_LIST
627#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
628 >();
629 }
630};
631} // namespace
632
633#define GET_OP_CLASSES
634#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
635
636void mlir::scf::registerTransformDialectExtension(DialectRegistry &registry) {
637 registry.addExtensions<SCFTransformDialectExtension>();
638}
639

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp