| 1 | //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" |
| 10 | |
| 11 | #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" |
| 12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 13 | #include "mlir/Dialect/Affine/LoopUtils.h" |
| 14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 15 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 16 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| 17 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| 18 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
| 19 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| 20 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| 21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 22 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 23 | #include "mlir/IR/BuiltinAttributes.h" |
| 24 | #include "mlir/IR/Dominance.h" |
| 25 | #include "mlir/IR/OpDefinition.h" |
| 26 | |
| 27 | using namespace mlir; |
| 28 | using namespace mlir::affine; |
| 29 | |
| 30 | //===----------------------------------------------------------------------===// |
| 31 | // Apply...PatternsOp |
| 32 | //===----------------------------------------------------------------------===// |
| 33 | |
| 34 | void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns( |
| 35 | RewritePatternSet &patterns) { |
| 36 | scf::populateSCFForLoopCanonicalizationPatterns(patterns); |
| 37 | } |
| 38 | |
| 39 | void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns( |
| 40 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 41 | scf::populateSCFStructuralTypeConversions(typeConverter, patterns); |
| 42 | } |
| 43 | |
| 44 | void transform::ApplySCFStructuralConversionPatternsOp:: |
| 45 | populateConversionTargetRules(const TypeConverter &typeConverter, |
| 46 | ConversionTarget &conversionTarget) { |
| 47 | scf::populateSCFStructuralTypeConversionTarget(typeConverter, |
| 48 | target&: conversionTarget); |
| 49 | } |
| 50 | |
| 51 | void transform::ApplySCFToControlFlowPatternsOp::populatePatterns( |
| 52 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| 53 | populateSCFToControlFlowConversionPatterns(patterns); |
| 54 | } |
| 55 | |
| 56 | //===----------------------------------------------------------------------===// |
| 57 | // ForallToForOp |
| 58 | //===----------------------------------------------------------------------===// |
| 59 | |
| 60 | DiagnosedSilenceableFailure |
| 61 | transform::ForallToForOp::apply(transform::TransformRewriter &rewriter, |
| 62 | transform::TransformResults &results, |
| 63 | transform::TransformState &state) { |
| 64 | auto payload = state.getPayloadOps(value: getTarget()); |
| 65 | if (!llvm::hasSingleElement(C&: payload)) |
| 66 | return emitSilenceableError() << "expected a single payload op" ; |
| 67 | |
| 68 | auto target = dyn_cast<scf::ForallOp>(Val: *payload.begin()); |
| 69 | if (!target) { |
| 70 | DiagnosedSilenceableFailure diag = |
| 71 | emitSilenceableError() << "expected the payload to be scf.forall" ; |
| 72 | diag.attachNote(loc: (*payload.begin())->getLoc()) << "payload op" ; |
| 73 | return diag; |
| 74 | } |
| 75 | |
| 76 | if (!target.getOutputs().empty()) { |
| 77 | return emitSilenceableError() |
| 78 | << "unsupported shared outputs (didn't bufferize?)" ; |
| 79 | } |
| 80 | |
| 81 | SmallVector<OpFoldResult> lbs = target.getMixedLowerBound(); |
| 82 | |
| 83 | if (getNumResults() != lbs.size()) { |
| 84 | DiagnosedSilenceableFailure diag = |
| 85 | emitSilenceableError() |
| 86 | << "op expects as many results (" << getNumResults() |
| 87 | << ") as payload has induction variables (" << lbs.size() << ")" ; |
| 88 | diag.attachNote(loc: target.getLoc()) << "payload op" ; |
| 89 | return diag; |
| 90 | } |
| 91 | |
| 92 | SmallVector<Operation *> opResults; |
| 93 | if (failed(Result: scf::forallToForLoop(rewriter, forallOp: target, results: &opResults))) { |
| 94 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
| 95 | << "failed to convert forall into for" ; |
| 96 | return diag; |
| 97 | } |
| 98 | |
| 99 | for (auto &&[i, res] : llvm::enumerate(First&: opResults)) { |
| 100 | results.set(value: cast<OpResult>(Val: getTransformed()[i]), ops: {res}); |
| 101 | } |
| 102 | return DiagnosedSilenceableFailure::success(); |
| 103 | } |
| 104 | |
| 105 | //===----------------------------------------------------------------------===// |
| 106 | // ForallToForOp |
| 107 | //===----------------------------------------------------------------------===// |
| 108 | |
| 109 | DiagnosedSilenceableFailure |
| 110 | transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter, |
| 111 | transform::TransformResults &results, |
| 112 | transform::TransformState &state) { |
| 113 | auto payload = state.getPayloadOps(value: getTarget()); |
| 114 | if (!llvm::hasSingleElement(C&: payload)) |
| 115 | return emitSilenceableError() << "expected a single payload op" ; |
| 116 | |
| 117 | auto target = dyn_cast<scf::ForallOp>(Val: *payload.begin()); |
| 118 | if (!target) { |
| 119 | DiagnosedSilenceableFailure diag = |
| 120 | emitSilenceableError() << "expected the payload to be scf.forall" ; |
| 121 | diag.attachNote(loc: (*payload.begin())->getLoc()) << "payload op" ; |
| 122 | return diag; |
| 123 | } |
| 124 | |
| 125 | if (!target.getOutputs().empty()) { |
| 126 | return emitSilenceableError() |
| 127 | << "unsupported shared outputs (didn't bufferize?)" ; |
| 128 | } |
| 129 | |
| 130 | if (getNumResults() != 1) { |
| 131 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
| 132 | << "op expects one result, given " |
| 133 | << getNumResults(); |
| 134 | diag.attachNote(loc: target.getLoc()) << "payload op" ; |
| 135 | return diag; |
| 136 | } |
| 137 | |
| 138 | scf::ParallelOp opResult; |
| 139 | if (failed(Result: scf::forallToParallelLoop(rewriter, forallOp: target, result: &opResult))) { |
| 140 | DiagnosedSilenceableFailure diag = |
| 141 | emitSilenceableError() << "failed to convert forall into parallel" ; |
| 142 | return diag; |
| 143 | } |
| 144 | |
| 145 | results.set(value: cast<OpResult>(Val: getTransformed()[0]), ops: {opResult}); |
| 146 | return DiagnosedSilenceableFailure::success(); |
| 147 | } |
| 148 | |
| 149 | //===----------------------------------------------------------------------===// |
| 150 | // LoopOutlineOp |
| 151 | //===----------------------------------------------------------------------===// |
| 152 | |
| 153 | /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses |
| 154 | /// the provided rewriter for all operations to remain compatible with the |
| 155 | /// rewriting infra, as opposed to just splicing the op in place. |
| 156 | static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, |
| 157 | Operation *op) { |
| 158 | if (op->getNumRegions() != 1) |
| 159 | return nullptr; |
| 160 | OpBuilder::InsertionGuard g(b); |
| 161 | b.setInsertionPoint(op); |
| 162 | scf::ExecuteRegionOp executeRegionOp = |
| 163 | b.create<scf::ExecuteRegionOp>(location: op->getLoc(), args: op->getResultTypes()); |
| 164 | { |
| 165 | OpBuilder::InsertionGuard g(b); |
| 166 | b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); |
| 167 | Operation *clonedOp = b.cloneWithoutRegions(op&: *op); |
| 168 | Region &clonedRegion = clonedOp->getRegions().front(); |
| 169 | assert(clonedRegion.empty() && "expected empty region" ); |
| 170 | b.inlineRegionBefore(region&: op->getRegions().front(), parent&: clonedRegion, |
| 171 | before: clonedRegion.end()); |
| 172 | b.create<scf::YieldOp>(location: op->getLoc(), args: clonedOp->getResults()); |
| 173 | } |
| 174 | b.replaceOp(op, newValues: executeRegionOp.getResults()); |
| 175 | return executeRegionOp; |
| 176 | } |
| 177 | |
| 178 | DiagnosedSilenceableFailure |
| 179 | transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter, |
| 180 | transform::TransformResults &results, |
| 181 | transform::TransformState &state) { |
| 182 | SmallVector<Operation *> functions; |
| 183 | SmallVector<Operation *> calls; |
| 184 | DenseMap<Operation *, SymbolTable> symbolTables; |
| 185 | for (Operation *target : state.getPayloadOps(value: getTarget())) { |
| 186 | Location location = target->getLoc(); |
| 187 | Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from: target); |
| 188 | scf::ExecuteRegionOp exec = wrapInExecuteRegion(b&: rewriter, op: target); |
| 189 | if (!exec) { |
| 190 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
| 191 | << "failed to outline" ; |
| 192 | diag.attachNote(loc: target->getLoc()) << "target op" ; |
| 193 | return diag; |
| 194 | } |
| 195 | func::CallOp call; |
| 196 | FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion( |
| 197 | rewriter, loc: location, region&: exec.getRegion(), funcName: getFuncName(), callOp: &call); |
| 198 | |
| 199 | if (failed(Result: outlined)) |
| 200 | return emitDefaultDefiniteFailure(target); |
| 201 | |
| 202 | if (symbolTableOp) { |
| 203 | SymbolTable &symbolTable = |
| 204 | symbolTables.try_emplace(Key: symbolTableOp, Args&: symbolTableOp) |
| 205 | .first->getSecond(); |
| 206 | symbolTable.insert(symbol: *outlined); |
| 207 | call.setCalleeAttr(FlatSymbolRefAttr::get(symbol: *outlined)); |
| 208 | } |
| 209 | functions.push_back(Elt: *outlined); |
| 210 | calls.push_back(Elt: call); |
| 211 | } |
| 212 | results.set(value: cast<OpResult>(Val: getFunction()), ops&: functions); |
| 213 | results.set(value: cast<OpResult>(Val: getCall()), ops&: calls); |
| 214 | return DiagnosedSilenceableFailure::success(); |
| 215 | } |
| 216 | |
| 217 | //===----------------------------------------------------------------------===// |
| 218 | // LoopPeelOp |
| 219 | //===----------------------------------------------------------------------===// |
| 220 | |
| 221 | DiagnosedSilenceableFailure |
| 222 | transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter, |
| 223 | scf::ForOp target, |
| 224 | transform::ApplyToEachResultList &results, |
| 225 | transform::TransformState &state) { |
| 226 | scf::ForOp result; |
| 227 | if (getPeelFront()) { |
| 228 | LogicalResult status = |
| 229 | scf::peelForLoopFirstIteration(rewriter, forOp: target, partialIteration&: result); |
| 230 | if (failed(Result: status)) { |
| 231 | DiagnosedSilenceableFailure diag = |
| 232 | emitSilenceableError() << "failed to peel the first iteration" ; |
| 233 | return diag; |
| 234 | } |
| 235 | } else { |
| 236 | LogicalResult status = |
| 237 | scf::peelForLoopAndSimplifyBounds(rewriter, forOp: target, partialIteration&: result); |
| 238 | if (failed(Result: status)) { |
| 239 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
| 240 | << "failed to peel the last iteration" ; |
| 241 | return diag; |
| 242 | } |
| 243 | } |
| 244 | |
| 245 | results.push_back(op: target); |
| 246 | results.push_back(op: result); |
| 247 | |
| 248 | return DiagnosedSilenceableFailure::success(); |
| 249 | } |
| 250 | |
| 251 | //===----------------------------------------------------------------------===// |
| 252 | // LoopPipelineOp |
| 253 | //===----------------------------------------------------------------------===// |
| 254 | |
| 255 | /// Callback for PipeliningOption. Populates `schedule` with the mapping from an |
| 256 | /// operation to its logical time position given the iteration interval and the |
| 257 | /// read latency. The latter is only relevant for vector transfers. |
| 258 | static void |
| 259 | loopScheduling(scf::ForOp forOp, |
| 260 | std::vector<std::pair<Operation *, unsigned>> &schedule, |
| 261 | unsigned iterationInterval, unsigned readLatency) { |
| 262 | auto getLatency = [&](Operation *op) -> unsigned { |
| 263 | if (isa<vector::TransferReadOp>(Val: op)) |
| 264 | return readLatency; |
| 265 | return 1; |
| 266 | }; |
| 267 | |
| 268 | std::optional<int64_t> ubConstant = |
| 269 | getConstantIntValue(ofr: forOp.getUpperBound()); |
| 270 | std::optional<int64_t> lbConstant = |
| 271 | getConstantIntValue(ofr: forOp.getLowerBound()); |
| 272 | DenseMap<Operation *, unsigned> opCycles; |
| 273 | std::map<unsigned, std::vector<Operation *>> wrappedSchedule; |
| 274 | for (Operation &op : forOp.getBody()->getOperations()) { |
| 275 | if (isa<scf::YieldOp>(Val: op)) |
| 276 | continue; |
| 277 | unsigned earlyCycle = 0; |
| 278 | for (Value operand : op.getOperands()) { |
| 279 | Operation *def = operand.getDefiningOp(); |
| 280 | if (!def) |
| 281 | continue; |
| 282 | if (ubConstant && lbConstant) { |
| 283 | unsigned ubInt = ubConstant.value(); |
| 284 | unsigned lbInt = lbConstant.value(); |
| 285 | auto minLatency = std::min(a: ubInt - lbInt - 1, b: getLatency(def)); |
| 286 | earlyCycle = std::max(a: earlyCycle, b: opCycles[def] + minLatency); |
| 287 | } else { |
| 288 | earlyCycle = std::max(a: earlyCycle, b: opCycles[def] + getLatency(def)); |
| 289 | } |
| 290 | } |
| 291 | opCycles[&op] = earlyCycle; |
| 292 | wrappedSchedule[earlyCycle % iterationInterval].push_back(x: &op); |
| 293 | } |
| 294 | for (const auto &it : wrappedSchedule) { |
| 295 | for (Operation *op : it.second) { |
| 296 | unsigned cycle = opCycles[op]; |
| 297 | schedule.emplace_back(args&: op, args: cycle / iterationInterval); |
| 298 | } |
| 299 | } |
| 300 | } |
| 301 | |
| 302 | DiagnosedSilenceableFailure |
| 303 | transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter, |
| 304 | scf::ForOp target, |
| 305 | transform::ApplyToEachResultList &results, |
| 306 | transform::TransformState &state) { |
| 307 | scf::PipeliningOption options; |
| 308 | options.getScheduleFn = |
| 309 | [this](scf::ForOp forOp, |
| 310 | std::vector<std::pair<Operation *, unsigned>> &schedule) mutable { |
| 311 | loopScheduling(forOp, schedule, iterationInterval: getIterationInterval(), |
| 312 | readLatency: getReadLatency()); |
| 313 | }; |
| 314 | scf::ForLoopPipeliningPattern pattern(options, target->getContext()); |
| 315 | rewriter.setInsertionPoint(target); |
| 316 | FailureOr<scf::ForOp> patternResult = |
| 317 | scf::pipelineForLoop(rewriter, forOp: target, options); |
| 318 | if (succeeded(Result: patternResult)) { |
| 319 | results.push_back(op: *patternResult); |
| 320 | return DiagnosedSilenceableFailure::success(); |
| 321 | } |
| 322 | return emitDefaultSilenceableFailure(target); |
| 323 | } |
| 324 | |
| 325 | //===----------------------------------------------------------------------===// |
| 326 | // LoopPromoteIfOneIterationOp |
| 327 | //===----------------------------------------------------------------------===// |
| 328 | |
| 329 | DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne( |
| 330 | transform::TransformRewriter &rewriter, LoopLikeOpInterface target, |
| 331 | transform::ApplyToEachResultList &results, |
| 332 | transform::TransformState &state) { |
| 333 | (void)target.promoteIfSingleIteration(rewriter); |
| 334 | return DiagnosedSilenceableFailure::success(); |
| 335 | } |
| 336 | |
| 337 | void transform::LoopPromoteIfOneIterationOp::getEffects( |
| 338 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 339 | consumesHandle(handles: getTargetMutable(), effects); |
| 340 | modifiesPayload(effects); |
| 341 | } |
| 342 | |
| 343 | //===----------------------------------------------------------------------===// |
| 344 | // LoopUnrollOp |
| 345 | //===----------------------------------------------------------------------===// |
| 346 | |
| 347 | DiagnosedSilenceableFailure |
| 348 | transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, |
| 349 | Operation *op, |
| 350 | transform::ApplyToEachResultList &results, |
| 351 | transform::TransformState &state) { |
| 352 | LogicalResult result(failure()); |
| 353 | if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(Val: op)) |
| 354 | result = loopUnrollByFactor(forOp: scfFor, unrollFactor: getFactor()); |
| 355 | else if (AffineForOp affineFor = dyn_cast<AffineForOp>(Val: op)) |
| 356 | result = loopUnrollByFactor(forOp: affineFor, unrollFactor: getFactor()); |
| 357 | else |
| 358 | return emitSilenceableError() |
| 359 | << "failed to unroll, incorrect type of payload" ; |
| 360 | |
| 361 | if (failed(Result: result)) |
| 362 | return emitSilenceableError() << "failed to unroll" ; |
| 363 | |
| 364 | return DiagnosedSilenceableFailure::success(); |
| 365 | } |
| 366 | |
| 367 | //===----------------------------------------------------------------------===// |
| 368 | // LoopUnrollAndJamOp |
| 369 | //===----------------------------------------------------------------------===// |
| 370 | |
| 371 | DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne( |
| 372 | transform::TransformRewriter &rewriter, Operation *op, |
| 373 | transform::ApplyToEachResultList &results, |
| 374 | transform::TransformState &state) { |
| 375 | LogicalResult result(failure()); |
| 376 | if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(Val: op)) |
| 377 | result = loopUnrollJamByFactor(forOp: scfFor, unrollFactor: getFactor()); |
| 378 | else if (AffineForOp affineFor = dyn_cast<AffineForOp>(Val: op)) |
| 379 | result = loopUnrollJamByFactor(forOp: affineFor, unrollJamFactor: getFactor()); |
| 380 | else |
| 381 | return emitSilenceableError() |
| 382 | << "failed to unroll and jam, incorrect type of payload" ; |
| 383 | |
| 384 | if (failed(Result: result)) |
| 385 | return emitSilenceableError() << "failed to unroll and jam" ; |
| 386 | |
| 387 | return DiagnosedSilenceableFailure::success(); |
| 388 | } |
| 389 | |
| 390 | //===----------------------------------------------------------------------===// |
| 391 | // LoopCoalesceOp |
| 392 | //===----------------------------------------------------------------------===// |
| 393 | |
| 394 | DiagnosedSilenceableFailure |
| 395 | transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter, |
| 396 | Operation *op, |
| 397 | transform::ApplyToEachResultList &results, |
| 398 | transform::TransformState &state) { |
| 399 | LogicalResult result(failure()); |
| 400 | if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(Val: op)) |
| 401 | result = coalescePerfectlyNestedSCFForLoops(op: scfForOp); |
| 402 | else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(Val: op)) |
| 403 | result = coalescePerfectlyNestedAffineLoops(op: affineForOp); |
| 404 | |
| 405 | results.push_back(op); |
| 406 | if (failed(Result: result)) { |
| 407 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
| 408 | << "failed to coalesce" ; |
| 409 | return diag; |
| 410 | } |
| 411 | return DiagnosedSilenceableFailure::success(); |
| 412 | } |
| 413 | |
| 414 | //===----------------------------------------------------------------------===// |
| 415 | // TakeAssumedBranchOp |
| 416 | //===----------------------------------------------------------------------===// |
| 417 | /// Replaces the given op with the contents of the given single-block region, |
| 418 | /// using the operands of the block terminator to replace operation results. |
| 419 | static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, |
| 420 | Region ®ion) { |
| 421 | assert(llvm::hasSingleElement(region) && "expected single-region block" ); |
| 422 | Block *block = ®ion.front(); |
| 423 | Operation *terminator = block->getTerminator(); |
| 424 | ValueRange results = terminator->getOperands(); |
| 425 | rewriter.inlineBlockBefore(source: block, op, /*blockArgs=*/argValues: {}); |
| 426 | rewriter.replaceOp(op, newValues: results); |
| 427 | rewriter.eraseOp(op: terminator); |
| 428 | } |
| 429 | |
| 430 | DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne( |
| 431 | transform::TransformRewriter &rewriter, scf::IfOp ifOp, |
| 432 | transform::ApplyToEachResultList &results, |
| 433 | transform::TransformState &state) { |
| 434 | rewriter.setInsertionPoint(ifOp); |
| 435 | Region ®ion = |
| 436 | getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion(); |
| 437 | if (!llvm::hasSingleElement(C&: region)) { |
| 438 | return emitDefiniteFailure() |
| 439 | << "requires an scf.if op with a single-block " |
| 440 | << ((getTakeElseBranch()) ? "`else`" : "`then`" ) << " region" ; |
| 441 | } |
| 442 | replaceOpWithRegion(rewriter, op: ifOp, region); |
| 443 | return DiagnosedSilenceableFailure::success(); |
| 444 | } |
| 445 | |
| 446 | void transform::TakeAssumedBranchOp::getEffects( |
| 447 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 448 | onlyReadsHandle(handles: getTargetMutable(), effects); |
| 449 | modifiesPayload(effects); |
| 450 | } |
| 451 | |
| 452 | //===----------------------------------------------------------------------===// |
| 453 | // LoopFuseSiblingOp |
| 454 | //===----------------------------------------------------------------------===// |
| 455 | |
| 456 | /// Check if `target` and `source` are siblings, in the context that `target` |
| 457 | /// is being fused into `source`. |
| 458 | /// |
| 459 | /// This is a simple check that just checks if both operations are in the same |
| 460 | /// block and some checks to ensure that the fused IR does not violate |
| 461 | /// dominance. |
| 462 | static DiagnosedSilenceableFailure isOpSibling(Operation *target, |
| 463 | Operation *source) { |
| 464 | // Check if both operations are same. |
| 465 | if (target == source) |
| 466 | return emitSilenceableFailure(op: source) |
| 467 | << "target and source need to be different loops" ; |
| 468 | |
| 469 | // Check if both operations are in the same block. |
| 470 | if (target->getBlock() != source->getBlock()) |
| 471 | return emitSilenceableFailure(op: source) |
| 472 | << "target and source are not in the same block" ; |
| 473 | |
| 474 | // Check if fusion will violate dominance. |
| 475 | DominanceInfo domInfo(source); |
| 476 | if (target->isBeforeInBlock(other: source)) { |
| 477 | // Since `target` is before `source`, all users of results of `target` |
| 478 | // need to be dominated by `source`. |
| 479 | for (Operation *user : target->getUsers()) { |
| 480 | if (!domInfo.properlyDominates(a: source, b: user, /*enclosingOpOk=*/false)) { |
| 481 | return emitSilenceableFailure(op: target) |
| 482 | << "user of results of target should be properly dominated by " |
| 483 | "source" ; |
| 484 | } |
| 485 | } |
| 486 | } else { |
| 487 | // Since `target` is after `source`, all values used by `target` need |
| 488 | // to dominate `source`. |
| 489 | |
| 490 | // Check if operands of `target` are dominated by `source`. |
| 491 | for (Value operand : target->getOperands()) { |
| 492 | Operation *operandOp = operand.getDefiningOp(); |
| 493 | // Operands without defining operations are block arguments. When `target` |
| 494 | // and `source` occur in the same block, these operands dominate `source`. |
| 495 | if (!operandOp) |
| 496 | continue; |
| 497 | |
| 498 | // Operand's defining operation should properly dominate `source`. |
| 499 | if (!domInfo.properlyDominates(a: operandOp, b: source, |
| 500 | /*enclosingOpOk=*/false)) |
| 501 | return emitSilenceableFailure(op: target) |
| 502 | << "operands of target should be properly dominated by source" ; |
| 503 | } |
| 504 | |
| 505 | // Check if values used by `target` are dominated by `source`. |
| 506 | bool failed = false; |
| 507 | OpOperand *failedValue = nullptr; |
| 508 | visitUsedValuesDefinedAbove(regions: target->getRegions(), callback: [&](OpOperand *operand) { |
| 509 | Operation *operandOp = operand->get().getDefiningOp(); |
| 510 | if (operandOp && !domInfo.properlyDominates(a: operandOp, b: source, |
| 511 | /*enclosingOpOk=*/false)) { |
| 512 | // `operand` is not an argument of an enclosing block and the defining |
| 513 | // op of `operand` is outside `target` but does not dominate `source`. |
| 514 | failed = true; |
| 515 | failedValue = operand; |
| 516 | } |
| 517 | }); |
| 518 | |
| 519 | if (failed) |
| 520 | return emitSilenceableFailure(op: failedValue->getOwner()) |
| 521 | << "values used inside regions of target should be properly " |
| 522 | "dominated by source" ; |
| 523 | } |
| 524 | |
| 525 | return DiagnosedSilenceableFailure::success(); |
| 526 | } |
| 527 | |
| 528 | /// Check if `target` scf.forall can be fused into `source` scf.forall. |
| 529 | /// |
| 530 | /// This simply checks if both loops have the same bounds, steps and mapping. |
| 531 | /// No attempt is made at checking that the side effects of `target` and |
| 532 | /// `source` are independent of each other. |
| 533 | static bool isForallWithIdenticalConfiguration(Operation *target, |
| 534 | Operation *source) { |
| 535 | auto targetOp = dyn_cast<scf::ForallOp>(Val: target); |
| 536 | auto sourceOp = dyn_cast<scf::ForallOp>(Val: source); |
| 537 | if (!targetOp || !sourceOp) |
| 538 | return false; |
| 539 | |
| 540 | return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && |
| 541 | targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && |
| 542 | targetOp.getMixedStep() == sourceOp.getMixedStep() && |
| 543 | targetOp.getMapping() == sourceOp.getMapping(); |
| 544 | } |
| 545 | |
| 546 | /// Check if `target` scf.for can be fused into `source` scf.for. |
| 547 | /// |
| 548 | /// This simply checks if both loops have the same bounds and steps. No attempt |
| 549 | /// is made at checking that the side effects of `target` and `source` are |
| 550 | /// independent of each other. |
| 551 | static bool isForWithIdenticalConfiguration(Operation *target, |
| 552 | Operation *source) { |
| 553 | auto targetOp = dyn_cast<scf::ForOp>(Val: target); |
| 554 | auto sourceOp = dyn_cast<scf::ForOp>(Val: source); |
| 555 | if (!targetOp || !sourceOp) |
| 556 | return false; |
| 557 | |
| 558 | return targetOp.getLowerBound() == sourceOp.getLowerBound() && |
| 559 | targetOp.getUpperBound() == sourceOp.getUpperBound() && |
| 560 | targetOp.getStep() == sourceOp.getStep(); |
| 561 | } |
| 562 | |
| 563 | DiagnosedSilenceableFailure |
| 564 | transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, |
| 565 | transform::TransformResults &results, |
| 566 | transform::TransformState &state) { |
| 567 | auto targetOps = state.getPayloadOps(value: getTarget()); |
| 568 | auto sourceOps = state.getPayloadOps(value: getSource()); |
| 569 | |
| 570 | if (!llvm::hasSingleElement(C&: targetOps) || |
| 571 | !llvm::hasSingleElement(C&: sourceOps)) { |
| 572 | return emitDefiniteFailure() |
| 573 | << "requires exactly one target handle (got " |
| 574 | << llvm::range_size(Range&: targetOps) << ") and exactly one " |
| 575 | << "source handle (got " << llvm::range_size(Range&: sourceOps) << ")" ; |
| 576 | } |
| 577 | |
| 578 | Operation *target = *targetOps.begin(); |
| 579 | Operation *source = *sourceOps.begin(); |
| 580 | |
| 581 | // Check if the target and source are siblings. |
| 582 | DiagnosedSilenceableFailure diag = isOpSibling(target, source); |
| 583 | if (!diag.succeeded()) |
| 584 | return diag; |
| 585 | |
| 586 | Operation *fusedLoop; |
| 587 | /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. |
| 588 | if (isForWithIdenticalConfiguration(target, source)) { |
| 589 | fusedLoop = fuseIndependentSiblingForLoops( |
| 590 | target: cast<scf::ForOp>(Val: target), source: cast<scf::ForOp>(Val: source), rewriter); |
| 591 | } else if (isForallWithIdenticalConfiguration(target, source)) { |
| 592 | fusedLoop = fuseIndependentSiblingForallLoops( |
| 593 | target: cast<scf::ForallOp>(Val: target), source: cast<scf::ForallOp>(Val: source), rewriter); |
| 594 | } else { |
| 595 | return emitSilenceableFailure(loc: target->getLoc()) |
| 596 | << "operations cannot be fused" ; |
| 597 | } |
| 598 | |
| 599 | assert(fusedLoop && "failed to fuse operations" ); |
| 600 | |
| 601 | results.set(value: cast<OpResult>(Val: getFusedLoop()), ops: {fusedLoop}); |
| 602 | return DiagnosedSilenceableFailure::success(); |
| 603 | } |
| 604 | |
| 605 | //===----------------------------------------------------------------------===// |
| 606 | // Transform op registration |
| 607 | //===----------------------------------------------------------------------===// |
| 608 | |
| 609 | namespace { |
| 610 | class SCFTransformDialectExtension |
| 611 | : public transform::TransformDialectExtension< |
| 612 | SCFTransformDialectExtension> { |
| 613 | public: |
| 614 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension) |
| 615 | |
| 616 | using Base::Base; |
| 617 | |
| 618 | void init() { |
| 619 | declareGeneratedDialect<affine::AffineDialect>(); |
| 620 | declareGeneratedDialect<func::FuncDialect>(); |
| 621 | |
| 622 | registerTransformOps< |
| 623 | #define GET_OP_LIST |
| 624 | #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" |
| 625 | >(); |
| 626 | } |
| 627 | }; |
| 628 | } // namespace |
| 629 | |
| 630 | #define GET_OP_CLASSES |
| 631 | #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" |
| 632 | |
| 633 | void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) { |
| 634 | registry.addExtensions<SCFTransformDialectExtension>(); |
| 635 | } |
| 636 | |