| 1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 13 | #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" |
| 14 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
| 15 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 19 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 20 | #include "mlir/IR/Dialect.h" |
| 21 | #include "mlir/IR/Operation.h" |
| 22 | #include "mlir/IR/PatternMatch.h" |
| 23 | |
| 24 | using namespace mlir; |
| 25 | using namespace mlir::bufferization; |
| 26 | using namespace mlir::scf; |
| 27 | |
| 28 | namespace mlir { |
| 29 | namespace scf { |
| 30 | namespace { |
| 31 | |
| 32 | /// Helper function for loop bufferization. Cast the given buffer to the given |
| 33 | /// memref type. |
| 34 | static Value castBuffer(OpBuilder &b, Value buffer, Type type) { |
| 35 | assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType" ); |
| 36 | assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType" ); |
| 37 | // If the buffer already has the correct type, no cast is needed. |
| 38 | if (buffer.getType() == type) |
| 39 | return buffer; |
| 40 | // TODO: In case `type` has a layout map that is not the fully dynamic |
| 41 | // one, we may not be able to cast the buffer. In that case, the loop |
| 42 | // iter_arg's layout map must be changed (see uses of `castBuffer`). |
| 43 | assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && |
| 44 | "scf.while op bufferization: cast incompatible" ); |
| 45 | return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); |
| 46 | } |
| 47 | |
| 48 | /// Helper function for loop bufferization. Return "true" if the given value |
| 49 | /// is guaranteed to not alias with an external tensor apart from values in |
| 50 | /// `exceptions`. A value is external if it is defined outside of the given |
| 51 | /// region or if it is an entry block argument of the region. |
| 52 | static bool doesNotAliasExternalValue(Value value, Region *region, |
| 53 | ValueRange exceptions, |
| 54 | const OneShotAnalysisState &state) { |
| 55 | assert(llvm::hasSingleElement(region->getBlocks()) && |
| 56 | "expected region with single block" ); |
| 57 | bool result = true; |
| 58 | state.applyOnAliases(v: value, fun: [&](Value alias) { |
| 59 | if (llvm::is_contained(Range&: exceptions, Element: alias)) |
| 60 | return; |
| 61 | Region *aliasRegion = alias.getParentRegion(); |
| 62 | if (isa<BlockArgument>(Val: alias) && !region->isProperAncestor(other: aliasRegion)) |
| 63 | result = false; |
| 64 | if (isa<OpResult>(Val: alias) && !region->isAncestor(other: aliasRegion)) |
| 65 | result = false; |
| 66 | }); |
| 67 | return result; |
| 68 | } |
| 69 | |
| 70 | /// Bufferization of scf.condition. |
| 71 | struct ConditionOpInterface |
| 72 | : public BufferizableOpInterface::ExternalModel<ConditionOpInterface, |
| 73 | scf::ConditionOp> { |
| 74 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| 75 | const AnalysisState &state) const { |
| 76 | return true; |
| 77 | } |
| 78 | |
| 79 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| 80 | const AnalysisState &state) const { |
| 81 | return false; |
| 82 | } |
| 83 | |
| 84 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| 85 | const AnalysisState &state) const { |
| 86 | return {}; |
| 87 | } |
| 88 | |
| 89 | bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
| 90 | const AnalysisState &state) const { |
| 91 | // Condition operands always bufferize inplace. Otherwise, an alloc + copy |
| 92 | // may be generated inside the block. We should not return/yield allocations |
| 93 | // when possible. |
| 94 | return true; |
| 95 | } |
| 96 | |
| 97 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 98 | const BufferizationOptions &options, |
| 99 | BufferizationState &state) const { |
| 100 | auto conditionOp = cast<scf::ConditionOp>(op); |
| 101 | auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp()); |
| 102 | |
| 103 | SmallVector<Value> newArgs; |
| 104 | for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { |
| 105 | Value value = it.value(); |
| 106 | if (isa<TensorType>(value.getType())) { |
| 107 | FailureOr<Value> maybeBuffer = |
| 108 | getBuffer(rewriter, value, options, state); |
| 109 | if (failed(maybeBuffer)) |
| 110 | return failure(); |
| 111 | FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
| 112 | whileOp.getAfterArguments()[it.index()], options, state); |
| 113 | if (failed(resultType)) |
| 114 | return failure(); |
| 115 | Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType); |
| 116 | newArgs.push_back(buffer); |
| 117 | } else { |
| 118 | newArgs.push_back(value); |
| 119 | } |
| 120 | } |
| 121 | |
| 122 | replaceOpWithNewBufferizedOp<scf::ConditionOp>( |
| 123 | rewriter, op, conditionOp.getCondition(), newArgs); |
| 124 | return success(); |
| 125 | } |
| 126 | }; |
| 127 | |
| 128 | /// Return the unique scf.yield op. If there are multiple or no scf.yield ops, |
| 129 | /// return an empty op. |
| 130 | static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) { |
| 131 | scf::YieldOp result; |
| 132 | for (Block &block : executeRegionOp.getRegion()) { |
| 133 | if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) { |
| 134 | if (result) |
| 135 | return {}; |
| 136 | result = yieldOp; |
| 137 | } |
| 138 | } |
| 139 | return result; |
| 140 | } |
| 141 | |
| 142 | /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not |
| 143 | /// fully implemented at the moment. |
| 144 | struct ExecuteRegionOpInterface |
| 145 | : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< |
| 146 | ExecuteRegionOpInterface, scf::ExecuteRegionOp> { |
| 147 | |
| 148 | static bool supportsUnstructuredControlFlow() { return true; } |
| 149 | |
| 150 | bool isWritable(Operation *op, Value value, |
| 151 | const AnalysisState &state) const { |
| 152 | return true; |
| 153 | } |
| 154 | |
| 155 | LogicalResult verifyAnalysis(Operation *op, |
| 156 | const AnalysisState &state) const { |
| 157 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
| 158 | // TODO: scf.execute_region with multiple yields are not supported. |
| 159 | if (!getUniqueYieldOp(executeRegionOp)) |
| 160 | return op->emitOpError(message: "op without unique scf.yield is not supported" ); |
| 161 | return success(); |
| 162 | } |
| 163 | |
| 164 | AliasingOpOperandList |
| 165 | getAliasingOpOperands(Operation *op, Value value, |
| 166 | const AnalysisState &state) const { |
| 167 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) |
| 168 | return getAliasingBranchOpOperands(op, bbArg, state); |
| 169 | |
| 170 | // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be |
| 171 | // any SSA value that is in scope. To allow for use-def chain traversal |
| 172 | // through ExecuteRegionOps in the analysis, the corresponding yield value |
| 173 | // is considered to be aliasing with the result. |
| 174 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
| 175 | auto it = llvm::find(Range: op->getOpResults(), Val: value); |
| 176 | assert(it != op->getOpResults().end() && "invalid value" ); |
| 177 | size_t resultNum = std::distance(first: op->getOpResults().begin(), last: it); |
| 178 | auto yieldOp = getUniqueYieldOp(executeRegionOp); |
| 179 | // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail. |
| 180 | if (!yieldOp) |
| 181 | return {}; |
| 182 | return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; |
| 183 | } |
| 184 | |
| 185 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 186 | const BufferizationOptions &options, |
| 187 | BufferizationState &state) const { |
| 188 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
| 189 | auto yieldOp = getUniqueYieldOp(executeRegionOp); |
| 190 | TypeRange newResultTypes(yieldOp.getResults()); |
| 191 | |
| 192 | // Create new op and move over region. |
| 193 | auto newOp = |
| 194 | rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); |
| 195 | newOp.getRegion().takeBody(executeRegionOp.getRegion()); |
| 196 | |
| 197 | // Bufferize every block. |
| 198 | for (Block &block : newOp.getRegion()) |
| 199 | if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, |
| 200 | options, state))) |
| 201 | return failure(); |
| 202 | |
| 203 | // Update all uses of the old op. |
| 204 | rewriter.setInsertionPointAfter(newOp); |
| 205 | SmallVector<Value> newResults; |
| 206 | for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { |
| 207 | if (isa<TensorType>(it.value())) { |
| 208 | newResults.push_back(rewriter.create<bufferization::ToTensorOp>( |
| 209 | executeRegionOp.getLoc(), it.value(), |
| 210 | newOp->getResult(it.index()))); |
| 211 | } else { |
| 212 | newResults.push_back(newOp->getResult(it.index())); |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | // Replace old op. |
| 217 | rewriter.replaceOp(executeRegionOp, newResults); |
| 218 | |
| 219 | return success(); |
| 220 | } |
| 221 | }; |
| 222 | |
| 223 | /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. |
| 224 | struct IfOpInterface |
| 225 | : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { |
| 226 | AliasingOpOperandList |
| 227 | getAliasingOpOperands(Operation *op, Value value, |
| 228 | const AnalysisState &state) const { |
| 229 | // IfOps do not have tensor OpOperands. The yielded value can be any SSA |
| 230 | // value that is in scope. To allow for use-def chain traversal through |
| 231 | // IfOps in the analysis, both corresponding yield values from the then/else |
| 232 | // branches are considered to be aliasing with the result. |
| 233 | auto ifOp = cast<scf::IfOp>(op); |
| 234 | size_t resultNum = std::distance(first: op->getOpResults().begin(), |
| 235 | last: llvm::find(Range: op->getOpResults(), Val: value)); |
| 236 | OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum); |
| 237 | OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum); |
| 238 | return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false}, |
| 239 | {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}}; |
| 240 | } |
| 241 | |
| 242 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 243 | const BufferizationOptions &options, |
| 244 | BufferizationState &state) const { |
| 245 | OpBuilder::InsertionGuard g(rewriter); |
| 246 | auto ifOp = cast<scf::IfOp>(op); |
| 247 | |
| 248 | // Compute bufferized result types. |
| 249 | SmallVector<Type> newTypes; |
| 250 | for (Value result : ifOp.getResults()) { |
| 251 | if (!isa<TensorType>(result.getType())) { |
| 252 | newTypes.push_back(result.getType()); |
| 253 | continue; |
| 254 | } |
| 255 | auto bufferType = bufferization::getBufferType(result, options, state); |
| 256 | if (failed(bufferType)) |
| 257 | return failure(); |
| 258 | newTypes.push_back(*bufferType); |
| 259 | } |
| 260 | |
| 261 | // Create new op. |
| 262 | rewriter.setInsertionPoint(ifOp); |
| 263 | auto newIfOp = |
| 264 | rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), |
| 265 | /*withElseRegion=*/true); |
| 266 | |
| 267 | // Move over then/else blocks. |
| 268 | rewriter.mergeBlocks(source: ifOp.thenBlock(), dest: newIfOp.thenBlock()); |
| 269 | rewriter.mergeBlocks(source: ifOp.elseBlock(), dest: newIfOp.elseBlock()); |
| 270 | |
| 271 | // Replace op results. |
| 272 | replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); |
| 273 | |
| 274 | return success(); |
| 275 | } |
| 276 | |
| 277 | FailureOr<BaseMemRefType> |
| 278 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| 279 | const BufferizationState &state, |
| 280 | SmallVector<Value> &invocationStack) const { |
| 281 | auto ifOp = cast<scf::IfOp>(op); |
| 282 | auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator()); |
| 283 | auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator()); |
| 284 | assert(value.getDefiningOp() == op && "invalid valid" ); |
| 285 | |
| 286 | // Determine buffer types of the true/false branches. |
| 287 | auto opResult = cast<OpResult>(Val&: value); |
| 288 | auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); |
| 289 | auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); |
| 290 | BaseMemRefType thenBufferType, elseBufferType; |
| 291 | if (isa<BaseMemRefType>(thenValue.getType())) { |
| 292 | // True branch was already bufferized. |
| 293 | thenBufferType = cast<BaseMemRefType>(thenValue.getType()); |
| 294 | } else { |
| 295 | auto maybeBufferType = bufferization::getBufferType( |
| 296 | value: thenValue, options, state, invocationStack); |
| 297 | if (failed(maybeBufferType)) |
| 298 | return failure(); |
| 299 | thenBufferType = *maybeBufferType; |
| 300 | } |
| 301 | if (isa<BaseMemRefType>(elseValue.getType())) { |
| 302 | // False branch was already bufferized. |
| 303 | elseBufferType = cast<BaseMemRefType>(elseValue.getType()); |
| 304 | } else { |
| 305 | auto maybeBufferType = bufferization::getBufferType( |
| 306 | value: elseValue, options, state, invocationStack); |
| 307 | if (failed(maybeBufferType)) |
| 308 | return failure(); |
| 309 | elseBufferType = *maybeBufferType; |
| 310 | } |
| 311 | |
| 312 | // Best case: Both branches have the exact same buffer type. |
| 313 | if (thenBufferType == elseBufferType) |
| 314 | return thenBufferType; |
| 315 | |
| 316 | // Memory space mismatch. |
| 317 | if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace()) |
| 318 | return op->emitError(message: "inconsistent memory space on then/else branches" ); |
| 319 | |
| 320 | // Layout maps are different: Promote to fully dynamic layout map. |
| 321 | return getMemRefTypeWithFullyDynamicLayout( |
| 322 | tensorType: cast<TensorType>(Val: opResult.getType()), memorySpace: thenBufferType.getMemorySpace()); |
| 323 | } |
| 324 | }; |
| 325 | |
| 326 | /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that |
| 327 | /// yields memrefs. |
| 328 | struct IndexSwitchOpInterface |
| 329 | : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface, |
| 330 | scf::IndexSwitchOp> { |
| 331 | AliasingOpOperandList |
| 332 | getAliasingOpOperands(Operation *op, Value value, |
| 333 | const AnalysisState &state) const { |
| 334 | // IndexSwitchOps do not have tensor OpOperands. The yielded value can be |
| 335 | // any SSA. This is similar to IfOps. |
| 336 | auto switchOp = cast<scf::IndexSwitchOp>(op); |
| 337 | int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber(); |
| 338 | AliasingOpOperandList result; |
| 339 | for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { |
| 340 | auto yieldOp = |
| 341 | cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator()); |
| 342 | result.addAlias(alias: AliasingOpOperand(&yieldOp->getOpOperand(resultNum), |
| 343 | BufferRelation::Equivalent, |
| 344 | /*isDefinite=*/false)); |
| 345 | } |
| 346 | auto defaultYieldOp = |
| 347 | cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator()); |
| 348 | result.addAlias(alias: AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum), |
| 349 | BufferRelation::Equivalent, |
| 350 | /*isDefinite=*/false)); |
| 351 | return result; |
| 352 | } |
| 353 | |
| 354 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 355 | const BufferizationOptions &options, |
| 356 | BufferizationState &state) const { |
| 357 | OpBuilder::InsertionGuard g(rewriter); |
| 358 | auto switchOp = cast<scf::IndexSwitchOp>(op); |
| 359 | |
| 360 | // Compute bufferized result types. |
| 361 | SmallVector<Type> newTypes; |
| 362 | for (Value result : switchOp.getResults()) { |
| 363 | if (!isa<TensorType>(result.getType())) { |
| 364 | newTypes.push_back(result.getType()); |
| 365 | continue; |
| 366 | } |
| 367 | auto bufferType = bufferization::getBufferType(result, options, state); |
| 368 | if (failed(bufferType)) |
| 369 | return failure(); |
| 370 | newTypes.push_back(*bufferType); |
| 371 | } |
| 372 | |
| 373 | // Create new op. |
| 374 | rewriter.setInsertionPoint(switchOp); |
| 375 | auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>( |
| 376 | switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(), |
| 377 | switchOp.getCases().size()); |
| 378 | |
| 379 | // Move over blocks. |
| 380 | for (auto [src, dest] : |
| 381 | llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions())) |
| 382 | rewriter.inlineRegionBefore(src, dest, dest.begin()); |
| 383 | rewriter.inlineRegionBefore(switchOp.getDefaultRegion(), |
| 384 | newSwitchOp.getDefaultRegion(), |
| 385 | newSwitchOp.getDefaultRegion().begin()); |
| 386 | |
| 387 | // Replace op results. |
| 388 | replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults()); |
| 389 | |
| 390 | return success(); |
| 391 | } |
| 392 | |
| 393 | FailureOr<BaseMemRefType> |
| 394 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| 395 | const BufferizationState &state, |
| 396 | SmallVector<Value> &invocationStack) const { |
| 397 | auto switchOp = cast<scf::IndexSwitchOp>(op); |
| 398 | assert(value.getDefiningOp() == op && "invalid value" ); |
| 399 | int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber(); |
| 400 | |
| 401 | // Helper function to get buffer type of a case. |
| 402 | auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> { |
| 403 | auto yieldOp = cast<scf::YieldOp>(b.getTerminator()); |
| 404 | Value yieldedValue = yieldOp->getOperand(resultNum); |
| 405 | if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType())) |
| 406 | return bufferType; |
| 407 | auto maybeBufferType = bufferization::getBufferType( |
| 408 | yieldedValue, options, state, invocationStack); |
| 409 | if (failed(maybeBufferType)) |
| 410 | return failure(); |
| 411 | return maybeBufferType; |
| 412 | }; |
| 413 | |
| 414 | // Compute buffer type of the default case. |
| 415 | auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock()); |
| 416 | if (failed(maybeBufferType)) |
| 417 | return failure(); |
| 418 | BaseMemRefType bufferType = *maybeBufferType; |
| 419 | |
| 420 | // Compute buffer types of all other cases. |
| 421 | for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { |
| 422 | auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i)); |
| 423 | if (failed(yieldedBufferType)) |
| 424 | return failure(); |
| 425 | |
| 426 | // Best case: Both branches have the exact same buffer type. |
| 427 | if (bufferType == *yieldedBufferType) |
| 428 | continue; |
| 429 | |
| 430 | // Memory space mismatch. |
| 431 | if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace()) |
| 432 | return op->emitError(message: "inconsistent memory space on switch cases" ); |
| 433 | |
| 434 | // Layout maps are different: Promote to fully dynamic layout map. |
| 435 | bufferType = getMemRefTypeWithFullyDynamicLayout( |
| 436 | tensorType: cast<TensorType>(Val: value.getType()), memorySpace: bufferType.getMemorySpace()); |
| 437 | } |
| 438 | |
| 439 | return bufferType; |
| 440 | } |
| 441 | }; |
| 442 | |
| 443 | /// Helper function for loop bufferization. Return the indices of all values |
| 444 | /// that have a tensor type. |
| 445 | static DenseSet<int64_t> getTensorIndices(ValueRange values) { |
| 446 | DenseSet<int64_t> result; |
| 447 | for (const auto &it : llvm::enumerate(First&: values)) |
| 448 | if (isa<TensorType>(Val: it.value().getType())) |
| 449 | result.insert(V: it.index()); |
| 450 | return result; |
| 451 | } |
| 452 | |
| 453 | /// Helper function for loop bufferization. Return the indices of all |
| 454 | /// bbArg/yielded value pairs who's buffer relation is "Equivalent". |
| 455 | DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, |
| 456 | ValueRange yieldedValues, |
| 457 | const AnalysisState &state) { |
| 458 | unsigned int minSize = std::min(a: bbArgs.size(), b: yieldedValues.size()); |
| 459 | DenseSet<int64_t> result; |
| 460 | for (unsigned int i = 0; i < minSize; ++i) { |
| 461 | if (!isa<TensorType>(Val: bbArgs[i].getType()) || |
| 462 | !isa<TensorType>(Val: yieldedValues[i].getType())) |
| 463 | continue; |
| 464 | if (state.areEquivalentBufferizedValues(v1: bbArgs[i], v2: yieldedValues[i])) |
| 465 | result.insert(V: i); |
| 466 | } |
| 467 | return result; |
| 468 | } |
| 469 | |
| 470 | /// Helper function for loop bufferization. Return the bufferized values of the |
| 471 | /// given OpOperands. If an operand is not a tensor, return the original value. |
| 472 | static FailureOr<SmallVector<Value>> |
| 473 | getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, |
| 474 | const BufferizationOptions &options, BufferizationState &state) { |
| 475 | SmallVector<Value> result; |
| 476 | for (OpOperand &opOperand : operands) { |
| 477 | if (isa<TensorType>(Val: opOperand.get().getType())) { |
| 478 | FailureOr<Value> resultBuffer = |
| 479 | getBuffer(rewriter, value: opOperand.get(), options, state); |
| 480 | if (failed(Result: resultBuffer)) |
| 481 | return failure(); |
| 482 | result.push_back(Elt: *resultBuffer); |
| 483 | } else { |
| 484 | result.push_back(Elt: opOperand.get()); |
| 485 | } |
| 486 | } |
| 487 | return result; |
| 488 | } |
| 489 | |
| 490 | /// Helper function for loop bufferization. Given a list of bbArgs of the new |
| 491 | /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into |
| 492 | /// ToTensorOps, so that the block body can be moved over to the new op. |
| 493 | static SmallVector<Value> |
| 494 | getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, |
| 495 | Block::BlockArgListType oldBbArgs, |
| 496 | const DenseSet<int64_t> &tensorIndices) { |
| 497 | SmallVector<Value> result; |
| 498 | for (const auto &it : llvm::enumerate(First&: bbArgs)) { |
| 499 | size_t idx = it.index(); |
| 500 | Value val = it.value(); |
| 501 | if (tensorIndices.contains(V: idx)) { |
| 502 | result.push_back(rewriter |
| 503 | .create<bufferization::ToTensorOp>( |
| 504 | val.getLoc(), oldBbArgs[idx].getType(), val) |
| 505 | .getResult()); |
| 506 | } else { |
| 507 | result.push_back(Elt: val); |
| 508 | } |
| 509 | } |
| 510 | return result; |
| 511 | } |
| 512 | |
| 513 | /// Compute the bufferized type of a loop iter_arg. This type must be equal to |
| 514 | /// the bufferized type of the corresponding init_arg and the bufferized type |
| 515 | /// of the corresponding yielded value. |
| 516 | /// |
| 517 | /// This function uses bufferization::getBufferType to compute the bufferized |
| 518 | /// type of the init_arg and of the yielded value. (The computation of the |
| 519 | /// bufferized yielded value type usually requires computing the bufferized type |
| 520 | /// of the iter_arg again; the implementation of getBufferType traces back the |
| 521 | /// use-def chain of the given value and computes a buffer type along the way.) |
| 522 | /// If both buffer types are equal, no casts are needed the computed buffer type |
| 523 | /// can be used directly. Otherwise, the buffer types can only differ in their |
| 524 | /// layout map and a cast must be inserted. |
| 525 | static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType( |
| 526 | Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue, |
| 527 | const BufferizationOptions &options, const BufferizationState &state, |
| 528 | SmallVector<Value> &invocationStack) { |
| 529 | // Determine the buffer type of the init_arg. |
| 530 | auto initArgBufferType = |
| 531 | bufferization::getBufferType(initArg, options, state, invocationStack); |
| 532 | if (failed(initArgBufferType)) |
| 533 | return failure(); |
| 534 | |
| 535 | if (llvm::count(Range&: invocationStack, Element: iterArg) >= 2) { |
| 536 | // If the iter_arg is already twice on the invocation stack, just take the |
| 537 | // type of the init_arg. This is to avoid infinite loops when calculating |
| 538 | // the buffer type. This will most likely result in computing a memref type |
| 539 | // with a fully dynamic layout map. |
| 540 | |
| 541 | // Note: For more precise layout map computation, a fixpoint iteration could |
| 542 | // be done (i.e., re-computing the yielded buffer type until the bufferized |
| 543 | // iter_arg type no longer changes). This current implementation immediately |
| 544 | // switches to a fully dynamic layout map when a mismatch between bufferized |
| 545 | // init_arg type and bufferized yield value type is detected. |
| 546 | return *initArgBufferType; |
| 547 | } |
| 548 | |
| 549 | // Compute the buffer type of the yielded value. |
| 550 | BaseMemRefType yieldedValueBufferType; |
| 551 | if (isa<BaseMemRefType>(Val: yieldedValue.getType())) { |
| 552 | // scf.yield was already bufferized. |
| 553 | yieldedValueBufferType = cast<BaseMemRefType>(Val: yieldedValue.getType()); |
| 554 | } else { |
| 555 | // Note: This typically triggers a recursive call for the buffer type of |
| 556 | // the iter_arg. |
| 557 | auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, |
| 558 | state, invocationStack); |
| 559 | if (failed(maybeBufferType)) |
| 560 | return failure(); |
| 561 | yieldedValueBufferType = *maybeBufferType; |
| 562 | } |
| 563 | |
| 564 | // If yielded type and init_arg type are the same, use that type directly. |
| 565 | if (*initArgBufferType == yieldedValueBufferType) |
| 566 | return yieldedValueBufferType; |
| 567 | |
| 568 | // If there is a mismatch between the yielded buffer type and the init_arg |
| 569 | // buffer type, the buffer type must be promoted to a fully dynamic layout |
| 570 | // map. |
| 571 | auto yieldedBufferType = cast<BaseMemRefType>(Val&: yieldedValueBufferType); |
| 572 | auto iterTensorType = cast<TensorType>(Val: iterArg.getType()); |
| 573 | auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType); |
| 574 | if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace()) |
| 575 | return loopOp->emitOpError( |
| 576 | message: "init_arg and yielded value bufferize to inconsistent memory spaces" ); |
| 577 | #ifndef NDEBUG |
| 578 | if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) { |
| 579 | assert( |
| 580 | llvm::all_equal({yieldedRankedBufferType.getShape(), |
| 581 | cast<MemRefType>(initBufferType).getShape(), |
| 582 | cast<RankedTensorType>(iterTensorType).getShape()}) && |
| 583 | "expected same shape" ); |
| 584 | } |
| 585 | #endif // NDEBUG |
| 586 | return getMemRefTypeWithFullyDynamicLayout( |
| 587 | tensorType: iterTensorType, memorySpace: yieldedBufferType.getMemorySpace()); |
| 588 | } |
| 589 | |
| 590 | /// Return `true` if the given loop may have 0 iterations. |
| 591 | bool mayHaveZeroIterations(scf::ForOp forOp) { |
| 592 | std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound()); |
| 593 | std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound()); |
| 594 | if (!lb.has_value() || !ub.has_value()) |
| 595 | return true; |
| 596 | return *ub <= *lb; |
| 597 | } |
| 598 | |
| 599 | /// Bufferization of scf.for. Replace with a new scf.for that operates on |
| 600 | /// memrefs. |
| 601 | struct ForOpInterface |
| 602 | : public BufferizableOpInterface::ExternalModel<ForOpInterface, |
| 603 | scf::ForOp> { |
| 604 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| 605 | const AnalysisState &state) const { |
| 606 | auto forOp = cast<scf::ForOp>(op); |
| 607 | |
| 608 | // If the loop has zero iterations, the results of the op are their |
| 609 | // corresponding init_args, meaning that the init_args bufferize to a read. |
| 610 | if (mayHaveZeroIterations(forOp)) |
| 611 | return true; |
| 612 | |
| 613 | // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of |
| 614 | // its matching bbArg may. |
| 615 | return state.isValueRead(value: forOp.getTiedLoopRegionIterArg(&opOperand)); |
| 616 | } |
| 617 | |
| 618 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| 619 | const AnalysisState &state) const { |
| 620 | // Tensor iter_args of scf::ForOps are always considered as a write. |
| 621 | return true; |
| 622 | } |
| 623 | |
| 624 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| 625 | const AnalysisState &state) const { |
| 626 | auto forOp = cast<scf::ForOp>(op); |
| 627 | OpResult opResult = forOp.getTiedLoopResult(&opOperand); |
| 628 | BufferRelation relation = bufferRelation(op, opResult, state); |
| 629 | return {{opResult, relation, |
| 630 | /*isDefinite=*/relation == BufferRelation::Equivalent}}; |
| 631 | } |
| 632 | |
| 633 | BufferRelation bufferRelation(Operation *op, OpResult opResult, |
| 634 | const AnalysisState &state) const { |
| 635 | // ForOp results are equivalent to their corresponding init_args if the |
| 636 | // corresponding iter_args and yield values are equivalent. |
| 637 | auto forOp = cast<scf::ForOp>(op); |
| 638 | BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); |
| 639 | bool equivalentYield = state.areEquivalentBufferizedValues( |
| 640 | v1: bbArg, v2: forOp.getTiedLoopYieldedValue(bbArg)->get()); |
| 641 | return equivalentYield ? BufferRelation::Equivalent |
| 642 | : BufferRelation::Unknown; |
| 643 | } |
| 644 | |
| 645 | bool isWritable(Operation *op, Value value, |
| 646 | const AnalysisState &state) const { |
| 647 | // Interestingly, scf::ForOp's bbArg can **always** be viewed |
| 648 | // inplace from the perspective of ops nested under: |
| 649 | // 1. Either the matching iter operand is not bufferized inplace and an |
| 650 | // alloc + optional copy makes the bbArg itself inplaceable. |
| 651 | // 2. Or the matching iter operand is bufferized inplace and bbArg just |
| 652 | // bufferizes to that too. |
| 653 | return true; |
| 654 | } |
| 655 | |
| 656 | LogicalResult |
| 657 | resolveConflicts(Operation *op, RewriterBase &rewriter, |
| 658 | const AnalysisState &analysisState, |
| 659 | const BufferizationState &bufferizationState) const { |
| 660 | auto bufferizableOp = cast<BufferizableOpInterface>(op); |
| 661 | if (failed(bufferizableOp.resolveTensorOpOperandConflicts( |
| 662 | rewriter, analysisState, bufferizationState))) |
| 663 | return failure(); |
| 664 | |
| 665 | if (analysisState.getOptions().copyBeforeWrite) |
| 666 | return success(); |
| 667 | |
| 668 | // According to the `getAliasing...` implementations, a bufferized OpResult |
| 669 | // may alias only with the corresponding bufferized init_arg (or with a |
| 670 | // newly allocated buffer) and not with other buffers defined outside of the |
| 671 | // loop. I.e., the i-th OpResult may alias with the i-th init_arg; |
| 672 | // but not with any other OpOperand. |
| 673 | auto forOp = cast<scf::ForOp>(op); |
| 674 | auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
| 675 | OpBuilder::InsertionGuard g(rewriter); |
| 676 | rewriter.setInsertionPoint(yieldOp); |
| 677 | |
| 678 | // Indices of all iter_args that have tensor type. These are the ones that |
| 679 | // are bufferized. |
| 680 | DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); |
| 681 | // For every yielded value, does it alias with something defined outside of |
| 682 | // the loop? |
| 683 | SmallVector<Value> yieldValues; |
| 684 | for (const auto it : llvm::enumerate(yieldOp.getResults())) { |
| 685 | // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this |
| 686 | // type cannot be used in the signature of `resolveConflicts` because the |
| 687 | // op interface is in the "IR" build unit and the `OneShotAnalysisState` |
| 688 | // is defined in the "Transforms" build unit. |
| 689 | if (!indices.contains(it.index()) || |
| 690 | doesNotAliasExternalValue( |
| 691 | it.value(), &forOp.getRegion(), |
| 692 | /*exceptions=*/forOp.getRegionIterArg(it.index()), |
| 693 | static_cast<const OneShotAnalysisState &>(analysisState))) { |
| 694 | yieldValues.push_back(it.value()); |
| 695 | continue; |
| 696 | } |
| 697 | FailureOr<Value> alloc = allocateTensorForShapedValue( |
| 698 | rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(), |
| 699 | bufferizationState); |
| 700 | if (failed(alloc)) |
| 701 | return failure(); |
| 702 | yieldValues.push_back(*alloc); |
| 703 | } |
| 704 | |
| 705 | rewriter.modifyOpInPlace( |
| 706 | yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); |
| 707 | return success(); |
| 708 | } |
| 709 | |
| 710 | FailureOr<BaseMemRefType> |
| 711 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| 712 | const BufferizationState &state, |
| 713 | SmallVector<Value> &invocationStack) const { |
| 714 | auto forOp = cast<scf::ForOp>(op); |
| 715 | assert(getOwnerOfValue(value) == op && "invalid value" ); |
| 716 | assert(isa<TensorType>(value.getType()) && "expected tensor type" ); |
| 717 | |
| 718 | if (auto opResult = dyn_cast<OpResult>(Val&: value)) { |
| 719 | // The type of an OpResult must match the corresponding iter_arg type. |
| 720 | BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); |
| 721 | return bufferization::getBufferType(bbArg, options, state, |
| 722 | invocationStack); |
| 723 | } |
| 724 | |
| 725 | // Compute result/argument number. |
| 726 | BlockArgument bbArg = cast<BlockArgument>(Val&: value); |
| 727 | unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber(); |
| 728 | |
| 729 | // Compute the bufferized type. |
| 730 | auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
| 731 | Value yieldedValue = yieldOp.getOperand(resultNum); |
| 732 | BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum]; |
| 733 | Value initArg = forOp.getInitArgs()[resultNum]; |
| 734 | return computeLoopRegionIterArgBufferType( |
| 735 | loopOp: op, iterArg, initArg, yieldedValue, options, state, invocationStack); |
| 736 | } |
| 737 | |
| 738 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 739 | const BufferizationOptions &options, |
| 740 | BufferizationState &state) const { |
| 741 | auto forOp = cast<scf::ForOp>(op); |
| 742 | Block *oldLoopBody = forOp.getBody(); |
| 743 | |
| 744 | // Indices of all iter_args that have tensor type. These are the ones that |
| 745 | // are bufferized. |
| 746 | DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); |
| 747 | |
| 748 | // The new memref init_args of the loop. |
| 749 | FailureOr<SmallVector<Value>> maybeInitArgs = |
| 750 | getBuffers(rewriter, forOp.getInitArgsMutable(), options, state); |
| 751 | if (failed(Result: maybeInitArgs)) |
| 752 | return failure(); |
| 753 | SmallVector<Value> initArgs = *maybeInitArgs; |
| 754 | |
| 755 | // Cast init_args if necessary. |
| 756 | SmallVector<Value> castedInitArgs; |
| 757 | for (const auto &it : llvm::enumerate(initArgs)) { |
| 758 | Value initArg = it.value(); |
| 759 | Value result = forOp->getResult(it.index()); |
| 760 | // If the type is not a tensor, bufferization doesn't need to touch it. |
| 761 | if (!isa<TensorType>(result.getType())) { |
| 762 | castedInitArgs.push_back(initArg); |
| 763 | continue; |
| 764 | } |
| 765 | auto targetType = bufferization::getBufferType(result, options, state); |
| 766 | if (failed(targetType)) |
| 767 | return failure(); |
| 768 | castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); |
| 769 | } |
| 770 | |
| 771 | // Construct a new scf.for op with memref instead of tensor values. |
| 772 | auto newForOp = rewriter.create<scf::ForOp>( |
| 773 | forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
| 774 | forOp.getStep(), castedInitArgs); |
| 775 | newForOp->setAttrs(forOp->getAttrs()); |
| 776 | Block *loopBody = newForOp.getBody(); |
| 777 | |
| 778 | // Set up new iter_args. The loop body uses tensors, so wrap the (memref) |
| 779 | // iter_args of the new loop in ToTensorOps. |
| 780 | rewriter.setInsertionPointToStart(loopBody); |
| 781 | SmallVector<Value> iterArgs = |
| 782 | getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), |
| 783 | forOp.getRegionIterArgs(), indices); |
| 784 | iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); |
| 785 | |
| 786 | // Move loop body to new loop. |
| 787 | rewriter.mergeBlocks(source: oldLoopBody, dest: loopBody, argValues: iterArgs); |
| 788 | |
| 789 | // Replace loop results. |
| 790 | replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); |
| 791 | |
| 792 | return success(); |
| 793 | } |
| 794 | |
| 795 | /// Assert that yielded values of an scf.for op are equivalent to their |
| 796 | /// corresponding bbArgs. In that case, the buffer relations of the |
| 797 | /// corresponding OpResults are "Equivalent". |
| 798 | /// |
| 799 | /// If this is not the case, an allocs+copies are inserted and yielded from |
| 800 | /// the loop. This could be a performance problem, so it must be explicitly |
| 801 | /// activated with `alloc-return-allocs`. |
| 802 | LogicalResult verifyAnalysis(Operation *op, |
| 803 | const AnalysisState &state) const { |
| 804 | const auto &options = |
| 805 | static_cast<const OneShotBufferizationOptions &>(state.getOptions()); |
| 806 | if (options.allowReturnAllocsFromLoops) |
| 807 | return success(); |
| 808 | |
| 809 | auto forOp = cast<scf::ForOp>(op); |
| 810 | auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
| 811 | for (OpResult opResult : op->getOpResults()) { |
| 812 | if (!isa<TensorType>(Val: opResult.getType())) |
| 813 | continue; |
| 814 | |
| 815 | // Note: This is overly strict. We should check for aliasing bufferized |
| 816 | // values. But we don't have a "must-alias" analysis yet. |
| 817 | if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) |
| 818 | return yieldOp->emitError() |
| 819 | << "Yield operand #" << opResult.getResultNumber() |
| 820 | << " is not equivalent to the corresponding iter bbArg" ; |
| 821 | } |
| 822 | |
| 823 | return success(); |
| 824 | } |
| 825 | }; |
| 826 | |
| 827 | /// Bufferization of scf.while. Replace with a new scf.while that operates on |
| 828 | /// memrefs. |
| 829 | struct WhileOpInterface |
| 830 | : public BufferizableOpInterface::ExternalModel<WhileOpInterface, |
| 831 | scf::WhileOp> { |
| 832 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| 833 | const AnalysisState &state) const { |
| 834 | // Tensor iter_args of scf::WhileOps are always considered as a read. |
| 835 | return true; |
| 836 | } |
| 837 | |
| 838 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| 839 | const AnalysisState &state) const { |
| 840 | // Tensor iter_args of scf::WhileOps are always considered as a write. |
| 841 | return true; |
| 842 | } |
| 843 | |
| 844 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| 845 | const AnalysisState &state) const { |
| 846 | auto whileOp = cast<scf::WhileOp>(op); |
| 847 | unsigned int idx = opOperand.getOperandNumber(); |
| 848 | |
| 849 | // The OpResults and OpOperands may not match. They may not even have the |
| 850 | // same type. The number of OpResults and OpOperands can also differ. |
| 851 | if (idx >= op->getNumResults() || |
| 852 | opOperand.get().getType() != op->getResult(idx).getType()) |
| 853 | return {}; |
| 854 | |
| 855 | // The only aliasing OpResult may be the one at the same index. |
| 856 | OpResult opResult = whileOp->getResult(idx); |
| 857 | BufferRelation relation = bufferRelation(op, opResult, state); |
| 858 | return {{opResult, relation, |
| 859 | /*isDefinite=*/relation == BufferRelation::Equivalent}}; |
| 860 | } |
| 861 | |
| 862 | BufferRelation bufferRelation(Operation *op, OpResult opResult, |
| 863 | const AnalysisState &state) const { |
| 864 | // WhileOp results are equivalent to their corresponding init_args if the |
| 865 | // corresponding iter_args and yield values are equivalent (for both the |
| 866 | // "before" and the "after" block). |
| 867 | unsigned int resultNumber = opResult.getResultNumber(); |
| 868 | auto whileOp = cast<scf::WhileOp>(op); |
| 869 | |
| 870 | // The "before" region bbArgs and the OpResults may not match. |
| 871 | if (resultNumber >= whileOp.getBeforeArguments().size()) |
| 872 | return BufferRelation::Unknown; |
| 873 | if (opResult.getType() != |
| 874 | whileOp.getBeforeArguments()[resultNumber].getType()) |
| 875 | return BufferRelation::Unknown; |
| 876 | |
| 877 | auto conditionOp = whileOp.getConditionOp(); |
| 878 | BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; |
| 879 | Value conditionOperand = conditionOp.getArgs()[resultNumber]; |
| 880 | bool equivCondition = |
| 881 | state.areEquivalentBufferizedValues(v1: conditionBbArg, v2: conditionOperand); |
| 882 | |
| 883 | auto yieldOp = whileOp.getYieldOp(); |
| 884 | BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; |
| 885 | Value yieldOperand = yieldOp.getOperand(resultNumber); |
| 886 | bool equivYield = |
| 887 | state.areEquivalentBufferizedValues(v1: bodyBbArg, v2: yieldOperand); |
| 888 | |
| 889 | return equivCondition && equivYield ? BufferRelation::Equivalent |
| 890 | : BufferRelation::Unknown; |
| 891 | } |
| 892 | |
| 893 | bool isWritable(Operation *op, Value value, |
| 894 | const AnalysisState &state) const { |
| 895 | // Interestingly, scf::WhileOp's bbArg can **always** be viewed |
| 896 | // inplace from the perspective of ops nested under: |
| 897 | // 1. Either the matching iter operand is not bufferized inplace and an |
| 898 | // alloc + optional copy makes the bbArg itself inplaceable. |
| 899 | // 2. Or the matching iter operand is bufferized inplace and bbArg just |
| 900 | // bufferizes to that too. |
| 901 | return true; |
| 902 | } |
| 903 | |
| 904 | LogicalResult |
| 905 | resolveConflicts(Operation *op, RewriterBase &rewriter, |
| 906 | const AnalysisState &analysisState, |
| 907 | const BufferizationState &bufferizationState) const { |
| 908 | auto bufferizableOp = cast<BufferizableOpInterface>(op); |
| 909 | if (failed(bufferizableOp.resolveTensorOpOperandConflicts( |
| 910 | rewriter, analysisState, bufferizationState))) |
| 911 | return failure(); |
| 912 | |
| 913 | if (analysisState.getOptions().copyBeforeWrite) |
| 914 | return success(); |
| 915 | |
| 916 | // According to the `getAliasing...` implementations, a bufferized OpResult |
| 917 | // may alias only with the corresponding bufferized init_arg and with no |
| 918 | // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; |
| 919 | // but not with any other OpOperand. If a corresponding OpResult/init_arg |
| 920 | // pair bufferizes to equivalent buffers, this aliasing requirement is |
| 921 | // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. |
| 922 | // (New buffer copies do not alias with any buffer.) |
| 923 | OpBuilder::InsertionGuard g(rewriter); |
| 924 | auto whileOp = cast<scf::WhileOp>(op); |
| 925 | auto conditionOp = whileOp.getConditionOp(); |
| 926 | |
| 927 | // For every yielded value, is the value equivalent to its corresponding |
| 928 | // bbArg? |
| 929 | DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( |
| 930 | whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState); |
| 931 | DenseSet<int64_t> equivalentYieldsAfter = |
| 932 | getEquivalentBuffers(whileOp.getAfterArguments(), |
| 933 | whileOp.getYieldOp().getResults(), analysisState); |
| 934 | |
| 935 | // Update "before" region. |
| 936 | rewriter.setInsertionPoint(conditionOp); |
| 937 | SmallVector<Value> beforeYieldValues; |
| 938 | for (int64_t idx = 0; |
| 939 | idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) { |
| 940 | Value value = conditionOp.getArgs()[idx]; |
| 941 | if (!isa<TensorType>(Val: value.getType()) || |
| 942 | (equivalentYieldsAfter.contains(V: idx) && |
| 943 | equivalentYieldsBefore.contains(V: idx))) { |
| 944 | beforeYieldValues.push_back(Elt: value); |
| 945 | continue; |
| 946 | } |
| 947 | FailureOr<Value> alloc = allocateTensorForShapedValue( |
| 948 | rewriter, conditionOp.getLoc(), value, analysisState.getOptions(), |
| 949 | bufferizationState); |
| 950 | if (failed(Result: alloc)) |
| 951 | return failure(); |
| 952 | beforeYieldValues.push_back(Elt: *alloc); |
| 953 | } |
| 954 | rewriter.modifyOpInPlace(conditionOp, [&]() { |
| 955 | conditionOp.getArgsMutable().assign(beforeYieldValues); |
| 956 | }); |
| 957 | |
| 958 | return success(); |
| 959 | } |
| 960 | |
| 961 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 962 | const BufferizationOptions &options, |
| 963 | BufferizationState &state) const { |
| 964 | auto whileOp = cast<scf::WhileOp>(op); |
| 965 | |
| 966 | // Indices of all bbArgs that have tensor type. These are the ones that |
| 967 | // are bufferized. The "before" and "after" regions may have different args. |
| 968 | DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); |
| 969 | DenseSet<int64_t> indicesAfter = |
| 970 | getTensorIndices(whileOp.getAfterArguments()); |
| 971 | |
| 972 | // The new memref init_args of the loop. |
| 973 | FailureOr<SmallVector<Value>> maybeInitArgs = |
| 974 | getBuffers(rewriter, whileOp.getInitsMutable(), options, state); |
| 975 | if (failed(Result: maybeInitArgs)) |
| 976 | return failure(); |
| 977 | SmallVector<Value> initArgs = *maybeInitArgs; |
| 978 | |
| 979 | // Cast init_args if necessary. |
| 980 | SmallVector<Value> castedInitArgs; |
| 981 | for (const auto &it : llvm::enumerate(initArgs)) { |
| 982 | Value initArg = it.value(); |
| 983 | Value beforeArg = whileOp.getBeforeArguments()[it.index()]; |
| 984 | // If the type is not a tensor, bufferization doesn't need to touch it. |
| 985 | if (!isa<TensorType>(beforeArg.getType())) { |
| 986 | castedInitArgs.push_back(initArg); |
| 987 | continue; |
| 988 | } |
| 989 | auto targetType = bufferization::getBufferType(beforeArg, options, state); |
| 990 | if (failed(targetType)) |
| 991 | return failure(); |
| 992 | castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); |
| 993 | } |
| 994 | |
| 995 | // The result types of a WhileOp are the same as the "after" bbArg types. |
| 996 | SmallVector<Type> argsTypesAfter = llvm::to_vector( |
| 997 | llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { |
| 998 | if (!isa<TensorType>(Val: bbArg.getType())) |
| 999 | return bbArg.getType(); |
| 1000 | // TODO: error handling |
| 1001 | return llvm::cast<Type>( |
| 1002 | Val: *bufferization::getBufferType(value: bbArg, options, state)); |
| 1003 | })); |
| 1004 | |
| 1005 | // Construct a new scf.while op with memref instead of tensor values. |
| 1006 | ValueRange argsRangeBefore(castedInitArgs); |
| 1007 | TypeRange argsTypesBefore(argsRangeBefore); |
| 1008 | auto newWhileOp = rewriter.create<scf::WhileOp>( |
| 1009 | whileOp.getLoc(), argsTypesAfter, castedInitArgs); |
| 1010 | |
| 1011 | // Add before/after regions to the new op. |
| 1012 | SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(), |
| 1013 | whileOp.getLoc()); |
| 1014 | SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(), |
| 1015 | whileOp.getLoc()); |
| 1016 | Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); |
| 1017 | newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); |
| 1018 | Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); |
| 1019 | newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); |
| 1020 | |
| 1021 | // Set up new iter_args and move the loop condition block to the new op. |
| 1022 | // The old block uses tensors, so wrap the (memref) bbArgs of the new block |
| 1023 | // in ToTensorOps. |
| 1024 | rewriter.setInsertionPointToStart(newBeforeBody); |
| 1025 | SmallVector<Value> newBeforeArgs = |
| 1026 | getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(), |
| 1027 | whileOp.getBeforeArguments(), indicesBefore); |
| 1028 | rewriter.mergeBlocks(source: whileOp.getBeforeBody(), dest: newBeforeBody, argValues: newBeforeArgs); |
| 1029 | |
| 1030 | // Set up new iter_args and move the loop body block to the new op. |
| 1031 | // The old block uses tensors, so wrap the (memref) bbArgs of the new block |
| 1032 | // in ToTensorOps. |
| 1033 | rewriter.setInsertionPointToStart(newAfterBody); |
| 1034 | SmallVector<Value> newAfterArgs = |
| 1035 | getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), |
| 1036 | whileOp.getAfterArguments(), indicesAfter); |
| 1037 | rewriter.mergeBlocks(source: whileOp.getAfterBody(), dest: newAfterBody, argValues: newAfterArgs); |
| 1038 | |
| 1039 | // Replace loop results. |
| 1040 | replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); |
| 1041 | |
| 1042 | return success(); |
| 1043 | } |
| 1044 | |
| 1045 | FailureOr<BaseMemRefType> |
| 1046 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| 1047 | const BufferizationState &state, |
| 1048 | SmallVector<Value> &invocationStack) const { |
| 1049 | auto whileOp = cast<scf::WhileOp>(op); |
| 1050 | assert(getOwnerOfValue(value) == op && "invalid value" ); |
| 1051 | assert(isa<TensorType>(value.getType()) && "expected tensor type" ); |
| 1052 | |
| 1053 | // Case 1: Block argument of the "before" region. |
| 1054 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) { |
| 1055 | if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { |
| 1056 | Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; |
| 1057 | auto yieldOp = whileOp.getYieldOp(); |
| 1058 | Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber()); |
| 1059 | return computeLoopRegionIterArgBufferType( |
| 1060 | loopOp: op, iterArg: bbArg, initArg, yieldedValue, options, state, invocationStack); |
| 1061 | } |
| 1062 | } |
| 1063 | |
| 1064 | // Case 2: OpResult of the loop or block argument of the "after" region. |
| 1065 | // The bufferized "after" bbArg type can be directly computed from the |
| 1066 | // bufferized "before" bbArg type. |
| 1067 | unsigned resultNum; |
| 1068 | if (auto opResult = dyn_cast<OpResult>(Val&: value)) { |
| 1069 | resultNum = opResult.getResultNumber(); |
| 1070 | } else if (cast<BlockArgument>(Val&: value).getOwner()->getParent() == |
| 1071 | &whileOp.getAfter()) { |
| 1072 | resultNum = cast<BlockArgument>(Val&: value).getArgNumber(); |
| 1073 | } else { |
| 1074 | llvm_unreachable("invalid value" ); |
| 1075 | } |
| 1076 | Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; |
| 1077 | if (!isa<TensorType>(Val: conditionYieldedVal.getType())) { |
| 1078 | // scf.condition was already bufferized. |
| 1079 | return cast<BaseMemRefType>(Val: conditionYieldedVal.getType()); |
| 1080 | } |
| 1081 | return bufferization::getBufferType(conditionYieldedVal, options, state, |
| 1082 | invocationStack); |
| 1083 | } |
| 1084 | |
| 1085 | /// Assert that yielded values of an scf.while op are equivalent to their |
| 1086 | /// corresponding bbArgs. In that case, the buffer relations of the |
| 1087 | /// corresponding OpResults are "Equivalent". |
| 1088 | /// |
| 1089 | /// If this is not the case, allocs+copies are inserted and yielded from |
| 1090 | /// the loop. This could be a performance problem, so it must be explicitly |
| 1091 | /// activated with `allow-return-allocs`. |
| 1092 | /// |
| 1093 | /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the |
| 1094 | /// equivalence condition must be checked for both. |
| 1095 | LogicalResult verifyAnalysis(Operation *op, |
| 1096 | const AnalysisState &state) const { |
| 1097 | auto whileOp = cast<scf::WhileOp>(op); |
| 1098 | const auto &options = |
| 1099 | static_cast<const OneShotBufferizationOptions &>(state.getOptions()); |
| 1100 | if (options.allowReturnAllocsFromLoops) |
| 1101 | return success(); |
| 1102 | |
| 1103 | auto conditionOp = whileOp.getConditionOp(); |
| 1104 | for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { |
| 1105 | Block *block = conditionOp->getBlock(); |
| 1106 | if (!isa<TensorType>(it.value().getType())) |
| 1107 | continue; |
| 1108 | if (it.index() >= block->getNumArguments() || |
| 1109 | !state.areEquivalentBufferizedValues(it.value(), |
| 1110 | block->getArgument(it.index()))) |
| 1111 | return conditionOp->emitError() |
| 1112 | << "Condition arg #" << it.index() |
| 1113 | << " is not equivalent to the corresponding iter bbArg" ; |
| 1114 | } |
| 1115 | |
| 1116 | auto yieldOp = whileOp.getYieldOp(); |
| 1117 | for (const auto &it : llvm::enumerate(yieldOp.getResults())) { |
| 1118 | Block *block = yieldOp->getBlock(); |
| 1119 | if (!isa<TensorType>(it.value().getType())) |
| 1120 | continue; |
| 1121 | if (it.index() >= block->getNumArguments() || |
| 1122 | !state.areEquivalentBufferizedValues(it.value(), |
| 1123 | block->getArgument(it.index()))) |
| 1124 | return yieldOp->emitError() |
| 1125 | << "Yield operand #" << it.index() |
| 1126 | << " is not equivalent to the corresponding iter bbArg" ; |
| 1127 | } |
| 1128 | |
| 1129 | return success(); |
| 1130 | } |
| 1131 | }; |
| 1132 | |
| 1133 | /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so |
| 1134 | /// this is for analysis only. |
| 1135 | struct YieldOpInterface |
| 1136 | : public BufferizableOpInterface::ExternalModel<YieldOpInterface, |
| 1137 | scf::YieldOp> { |
| 1138 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| 1139 | const AnalysisState &state) const { |
| 1140 | return true; |
| 1141 | } |
| 1142 | |
| 1143 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| 1144 | const AnalysisState &state) const { |
| 1145 | return false; |
| 1146 | } |
| 1147 | |
| 1148 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| 1149 | const AnalysisState &state) const { |
| 1150 | if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) { |
| 1151 | return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()), |
| 1152 | BufferRelation::Equivalent, /*isDefinite=*/false}}; |
| 1153 | } |
| 1154 | if (isa<scf::ExecuteRegionOp>(op->getParentOp())) |
| 1155 | return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()), |
| 1156 | BufferRelation::Equivalent}}; |
| 1157 | return {}; |
| 1158 | } |
| 1159 | |
| 1160 | bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
| 1161 | const AnalysisState &state) const { |
| 1162 | // Yield operands always bufferize inplace. Otherwise, an alloc + copy |
| 1163 | // may be generated inside the block. We should not return/yield allocations |
| 1164 | // when possible. |
| 1165 | return true; |
| 1166 | } |
| 1167 | |
| 1168 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 1169 | const BufferizationOptions &options, |
| 1170 | BufferizationState &state) const { |
| 1171 | auto yieldOp = cast<scf::YieldOp>(op); |
| 1172 | if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp, |
| 1173 | scf::WhileOp>(yieldOp->getParentOp())) |
| 1174 | return yieldOp->emitError("unsupported scf::YieldOp parent" ); |
| 1175 | |
| 1176 | SmallVector<Value> newResults; |
| 1177 | for (const auto &it : llvm::enumerate(yieldOp.getResults())) { |
| 1178 | Value value = it.value(); |
| 1179 | if (isa<TensorType>(value.getType())) { |
| 1180 | FailureOr<Value> maybeBuffer = |
| 1181 | getBuffer(rewriter, value, options, state); |
| 1182 | if (failed(maybeBuffer)) |
| 1183 | return failure(); |
| 1184 | Value buffer = *maybeBuffer; |
| 1185 | // We may have to cast the value before yielding it. |
| 1186 | if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>( |
| 1187 | yieldOp->getParentOp())) { |
| 1188 | FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
| 1189 | yieldOp->getParentOp()->getResult(it.index()), options, state); |
| 1190 | if (failed(resultType)) |
| 1191 | return failure(); |
| 1192 | buffer = castBuffer(rewriter, buffer, *resultType); |
| 1193 | } else if (auto whileOp = |
| 1194 | dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) { |
| 1195 | FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
| 1196 | whileOp.getBeforeArguments()[it.index()], options, state); |
| 1197 | if (failed(resultType)) |
| 1198 | return failure(); |
| 1199 | buffer = castBuffer(rewriter, buffer, *resultType); |
| 1200 | } |
| 1201 | newResults.push_back(buffer); |
| 1202 | } else { |
| 1203 | newResults.push_back(value); |
| 1204 | } |
| 1205 | } |
| 1206 | |
| 1207 | replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults); |
| 1208 | return success(); |
| 1209 | } |
| 1210 | }; |
| 1211 | |
| 1212 | /// Bufferization of ForallOp. This also bufferizes the terminator of the |
| 1213 | /// region. There are op interfaces for the terminators (InParallelOp |
| 1214 | /// and ParallelInsertSliceOp), but these are only used during analysis. Not |
| 1215 | /// for bufferization. |
| 1216 | struct ForallOpInterface |
| 1217 | : public BufferizableOpInterface::ExternalModel<ForallOpInterface, |
| 1218 | ForallOp> { |
| 1219 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| 1220 | const AnalysisState &state) const { |
| 1221 | // All tensor operands to `scf.forall` are `shared_outs` and all |
| 1222 | // shared outs are assumed to be read by the loop. This does not |
| 1223 | // account for the case where the entire value is over-written, |
| 1224 | // but being conservative here. |
| 1225 | return true; |
| 1226 | } |
| 1227 | |
| 1228 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| 1229 | const AnalysisState &state) const { |
| 1230 | // Outputs of scf::ForallOps are always considered as a write. |
| 1231 | return true; |
| 1232 | } |
| 1233 | |
| 1234 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| 1235 | const AnalysisState &state) const { |
| 1236 | auto forallOp = cast<ForallOp>(op); |
| 1237 | return { |
| 1238 | {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}}; |
| 1239 | } |
| 1240 | |
| 1241 | bool isWritable(Operation *op, Value value, |
| 1242 | const AnalysisState &state) const { |
| 1243 | return true; |
| 1244 | } |
| 1245 | |
| 1246 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 1247 | const BufferizationOptions &options, |
| 1248 | BufferizationState &state) const { |
| 1249 | OpBuilder::InsertionGuard guard(rewriter); |
| 1250 | auto forallOp = cast<ForallOp>(op); |
| 1251 | int64_t rank = forallOp.getRank(); |
| 1252 | |
| 1253 | // Get buffers for all output operands. |
| 1254 | SmallVector<Value> buffers; |
| 1255 | for (Value out : forallOp.getOutputs()) { |
| 1256 | FailureOr<Value> buffer = getBuffer(rewriter, out, options, state); |
| 1257 | if (failed(buffer)) |
| 1258 | return failure(); |
| 1259 | buffers.push_back(*buffer); |
| 1260 | } |
| 1261 | |
| 1262 | // Use buffers instead of block arguments. |
| 1263 | rewriter.setInsertionPointToStart(forallOp.getBody()); |
| 1264 | for (const auto &it : llvm::zip( |
| 1265 | forallOp.getBody()->getArguments().drop_front(rank), buffers)) { |
| 1266 | BlockArgument bbArg = std::get<0>(it); |
| 1267 | Value buffer = std::get<1>(it); |
| 1268 | Value bufferAsTensor = rewriter.create<ToTensorOp>( |
| 1269 | forallOp.getLoc(), bbArg.getType(), buffer); |
| 1270 | bbArg.replaceAllUsesWith(bufferAsTensor); |
| 1271 | } |
| 1272 | |
| 1273 | // Create new ForallOp without any results and drop the automatically |
| 1274 | // introduced terminator. |
| 1275 | rewriter.setInsertionPoint(forallOp); |
| 1276 | ForallOp newForallOp; |
| 1277 | newForallOp = rewriter.create<ForallOp>( |
| 1278 | forallOp.getLoc(), forallOp.getMixedLowerBound(), |
| 1279 | forallOp.getMixedUpperBound(), forallOp.getMixedStep(), |
| 1280 | /*outputs=*/ValueRange(), forallOp.getMapping()); |
| 1281 | |
| 1282 | // Keep discardable attributes from the original op. |
| 1283 | newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); |
| 1284 | |
| 1285 | rewriter.eraseOp(op: newForallOp.getBody()->getTerminator()); |
| 1286 | |
| 1287 | // Move over block contents of the old op. |
| 1288 | SmallVector<Value> replacementBbArgs; |
| 1289 | replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(), |
| 1290 | newForallOp.getBody()->getArguments().end()); |
| 1291 | replacementBbArgs.append(forallOp.getOutputs().size(), Value()); |
| 1292 | rewriter.mergeBlocks(source: forallOp.getBody(), dest: newForallOp.getBody(), |
| 1293 | argValues: replacementBbArgs); |
| 1294 | |
| 1295 | // Remove the old op and replace all of its uses. |
| 1296 | replaceOpWithBufferizedValues(rewriter, op, values: buffers); |
| 1297 | |
| 1298 | return success(); |
| 1299 | } |
| 1300 | |
| 1301 | FailureOr<BaseMemRefType> |
| 1302 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| 1303 | const BufferizationState &state, |
| 1304 | SmallVector<Value> &invocationStack) const { |
| 1305 | auto forallOp = cast<ForallOp>(op); |
| 1306 | |
| 1307 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) |
| 1308 | // A tensor block argument has the same bufferized type as the |
| 1309 | // corresponding output operand. |
| 1310 | return bufferization::getBufferType( |
| 1311 | value: forallOp.getTiedOpOperand(bbArg)->get(), options, state, |
| 1312 | invocationStack); |
| 1313 | |
| 1314 | // The bufferized result type is the same as the bufferized type of the |
| 1315 | // corresponding output operand. |
| 1316 | return bufferization::getBufferType( |
| 1317 | value: forallOp.getOutputs()[cast<OpResult>(Val&: value).getResultNumber()], options, |
| 1318 | state, invocationStack); |
| 1319 | } |
| 1320 | |
| 1321 | bool isRepetitiveRegion(Operation *op, unsigned index) const { |
| 1322 | auto forallOp = cast<ForallOp>(op); |
| 1323 | |
| 1324 | // This op is repetitive if it has 1 or more steps. |
| 1325 | // If the control variables are dynamic, it is also considered so. |
| 1326 | for (auto [lb, ub, step] : |
| 1327 | llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), |
| 1328 | forallOp.getMixedStep())) { |
| 1329 | std::optional<int64_t> lbConstant = getConstantIntValue(lb); |
| 1330 | if (!lbConstant) |
| 1331 | return true; |
| 1332 | |
| 1333 | std::optional<int64_t> ubConstant = getConstantIntValue(ub); |
| 1334 | if (!ubConstant) |
| 1335 | return true; |
| 1336 | |
| 1337 | std::optional<int64_t> stepConstant = getConstantIntValue(step); |
| 1338 | if (!stepConstant) |
| 1339 | return true; |
| 1340 | |
| 1341 | if (*lbConstant + *stepConstant < *ubConstant) |
| 1342 | return true; |
| 1343 | } |
| 1344 | return false; |
| 1345 | } |
| 1346 | |
| 1347 | bool isParallelRegion(Operation *op, unsigned index) const { |
| 1348 | return isRepetitiveRegion(op, index); |
| 1349 | } |
| 1350 | }; |
| 1351 | |
| 1352 | /// Nothing to do for InParallelOp. |
| 1353 | struct InParallelOpInterface |
| 1354 | : public BufferizableOpInterface::ExternalModel<InParallelOpInterface, |
| 1355 | InParallelOp> { |
| 1356 | LogicalResult bufferize(Operation *op, RewriterBase &b, |
| 1357 | const BufferizationOptions &options, |
| 1358 | BufferizationState &state) const { |
| 1359 | llvm_unreachable("op does not have any tensor OpOperands / OpResults" ); |
| 1360 | return failure(); |
| 1361 | } |
| 1362 | }; |
| 1363 | |
| 1364 | } // namespace |
| 1365 | } // namespace scf |
| 1366 | } // namespace mlir |
| 1367 | |
| 1368 | void mlir::scf::registerBufferizableOpInterfaceExternalModels( |
| 1369 | DialectRegistry ®istry) { |
| 1370 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) { |
| 1371 | ConditionOp::attachInterface<ConditionOpInterface>(*ctx); |
| 1372 | ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); |
| 1373 | ForOp::attachInterface<ForOpInterface>(*ctx); |
| 1374 | IfOp::attachInterface<IfOpInterface>(*ctx); |
| 1375 | IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx); |
| 1376 | ForallOp::attachInterface<ForallOpInterface>(*ctx); |
| 1377 | InParallelOp::attachInterface<InParallelOpInterface>(*ctx); |
| 1378 | WhileOp::attachInterface<WhileOpInterface>(*ctx); |
| 1379 | YieldOp::attachInterface<YieldOpInterface>(*ctx); |
| 1380 | }); |
| 1381 | } |
| 1382 | |