| 1 | //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" |
| 10 | |
| 11 | #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" |
| 12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 13 | #include "mlir/Dialect/Affine/LoopUtils.h" |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 16 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 18 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| 19 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| 20 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
| 21 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| 22 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
| 23 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| 24 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 25 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 26 | #include "mlir/IR/BuiltinAttributes.h" |
| 27 | #include "mlir/IR/Dominance.h" |
| 28 | #include "mlir/IR/OpDefinition.h" |
| 29 | |
| 30 | using namespace mlir; |
| 31 | using namespace mlir::affine; |
| 32 | |
| 33 | //===----------------------------------------------------------------------===// |
| 34 | // Apply...PatternsOp |
| 35 | //===----------------------------------------------------------------------===// |
| 36 | |
| 37 | void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns( |
| 38 | RewritePatternSet &patterns) { |
| 39 | scf::populateSCFForLoopCanonicalizationPatterns(patterns); |
| 40 | } |
| 41 | |
| 42 | void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns( |
| 43 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 44 | scf::populateSCFStructuralTypeConversions(typeConverter, patterns); |
| 45 | } |
| 46 | |
| 47 | void transform::ApplySCFStructuralConversionPatternsOp:: |
| 48 | populateConversionTargetRules(const TypeConverter &typeConverter, |
| 49 | ConversionTarget &conversionTarget) { |
| 50 | scf::populateSCFStructuralTypeConversionTarget(typeConverter, |
| 51 | conversionTarget); |
| 52 | } |
| 53 | |
| 54 | void transform::ApplySCFToControlFlowPatternsOp::populatePatterns( |
| 55 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 56 | populateSCFToControlFlowConversionPatterns(patterns); |
| 57 | } |
| 58 | |
| 59 | //===----------------------------------------------------------------------===// |
| 60 | // ForallToForOp |
| 61 | //===----------------------------------------------------------------------===// |
| 62 | |
| 63 | DiagnosedSilenceableFailure |
| 64 | transform::ForallToForOp::apply(transform::TransformRewriter &rewriter, |
| 65 | transform::TransformResults &results, |
| 66 | transform::TransformState &state) { |
| 67 | auto payload = state.getPayloadOps(getTarget()); |
| 68 | if (!llvm::hasSingleElement(payload)) |
| 69 | return emitSilenceableError() << "expected a single payload op" ; |
| 70 | |
| 71 | auto target = dyn_cast<scf::ForallOp>(*payload.begin()); |
| 72 | if (!target) { |
| 73 | DiagnosedSilenceableFailure diag = |
| 74 | emitSilenceableError() << "expected the payload to be scf.forall" ; |
| 75 | diag.attachNote((*payload.begin())->getLoc()) << "payload op" ; |
| 76 | return diag; |
| 77 | } |
| 78 | |
| 79 | if (!target.getOutputs().empty()) { |
| 80 | return emitSilenceableError() |
| 81 | << "unsupported shared outputs (didn't bufferize?)" ; |
| 82 | } |
| 83 | |
| 84 | SmallVector<OpFoldResult> lbs = target.getMixedLowerBound(); |
| 85 | |
| 86 | if (getNumResults() != lbs.size()) { |
| 87 | DiagnosedSilenceableFailure diag = |
| 88 | emitSilenceableError() |
| 89 | << "op expects as many results (" << getNumResults() |
| 90 | << ") as payload has induction variables (" << lbs.size() << ")" ; |
| 91 | diag.attachNote(target.getLoc()) << "payload op" ; |
| 92 | return diag; |
| 93 | } |
| 94 | |
| 95 | SmallVector<Operation *> opResults; |
| 96 | if (failed(scf::forallToForLoop(rewriter, target, &opResults))) { |
| 97 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
| 98 | << "failed to convert forall into for" ; |
| 99 | return diag; |
| 100 | } |
| 101 | |
| 102 | for (auto &&[i, res] : llvm::enumerate(opResults)) { |
| 103 | results.set(cast<OpResult>(getTransformed()[i]), {res}); |
| 104 | } |
| 105 | return DiagnosedSilenceableFailure::success(); |
| 106 | } |
| 107 | |
| 108 | //===----------------------------------------------------------------------===// |
| 109 | // ForallToForOp |
| 110 | //===----------------------------------------------------------------------===// |
| 111 | |
| 112 | DiagnosedSilenceableFailure |
| 113 | transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter, |
| 114 | transform::TransformResults &results, |
| 115 | transform::TransformState &state) { |
| 116 | auto payload = state.getPayloadOps(getTarget()); |
| 117 | if (!llvm::hasSingleElement(payload)) |
| 118 | return emitSilenceableError() << "expected a single payload op" ; |
| 119 | |
| 120 | auto target = dyn_cast<scf::ForallOp>(*payload.begin()); |
| 121 | if (!target) { |
| 122 | DiagnosedSilenceableFailure diag = |
| 123 | emitSilenceableError() << "expected the payload to be scf.forall" ; |
| 124 | diag.attachNote((*payload.begin())->getLoc()) << "payload op" ; |
| 125 | return diag; |
| 126 | } |
| 127 | |
| 128 | if (!target.getOutputs().empty()) { |
| 129 | return emitSilenceableError() |
| 130 | << "unsupported shared outputs (didn't bufferize?)" ; |
| 131 | } |
| 132 | |
| 133 | if (getNumResults() != 1) { |
| 134 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
| 135 | << "op expects one result, given " |
| 136 | << getNumResults(); |
| 137 | diag.attachNote(target.getLoc()) << "payload op" ; |
| 138 | return diag; |
| 139 | } |
| 140 | |
| 141 | scf::ParallelOp opResult; |
| 142 | if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) { |
| 143 | DiagnosedSilenceableFailure diag = |
| 144 | emitSilenceableError() << "failed to convert forall into parallel" ; |
| 145 | return diag; |
| 146 | } |
| 147 | |
| 148 | results.set(cast<OpResult>(getTransformed()[0]), {opResult}); |
| 149 | return DiagnosedSilenceableFailure::success(); |
| 150 | } |
| 151 | |
| 152 | //===----------------------------------------------------------------------===// |
| 153 | // LoopOutlineOp |
| 154 | //===----------------------------------------------------------------------===// |
| 155 | |
| 156 | /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses |
| 157 | /// the provided rewriter for all operations to remain compatible with the |
| 158 | /// rewriting infra, as opposed to just splicing the op in place. |
| 159 | static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, |
| 160 | Operation *op) { |
| 161 | if (op->getNumRegions() != 1) |
| 162 | return nullptr; |
| 163 | OpBuilder::InsertionGuard g(b); |
| 164 | b.setInsertionPoint(op); |
| 165 | scf::ExecuteRegionOp executeRegionOp = |
| 166 | b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes()); |
| 167 | { |
| 168 | OpBuilder::InsertionGuard g(b); |
| 169 | b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); |
| 170 | Operation *clonedOp = b.cloneWithoutRegions(op&: *op); |
| 171 | Region &clonedRegion = clonedOp->getRegions().front(); |
| 172 | assert(clonedRegion.empty() && "expected empty region" ); |
| 173 | b.inlineRegionBefore(region&: op->getRegions().front(), parent&: clonedRegion, |
| 174 | before: clonedRegion.end()); |
| 175 | b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults()); |
| 176 | } |
| 177 | b.replaceOp(op, executeRegionOp.getResults()); |
| 178 | return executeRegionOp; |
| 179 | } |
| 180 | |
| 181 | DiagnosedSilenceableFailure |
| 182 | transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter, |
| 183 | transform::TransformResults &results, |
| 184 | transform::TransformState &state) { |
| 185 | SmallVector<Operation *> functions; |
| 186 | SmallVector<Operation *> calls; |
| 187 | DenseMap<Operation *, SymbolTable> symbolTables; |
| 188 | for (Operation *target : state.getPayloadOps(getTarget())) { |
| 189 | Location location = target->getLoc(); |
| 190 | Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); |
| 191 | scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); |
| 192 | if (!exec) { |
| 193 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
| 194 | << "failed to outline" ; |
| 195 | diag.attachNote(target->getLoc()) << "target op" ; |
| 196 | return diag; |
| 197 | } |
| 198 | func::CallOp call; |
| 199 | FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion( |
| 200 | rewriter, location, exec.getRegion(), getFuncName(), &call); |
| 201 | |
| 202 | if (failed(outlined)) |
| 203 | return emitDefaultDefiniteFailure(target); |
| 204 | |
| 205 | if (symbolTableOp) { |
| 206 | SymbolTable &symbolTable = |
| 207 | symbolTables.try_emplace(symbolTableOp, symbolTableOp) |
| 208 | .first->getSecond(); |
| 209 | symbolTable.insert(*outlined); |
| 210 | call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); |
| 211 | } |
| 212 | functions.push_back(*outlined); |
| 213 | calls.push_back(call); |
| 214 | } |
| 215 | results.set(cast<OpResult>(getFunction()), functions); |
| 216 | results.set(cast<OpResult>(getCall()), calls); |
| 217 | return DiagnosedSilenceableFailure::success(); |
| 218 | } |
| 219 | |
| 220 | //===----------------------------------------------------------------------===// |
| 221 | // LoopPeelOp |
| 222 | //===----------------------------------------------------------------------===// |
| 223 | |
| 224 | DiagnosedSilenceableFailure |
| 225 | transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter, |
| 226 | scf::ForOp target, |
| 227 | transform::ApplyToEachResultList &results, |
| 228 | transform::TransformState &state) { |
| 229 | scf::ForOp result; |
| 230 | if (getPeelFront()) { |
| 231 | LogicalResult status = |
| 232 | scf::peelForLoopFirstIteration(rewriter, target, result); |
| 233 | if (failed(status)) { |
| 234 | DiagnosedSilenceableFailure diag = |
| 235 | emitSilenceableError() << "failed to peel the first iteration" ; |
| 236 | return diag; |
| 237 | } |
| 238 | } else { |
| 239 | LogicalResult status = |
| 240 | scf::peelForLoopAndSimplifyBounds(rewriter, target, result); |
| 241 | if (failed(status)) { |
| 242 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
| 243 | << "failed to peel the last iteration" ; |
| 244 | return diag; |
| 245 | } |
| 246 | } |
| 247 | |
| 248 | results.push_back(target); |
| 249 | results.push_back(result); |
| 250 | |
| 251 | return DiagnosedSilenceableFailure::success(); |
| 252 | } |
| 253 | |
| 254 | //===----------------------------------------------------------------------===// |
| 255 | // LoopPipelineOp |
| 256 | //===----------------------------------------------------------------------===// |
| 257 | |
| 258 | /// Callback for PipeliningOption. Populates `schedule` with the mapping from an |
| 259 | /// operation to its logical time position given the iteration interval and the |
| 260 | /// read latency. The latter is only relevant for vector transfers. |
| 261 | static void |
| 262 | loopScheduling(scf::ForOp forOp, |
| 263 | std::vector<std::pair<Operation *, unsigned>> &schedule, |
| 264 | unsigned iterationInterval, unsigned readLatency) { |
| 265 | auto getLatency = [&](Operation *op) -> unsigned { |
| 266 | if (isa<vector::TransferReadOp>(Val: op)) |
| 267 | return readLatency; |
| 268 | return 1; |
| 269 | }; |
| 270 | |
| 271 | std::optional<int64_t> ubConstant = |
| 272 | getConstantIntValue(forOp.getUpperBound()); |
| 273 | std::optional<int64_t> lbConstant = |
| 274 | getConstantIntValue(forOp.getLowerBound()); |
| 275 | DenseMap<Operation *, unsigned> opCycles; |
| 276 | std::map<unsigned, std::vector<Operation *>> wrappedSchedule; |
| 277 | for (Operation &op : forOp.getBody()->getOperations()) { |
| 278 | if (isa<scf::YieldOp>(op)) |
| 279 | continue; |
| 280 | unsigned earlyCycle = 0; |
| 281 | for (Value operand : op.getOperands()) { |
| 282 | Operation *def = operand.getDefiningOp(); |
| 283 | if (!def) |
| 284 | continue; |
| 285 | if (ubConstant && lbConstant) { |
| 286 | unsigned ubInt = ubConstant.value(); |
| 287 | unsigned lbInt = lbConstant.value(); |
| 288 | auto minLatency = std::min(ubInt - lbInt - 1, getLatency(def)); |
| 289 | earlyCycle = std::max(earlyCycle, opCycles[def] + minLatency); |
| 290 | } else { |
| 291 | earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); |
| 292 | } |
| 293 | } |
| 294 | opCycles[&op] = earlyCycle; |
| 295 | wrappedSchedule[earlyCycle % iterationInterval].push_back(&op); |
| 296 | } |
| 297 | for (const auto &it : wrappedSchedule) { |
| 298 | for (Operation *op : it.second) { |
| 299 | unsigned cycle = opCycles[op]; |
| 300 | schedule.emplace_back(args&: op, args: cycle / iterationInterval); |
| 301 | } |
| 302 | } |
| 303 | } |
| 304 | |
| 305 | DiagnosedSilenceableFailure |
| 306 | transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter, |
| 307 | scf::ForOp target, |
| 308 | transform::ApplyToEachResultList &results, |
| 309 | transform::TransformState &state) { |
| 310 | scf::PipeliningOption options; |
| 311 | options.getScheduleFn = |
| 312 | [this](scf::ForOp forOp, |
| 313 | std::vector<std::pair<Operation *, unsigned>> &schedule) mutable { |
| 314 | loopScheduling(forOp, schedule, getIterationInterval(), |
| 315 | getReadLatency()); |
| 316 | }; |
| 317 | scf::ForLoopPipeliningPattern pattern(options, target->getContext()); |
| 318 | rewriter.setInsertionPoint(target); |
| 319 | FailureOr<scf::ForOp> patternResult = |
| 320 | scf::pipelineForLoop(rewriter, target, options); |
| 321 | if (succeeded(patternResult)) { |
| 322 | results.push_back(*patternResult); |
| 323 | return DiagnosedSilenceableFailure::success(); |
| 324 | } |
| 325 | return emitDefaultSilenceableFailure(target); |
| 326 | } |
| 327 | |
| 328 | //===----------------------------------------------------------------------===// |
| 329 | // LoopPromoteIfOneIterationOp |
| 330 | //===----------------------------------------------------------------------===// |
| 331 | |
| 332 | DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne( |
| 333 | transform::TransformRewriter &rewriter, LoopLikeOpInterface target, |
| 334 | transform::ApplyToEachResultList &results, |
| 335 | transform::TransformState &state) { |
| 336 | (void)target.promoteIfSingleIteration(rewriter); |
| 337 | return DiagnosedSilenceableFailure::success(); |
| 338 | } |
| 339 | |
| 340 | void transform::LoopPromoteIfOneIterationOp::getEffects( |
| 341 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 342 | consumesHandle(getTargetMutable(), effects); |
| 343 | modifiesPayload(effects); |
| 344 | } |
| 345 | |
| 346 | //===----------------------------------------------------------------------===// |
| 347 | // LoopUnrollOp |
| 348 | //===----------------------------------------------------------------------===// |
| 349 | |
| 350 | DiagnosedSilenceableFailure |
| 351 | transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, |
| 352 | Operation *op, |
| 353 | transform::ApplyToEachResultList &results, |
| 354 | transform::TransformState &state) { |
| 355 | LogicalResult result(failure()); |
| 356 | if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) |
| 357 | result = loopUnrollByFactor(scfFor, getFactor()); |
| 358 | else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op)) |
| 359 | result = loopUnrollByFactor(affineFor, getFactor()); |
| 360 | else |
| 361 | return emitSilenceableError() |
| 362 | << "failed to unroll, incorrect type of payload" ; |
| 363 | |
| 364 | if (failed(result)) |
| 365 | return emitSilenceableError() << "failed to unroll" ; |
| 366 | |
| 367 | return DiagnosedSilenceableFailure::success(); |
| 368 | } |
| 369 | |
| 370 | //===----------------------------------------------------------------------===// |
| 371 | // LoopUnrollAndJamOp |
| 372 | //===----------------------------------------------------------------------===// |
| 373 | |
| 374 | DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne( |
| 375 | transform::TransformRewriter &rewriter, Operation *op, |
| 376 | transform::ApplyToEachResultList &results, |
| 377 | transform::TransformState &state) { |
| 378 | LogicalResult result(failure()); |
| 379 | if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) |
| 380 | result = loopUnrollJamByFactor(scfFor, getFactor()); |
| 381 | else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op)) |
| 382 | result = loopUnrollJamByFactor(affineFor, getFactor()); |
| 383 | else |
| 384 | return emitSilenceableError() |
| 385 | << "failed to unroll and jam, incorrect type of payload" ; |
| 386 | |
| 387 | if (failed(result)) |
| 388 | return emitSilenceableError() << "failed to unroll and jam" ; |
| 389 | |
| 390 | return DiagnosedSilenceableFailure::success(); |
| 391 | } |
| 392 | |
| 393 | //===----------------------------------------------------------------------===// |
| 394 | // LoopCoalesceOp |
| 395 | //===----------------------------------------------------------------------===// |
| 396 | |
| 397 | DiagnosedSilenceableFailure |
| 398 | transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter, |
| 399 | Operation *op, |
| 400 | transform::ApplyToEachResultList &results, |
| 401 | transform::TransformState &state) { |
| 402 | LogicalResult result(failure()); |
| 403 | if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op)) |
| 404 | result = coalescePerfectlyNestedSCFForLoops(scfForOp); |
| 405 | else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op)) |
| 406 | result = coalescePerfectlyNestedAffineLoops(affineForOp); |
| 407 | |
| 408 | results.push_back(op); |
| 409 | if (failed(result)) { |
| 410 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
| 411 | << "failed to coalesce" ; |
| 412 | return diag; |
| 413 | } |
| 414 | return DiagnosedSilenceableFailure::success(); |
| 415 | } |
| 416 | |
| 417 | //===----------------------------------------------------------------------===// |
| 418 | // TakeAssumedBranchOp |
| 419 | //===----------------------------------------------------------------------===// |
| 420 | /// Replaces the given op with the contents of the given single-block region, |
| 421 | /// using the operands of the block terminator to replace operation results. |
| 422 | static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, |
| 423 | Region ®ion) { |
| 424 | assert(llvm::hasSingleElement(region) && "expected single-region block" ); |
| 425 | Block *block = ®ion.front(); |
| 426 | Operation *terminator = block->getTerminator(); |
| 427 | ValueRange results = terminator->getOperands(); |
| 428 | rewriter.inlineBlockBefore(source: block, op, /*blockArgs=*/argValues: {}); |
| 429 | rewriter.replaceOp(op, newValues: results); |
| 430 | rewriter.eraseOp(op: terminator); |
| 431 | } |
| 432 | |
| 433 | DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne( |
| 434 | transform::TransformRewriter &rewriter, scf::IfOp ifOp, |
| 435 | transform::ApplyToEachResultList &results, |
| 436 | transform::TransformState &state) { |
| 437 | rewriter.setInsertionPoint(ifOp); |
| 438 | Region ®ion = |
| 439 | getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion(); |
| 440 | if (!llvm::hasSingleElement(region)) { |
| 441 | return emitDefiniteFailure() |
| 442 | << "requires an scf.if op with a single-block " |
| 443 | << ((getTakeElseBranch()) ? "`else`" : "`then`" ) << " region" ; |
| 444 | } |
| 445 | replaceOpWithRegion(rewriter, ifOp, region); |
| 446 | return DiagnosedSilenceableFailure::success(); |
| 447 | } |
| 448 | |
| 449 | void transform::TakeAssumedBranchOp::getEffects( |
| 450 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 451 | onlyReadsHandle(getTargetMutable(), effects); |
| 452 | modifiesPayload(effects); |
| 453 | } |
| 454 | |
| 455 | //===----------------------------------------------------------------------===// |
| 456 | // LoopFuseSiblingOp |
| 457 | //===----------------------------------------------------------------------===// |
| 458 | |
| 459 | /// Check if `target` and `source` are siblings, in the context that `target` |
| 460 | /// is being fused into `source`. |
| 461 | /// |
| 462 | /// This is a simple check that just checks if both operations are in the same |
| 463 | /// block and some checks to ensure that the fused IR does not violate |
| 464 | /// dominance. |
| 465 | static DiagnosedSilenceableFailure isOpSibling(Operation *target, |
| 466 | Operation *source) { |
| 467 | // Check if both operations are same. |
| 468 | if (target == source) |
| 469 | return emitSilenceableFailure(op: source) |
| 470 | << "target and source need to be different loops" ; |
| 471 | |
| 472 | // Check if both operations are in the same block. |
| 473 | if (target->getBlock() != source->getBlock()) |
| 474 | return emitSilenceableFailure(op: source) |
| 475 | << "target and source are not in the same block" ; |
| 476 | |
| 477 | // Check if fusion will violate dominance. |
| 478 | DominanceInfo domInfo(source); |
| 479 | if (target->isBeforeInBlock(other: source)) { |
| 480 | // Since `target` is before `source`, all users of results of `target` |
| 481 | // need to be dominated by `source`. |
| 482 | for (Operation *user : target->getUsers()) { |
| 483 | if (!domInfo.properlyDominates(a: source, b: user, /*enclosingOpOk=*/false)) { |
| 484 | return emitSilenceableFailure(op: target) |
| 485 | << "user of results of target should be properly dominated by " |
| 486 | "source" ; |
| 487 | } |
| 488 | } |
| 489 | } else { |
| 490 | // Since `target` is after `source`, all values used by `target` need |
| 491 | // to dominate `source`. |
| 492 | |
| 493 | // Check if operands of `target` are dominated by `source`. |
| 494 | for (Value operand : target->getOperands()) { |
| 495 | Operation *operandOp = operand.getDefiningOp(); |
| 496 | // Operands without defining operations are block arguments. When `target` |
| 497 | // and `source` occur in the same block, these operands dominate `source`. |
| 498 | if (!operandOp) |
| 499 | continue; |
| 500 | |
| 501 | // Operand's defining operation should properly dominate `source`. |
| 502 | if (!domInfo.properlyDominates(a: operandOp, b: source, |
| 503 | /*enclosingOpOk=*/false)) |
| 504 | return emitSilenceableFailure(op: target) |
| 505 | << "operands of target should be properly dominated by source" ; |
| 506 | } |
| 507 | |
| 508 | // Check if values used by `target` are dominated by `source`. |
| 509 | bool failed = false; |
| 510 | OpOperand *failedValue = nullptr; |
| 511 | visitUsedValuesDefinedAbove(regions: target->getRegions(), callback: [&](OpOperand *operand) { |
| 512 | Operation *operandOp = operand->get().getDefiningOp(); |
| 513 | if (operandOp && !domInfo.properlyDominates(a: operandOp, b: source, |
| 514 | /*enclosingOpOk=*/false)) { |
| 515 | // `operand` is not an argument of an enclosing block and the defining |
| 516 | // op of `operand` is outside `target` but does not dominate `source`. |
| 517 | failed = true; |
| 518 | failedValue = operand; |
| 519 | } |
| 520 | }); |
| 521 | |
| 522 | if (failed) |
| 523 | return emitSilenceableFailure(op: failedValue->getOwner()) |
| 524 | << "values used inside regions of target should be properly " |
| 525 | "dominated by source" ; |
| 526 | } |
| 527 | |
| 528 | return DiagnosedSilenceableFailure::success(); |
| 529 | } |
| 530 | |
| 531 | /// Check if `target` scf.forall can be fused into `source` scf.forall. |
| 532 | /// |
| 533 | /// This simply checks if both loops have the same bounds, steps and mapping. |
| 534 | /// No attempt is made at checking that the side effects of `target` and |
| 535 | /// `source` are independent of each other. |
| 536 | static bool isForallWithIdenticalConfiguration(Operation *target, |
| 537 | Operation *source) { |
| 538 | auto targetOp = dyn_cast<scf::ForallOp>(target); |
| 539 | auto sourceOp = dyn_cast<scf::ForallOp>(source); |
| 540 | if (!targetOp || !sourceOp) |
| 541 | return false; |
| 542 | |
| 543 | return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && |
| 544 | targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && |
| 545 | targetOp.getMixedStep() == sourceOp.getMixedStep() && |
| 546 | targetOp.getMapping() == sourceOp.getMapping(); |
| 547 | } |
| 548 | |
| 549 | /// Check if `target` scf.for can be fused into `source` scf.for. |
| 550 | /// |
| 551 | /// This simply checks if both loops have the same bounds and steps. No attempt |
| 552 | /// is made at checking that the side effects of `target` and `source` are |
| 553 | /// independent of each other. |
| 554 | static bool isForWithIdenticalConfiguration(Operation *target, |
| 555 | Operation *source) { |
| 556 | auto targetOp = dyn_cast<scf::ForOp>(target); |
| 557 | auto sourceOp = dyn_cast<scf::ForOp>(source); |
| 558 | if (!targetOp || !sourceOp) |
| 559 | return false; |
| 560 | |
| 561 | return targetOp.getLowerBound() == sourceOp.getLowerBound() && |
| 562 | targetOp.getUpperBound() == sourceOp.getUpperBound() && |
| 563 | targetOp.getStep() == sourceOp.getStep(); |
| 564 | } |
| 565 | |
| 566 | DiagnosedSilenceableFailure |
| 567 | transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, |
| 568 | transform::TransformResults &results, |
| 569 | transform::TransformState &state) { |
| 570 | auto targetOps = state.getPayloadOps(getTarget()); |
| 571 | auto sourceOps = state.getPayloadOps(getSource()); |
| 572 | |
| 573 | if (!llvm::hasSingleElement(targetOps) || |
| 574 | !llvm::hasSingleElement(sourceOps)) { |
| 575 | return emitDefiniteFailure() |
| 576 | << "requires exactly one target handle (got " |
| 577 | << llvm::range_size(targetOps) << ") and exactly one " |
| 578 | << "source handle (got " << llvm::range_size(sourceOps) << ")" ; |
| 579 | } |
| 580 | |
| 581 | Operation *target = *targetOps.begin(); |
| 582 | Operation *source = *sourceOps.begin(); |
| 583 | |
| 584 | // Check if the target and source are siblings. |
| 585 | DiagnosedSilenceableFailure diag = isOpSibling(target, source); |
| 586 | if (!diag.succeeded()) |
| 587 | return diag; |
| 588 | |
| 589 | Operation *fusedLoop; |
| 590 | /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. |
| 591 | if (isForWithIdenticalConfiguration(target, source)) { |
| 592 | fusedLoop = fuseIndependentSiblingForLoops( |
| 593 | cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter); |
| 594 | } else if (isForallWithIdenticalConfiguration(target, source)) { |
| 595 | fusedLoop = fuseIndependentSiblingForallLoops( |
| 596 | cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter); |
| 597 | } else { |
| 598 | return emitSilenceableFailure(target->getLoc()) |
| 599 | << "operations cannot be fused" ; |
| 600 | } |
| 601 | |
| 602 | assert(fusedLoop && "failed to fuse operations" ); |
| 603 | |
| 604 | results.set(cast<OpResult>(getFusedLoop()), {fusedLoop}); |
| 605 | return DiagnosedSilenceableFailure::success(); |
| 606 | } |
| 607 | |
| 608 | //===----------------------------------------------------------------------===// |
| 609 | // Transform op registration |
| 610 | //===----------------------------------------------------------------------===// |
| 611 | |
| 612 | namespace { |
| 613 | class SCFTransformDialectExtension |
| 614 | : public transform::TransformDialectExtension< |
| 615 | SCFTransformDialectExtension> { |
| 616 | public: |
| 617 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension) |
| 618 | |
| 619 | using Base::Base; |
| 620 | |
| 621 | void init() { |
| 622 | declareGeneratedDialect<affine::AffineDialect>(); |
| 623 | declareGeneratedDialect<func::FuncDialect>(); |
| 624 | |
| 625 | registerTransformOps< |
| 626 | #define GET_OP_LIST |
| 627 | #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" |
| 628 | >(); |
| 629 | } |
| 630 | }; |
| 631 | } // namespace |
| 632 | |
| 633 | #define GET_OP_CLASSES |
| 634 | #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" |
| 635 | |
| 636 | void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) { |
| 637 | registry.addExtensions<SCFTransformDialectExtension>(); |
| 638 | } |
| 639 | |