| 1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 13 | #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" |
| 14 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
| 15 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 18 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 19 | #include "mlir/IR/Dialect.h" |
| 20 | #include "mlir/IR/Operation.h" |
| 21 | #include "mlir/IR/PatternMatch.h" |
| 22 | |
| 23 | using namespace mlir; |
| 24 | using namespace mlir::bufferization; |
| 25 | using namespace mlir::scf; |
| 26 | |
| 27 | namespace mlir { |
| 28 | namespace scf { |
| 29 | namespace { |
| 30 | |
| 31 | /// Helper function for loop bufferization. Cast the given buffer to the given |
| 32 | /// memref type. |
| 33 | static Value castBuffer(OpBuilder &b, Value buffer, Type type) { |
| 34 | assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType" ); |
| 35 | assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType" ); |
| 36 | // If the buffer already has the correct type, no cast is needed. |
| 37 | if (buffer.getType() == type) |
| 38 | return buffer; |
| 39 | // TODO: In case `type` has a layout map that is not the fully dynamic |
| 40 | // one, we may not be able to cast the buffer. In that case, the loop |
| 41 | // iter_arg's layout map must be changed (see uses of `castBuffer`). |
| 42 | assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && |
| 43 | "scf.while op bufferization: cast incompatible" ); |
| 44 | return b.create<memref::CastOp>(location: buffer.getLoc(), args&: type, args&: buffer).getResult(); |
| 45 | } |
| 46 | |
| 47 | /// Helper function for loop bufferization. Return "true" if the given value |
| 48 | /// is guaranteed to not alias with an external tensor apart from values in |
| 49 | /// `exceptions`. A value is external if it is defined outside of the given |
| 50 | /// region or if it is an entry block argument of the region. |
| 51 | static bool doesNotAliasExternalValue(Value value, Region *region, |
| 52 | ValueRange exceptions, |
| 53 | const OneShotAnalysisState &state) { |
| 54 | assert(llvm::hasSingleElement(region->getBlocks()) && |
| 55 | "expected region with single block" ); |
| 56 | bool result = true; |
| 57 | state.applyOnAliases(v: value, fun: [&](Value alias) { |
| 58 | if (llvm::is_contained(Range&: exceptions, Element: alias)) |
| 59 | return; |
| 60 | Region *aliasRegion = alias.getParentRegion(); |
| 61 | if (isa<BlockArgument>(Val: alias) && !region->isProperAncestor(other: aliasRegion)) |
| 62 | result = false; |
| 63 | if (isa<OpResult>(Val: alias) && !region->isAncestor(other: aliasRegion)) |
| 64 | result = false; |
| 65 | }); |
| 66 | return result; |
| 67 | } |
| 68 | |
| 69 | /// Bufferization of scf.condition. |
| 70 | struct ConditionOpInterface |
| 71 | : public BufferizableOpInterface::ExternalModel<ConditionOpInterface, |
| 72 | scf::ConditionOp> { |
| 73 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| 74 | const AnalysisState &state) const { |
| 75 | return true; |
| 76 | } |
| 77 | |
| 78 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| 79 | const AnalysisState &state) const { |
| 80 | return false; |
| 81 | } |
| 82 | |
| 83 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| 84 | const AnalysisState &state) const { |
| 85 | return {}; |
| 86 | } |
| 87 | |
| 88 | bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
| 89 | const AnalysisState &state) const { |
| 90 | // Condition operands always bufferize inplace. Otherwise, an alloc + copy |
| 91 | // may be generated inside the block. We should not return/yield allocations |
| 92 | // when possible. |
| 93 | return true; |
| 94 | } |
| 95 | |
| 96 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 97 | const BufferizationOptions &options, |
| 98 | BufferizationState &state) const { |
| 99 | auto conditionOp = cast<scf::ConditionOp>(Val: op); |
| 100 | auto whileOp = cast<scf::WhileOp>(Val: conditionOp->getParentOp()); |
| 101 | |
| 102 | SmallVector<Value> newArgs; |
| 103 | for (const auto &it : llvm::enumerate(First: conditionOp.getArgs())) { |
| 104 | Value value = it.value(); |
| 105 | if (isa<TensorType>(Val: value.getType())) { |
| 106 | FailureOr<Value> maybeBuffer = |
| 107 | getBuffer(rewriter, value, options, state); |
| 108 | if (failed(Result: maybeBuffer)) |
| 109 | return failure(); |
| 110 | FailureOr<BufferLikeType> resultType = bufferization::getBufferType( |
| 111 | value: whileOp.getAfterArguments()[it.index()], options, state); |
| 112 | if (failed(Result: resultType)) |
| 113 | return failure(); |
| 114 | Value buffer = castBuffer(b&: rewriter, buffer: *maybeBuffer, type: *resultType); |
| 115 | newArgs.push_back(Elt: buffer); |
| 116 | } else { |
| 117 | newArgs.push_back(Elt: value); |
| 118 | } |
| 119 | } |
| 120 | |
| 121 | replaceOpWithNewBufferizedOp<scf::ConditionOp>( |
| 122 | rewriter, op, args: conditionOp.getCondition(), args&: newArgs); |
| 123 | return success(); |
| 124 | } |
| 125 | }; |
| 126 | |
| 127 | /// Return the unique scf.yield op. If there are multiple or no scf.yield ops, |
| 128 | /// return an empty op. |
| 129 | static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) { |
| 130 | scf::YieldOp result; |
| 131 | for (Block &block : executeRegionOp.getRegion()) { |
| 132 | if (auto yieldOp = dyn_cast<scf::YieldOp>(Val: block.getTerminator())) { |
| 133 | if (result) |
| 134 | return {}; |
| 135 | result = yieldOp; |
| 136 | } |
| 137 | } |
| 138 | return result; |
| 139 | } |
| 140 | |
| 141 | /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not |
| 142 | /// fully implemented at the moment. |
| 143 | struct ExecuteRegionOpInterface |
| 144 | : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< |
| 145 | ExecuteRegionOpInterface, scf::ExecuteRegionOp> { |
| 146 | |
| 147 | static bool supportsUnstructuredControlFlow() { return true; } |
| 148 | |
| 149 | bool isWritable(Operation *op, Value value, |
| 150 | const AnalysisState &state) const { |
| 151 | return true; |
| 152 | } |
| 153 | |
| 154 | LogicalResult verifyAnalysis(Operation *op, |
| 155 | const AnalysisState &state) const { |
| 156 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(Val: op); |
| 157 | // TODO: scf.execute_region with multiple yields are not supported. |
| 158 | if (!getUniqueYieldOp(executeRegionOp)) |
| 159 | return op->emitOpError(message: "op without unique scf.yield is not supported" ); |
| 160 | return success(); |
| 161 | } |
| 162 | |
| 163 | AliasingOpOperandList |
| 164 | getAliasingOpOperands(Operation *op, Value value, |
| 165 | const AnalysisState &state) const { |
| 166 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) |
| 167 | return getAliasingBranchOpOperands(op, bbArg, state); |
| 168 | |
| 169 | // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be |
| 170 | // any SSA value that is in scope. To allow for use-def chain traversal |
| 171 | // through ExecuteRegionOps in the analysis, the corresponding yield value |
| 172 | // is considered to be aliasing with the result. |
| 173 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(Val: op); |
| 174 | auto it = llvm::find(Range: op->getOpResults(), Val: value); |
| 175 | assert(it != op->getOpResults().end() && "invalid value" ); |
| 176 | size_t resultNum = std::distance(first: op->getOpResults().begin(), last: it); |
| 177 | auto yieldOp = getUniqueYieldOp(executeRegionOp); |
| 178 | // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail. |
| 179 | if (!yieldOp) |
| 180 | return {}; |
| 181 | return {{&yieldOp->getOpOperand(idx: resultNum), BufferRelation::Equivalent}}; |
| 182 | } |
| 183 | |
| 184 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 185 | const BufferizationOptions &options, |
| 186 | BufferizationState &state) const { |
| 187 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(Val: op); |
| 188 | auto yieldOp = getUniqueYieldOp(executeRegionOp); |
| 189 | TypeRange newResultTypes(yieldOp.getResults()); |
| 190 | |
| 191 | // Create new op and move over region. |
| 192 | auto newOp = |
| 193 | rewriter.create<scf::ExecuteRegionOp>(location: op->getLoc(), args&: newResultTypes); |
| 194 | newOp.getRegion().takeBody(other&: executeRegionOp.getRegion()); |
| 195 | |
| 196 | // Bufferize every block. |
| 197 | for (Block &block : newOp.getRegion()) |
| 198 | if (failed(Result: bufferization::bufferizeBlockSignature(block: &block, rewriter, |
| 199 | options, state))) |
| 200 | return failure(); |
| 201 | |
| 202 | // Update all uses of the old op. |
| 203 | rewriter.setInsertionPointAfter(newOp); |
| 204 | SmallVector<Value> newResults; |
| 205 | for (const auto &it : llvm::enumerate(First: executeRegionOp->getResultTypes())) { |
| 206 | if (isa<TensorType>(Val: it.value())) { |
| 207 | newResults.push_back(Elt: rewriter.create<bufferization::ToTensorOp>( |
| 208 | location: executeRegionOp.getLoc(), args&: it.value(), |
| 209 | args: newOp->getResult(idx: it.index()))); |
| 210 | } else { |
| 211 | newResults.push_back(Elt: newOp->getResult(idx: it.index())); |
| 212 | } |
| 213 | } |
| 214 | |
| 215 | // Replace old op. |
| 216 | rewriter.replaceOp(op: executeRegionOp, newValues: newResults); |
| 217 | |
| 218 | return success(); |
| 219 | } |
| 220 | }; |
| 221 | |
| 222 | /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. |
| 223 | struct IfOpInterface |
| 224 | : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { |
| 225 | AliasingOpOperandList |
| 226 | getAliasingOpOperands(Operation *op, Value value, |
| 227 | const AnalysisState &state) const { |
| 228 | // IfOps do not have tensor OpOperands. The yielded value can be any SSA |
| 229 | // value that is in scope. To allow for use-def chain traversal through |
| 230 | // IfOps in the analysis, both corresponding yield values from the then/else |
| 231 | // branches are considered to be aliasing with the result. |
| 232 | auto ifOp = cast<scf::IfOp>(Val: op); |
| 233 | size_t resultNum = std::distance(first: op->getOpResults().begin(), |
| 234 | last: llvm::find(Range: op->getOpResults(), Val: value)); |
| 235 | OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(idx: resultNum); |
| 236 | OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(idx: resultNum); |
| 237 | return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false}, |
| 238 | {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}}; |
| 239 | } |
| 240 | |
| 241 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 242 | const BufferizationOptions &options, |
| 243 | BufferizationState &state) const { |
| 244 | OpBuilder::InsertionGuard g(rewriter); |
| 245 | auto ifOp = cast<scf::IfOp>(Val: op); |
| 246 | |
| 247 | // Compute bufferized result types. |
| 248 | SmallVector<Type> newTypes; |
| 249 | for (Value result : ifOp.getResults()) { |
| 250 | if (!isa<TensorType>(Val: result.getType())) { |
| 251 | newTypes.push_back(Elt: result.getType()); |
| 252 | continue; |
| 253 | } |
| 254 | auto bufferType = bufferization::getBufferType(value: result, options, state); |
| 255 | if (failed(Result: bufferType)) |
| 256 | return failure(); |
| 257 | newTypes.push_back(Elt: *bufferType); |
| 258 | } |
| 259 | |
| 260 | // Create new op. |
| 261 | rewriter.setInsertionPoint(ifOp); |
| 262 | auto newIfOp = |
| 263 | rewriter.create<scf::IfOp>(location: ifOp.getLoc(), args&: newTypes, args: ifOp.getCondition(), |
| 264 | /*withElseRegion=*/args: true); |
| 265 | |
| 266 | // Move over then/else blocks. |
| 267 | rewriter.mergeBlocks(source: ifOp.thenBlock(), dest: newIfOp.thenBlock()); |
| 268 | rewriter.mergeBlocks(source: ifOp.elseBlock(), dest: newIfOp.elseBlock()); |
| 269 | |
| 270 | // Replace op results. |
| 271 | replaceOpWithBufferizedValues(rewriter, op, values: newIfOp->getResults()); |
| 272 | |
| 273 | return success(); |
| 274 | } |
| 275 | |
| 276 | FailureOr<BufferLikeType> |
| 277 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| 278 | const BufferizationState &state, |
| 279 | SmallVector<Value> &invocationStack) const { |
| 280 | auto ifOp = cast<scf::IfOp>(Val: op); |
| 281 | auto thenYieldOp = cast<scf::YieldOp>(Val: ifOp.thenBlock()->getTerminator()); |
| 282 | auto elseYieldOp = cast<scf::YieldOp>(Val: ifOp.elseBlock()->getTerminator()); |
| 283 | assert(value.getDefiningOp() == op && "invalid valid" ); |
| 284 | |
| 285 | // Determine buffer types of the true/false branches. |
| 286 | auto opResult = cast<OpResult>(Val&: value); |
| 287 | auto thenValue = thenYieldOp.getOperand(i: opResult.getResultNumber()); |
| 288 | auto elseValue = elseYieldOp.getOperand(i: opResult.getResultNumber()); |
| 289 | BaseMemRefType thenBufferType, elseBufferType; |
| 290 | if (isa<BaseMemRefType>(Val: thenValue.getType())) { |
| 291 | // True branch was already bufferized. |
| 292 | thenBufferType = cast<BaseMemRefType>(Val: thenValue.getType()); |
| 293 | } else { |
| 294 | auto maybeBufferType = |
| 295 | bufferization::detail::asMemRefType(bufferType: bufferization::getBufferType( |
| 296 | value: thenValue, options, state, invocationStack)); |
| 297 | if (failed(Result: maybeBufferType)) |
| 298 | return failure(); |
| 299 | thenBufferType = *maybeBufferType; |
| 300 | } |
| 301 | if (isa<BaseMemRefType>(Val: elseValue.getType())) { |
| 302 | // False branch was already bufferized. |
| 303 | elseBufferType = cast<BaseMemRefType>(Val: elseValue.getType()); |
| 304 | } else { |
| 305 | auto maybeBufferType = |
| 306 | bufferization::detail::asMemRefType(bufferType: bufferization::getBufferType( |
| 307 | value: elseValue, options, state, invocationStack)); |
| 308 | if (failed(Result: maybeBufferType)) |
| 309 | return failure(); |
| 310 | elseBufferType = *maybeBufferType; |
| 311 | } |
| 312 | |
| 313 | // Best case: Both branches have the exact same buffer type. |
| 314 | if (thenBufferType == elseBufferType) |
| 315 | return cast<BufferLikeType>(Val&: thenBufferType); |
| 316 | |
| 317 | // Memory space mismatch. |
| 318 | if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace()) |
| 319 | return op->emitError(message: "inconsistent memory space on then/else branches" ); |
| 320 | |
| 321 | // Layout maps are different: Promote to fully dynamic layout map. |
| 322 | return cast<BufferLikeType>(Val: getMemRefTypeWithFullyDynamicLayout( |
| 323 | tensorType: cast<TensorType>(Val: opResult.getType()), memorySpace: thenBufferType.getMemorySpace())); |
| 324 | } |
| 325 | }; |
| 326 | |
| 327 | /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that |
| 328 | /// yields memrefs. |
| 329 | struct IndexSwitchOpInterface |
| 330 | : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface, |
| 331 | scf::IndexSwitchOp> { |
| 332 | AliasingOpOperandList |
| 333 | getAliasingOpOperands(Operation *op, Value value, |
| 334 | const AnalysisState &state) const { |
| 335 | // IndexSwitchOps do not have tensor OpOperands. The yielded value can be |
| 336 | // any SSA. This is similar to IfOps. |
| 337 | auto switchOp = cast<scf::IndexSwitchOp>(Val: op); |
| 338 | int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber(); |
| 339 | AliasingOpOperandList result; |
| 340 | for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { |
| 341 | auto yieldOp = |
| 342 | cast<scf::YieldOp>(Val: switchOp.getCaseBlock(idx: i).getTerminator()); |
| 343 | result.addAlias(alias: AliasingOpOperand(&yieldOp->getOpOperand(idx: resultNum), |
| 344 | BufferRelation::Equivalent, |
| 345 | /*isDefinite=*/false)); |
| 346 | } |
| 347 | auto defaultYieldOp = |
| 348 | cast<scf::YieldOp>(Val: switchOp.getDefaultBlock().getTerminator()); |
| 349 | result.addAlias(alias: AliasingOpOperand(&defaultYieldOp->getOpOperand(idx: resultNum), |
| 350 | BufferRelation::Equivalent, |
| 351 | /*isDefinite=*/false)); |
| 352 | return result; |
| 353 | } |
| 354 | |
| 355 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 356 | const BufferizationOptions &options, |
| 357 | BufferizationState &state) const { |
| 358 | OpBuilder::InsertionGuard g(rewriter); |
| 359 | auto switchOp = cast<scf::IndexSwitchOp>(Val: op); |
| 360 | |
| 361 | // Compute bufferized result types. |
| 362 | SmallVector<Type> newTypes; |
| 363 | for (Value result : switchOp.getResults()) { |
| 364 | if (!isa<TensorType>(Val: result.getType())) { |
| 365 | newTypes.push_back(Elt: result.getType()); |
| 366 | continue; |
| 367 | } |
| 368 | auto bufferType = bufferization::getBufferType(value: result, options, state); |
| 369 | if (failed(Result: bufferType)) |
| 370 | return failure(); |
| 371 | newTypes.push_back(Elt: *bufferType); |
| 372 | } |
| 373 | |
| 374 | // Create new op. |
| 375 | rewriter.setInsertionPoint(switchOp); |
| 376 | auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>( |
| 377 | location: switchOp.getLoc(), args&: newTypes, args: switchOp.getArg(), args: switchOp.getCases(), |
| 378 | args: switchOp.getCases().size()); |
| 379 | |
| 380 | // Move over blocks. |
| 381 | for (auto [src, dest] : |
| 382 | llvm::zip(t: switchOp.getCaseRegions(), u: newSwitchOp.getCaseRegions())) |
| 383 | rewriter.inlineRegionBefore(region&: src, parent&: dest, before: dest.begin()); |
| 384 | rewriter.inlineRegionBefore(region&: switchOp.getDefaultRegion(), |
| 385 | parent&: newSwitchOp.getDefaultRegion(), |
| 386 | before: newSwitchOp.getDefaultRegion().begin()); |
| 387 | |
| 388 | // Replace op results. |
| 389 | replaceOpWithBufferizedValues(rewriter, op, values: newSwitchOp->getResults()); |
| 390 | |
| 391 | return success(); |
| 392 | } |
| 393 | |
| 394 | FailureOr<BufferLikeType> |
| 395 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| 396 | const BufferizationState &state, |
| 397 | SmallVector<Value> &invocationStack) const { |
| 398 | auto switchOp = cast<scf::IndexSwitchOp>(Val: op); |
| 399 | assert(value.getDefiningOp() == op && "invalid value" ); |
| 400 | int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber(); |
| 401 | |
| 402 | // Helper function to get buffer type of a case. |
| 403 | auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> { |
| 404 | auto yieldOp = cast<scf::YieldOp>(Val: b.getTerminator()); |
| 405 | Value yieldedValue = yieldOp->getOperand(idx: resultNum); |
| 406 | if (auto bufferType = dyn_cast<BaseMemRefType>(Val: yieldedValue.getType())) |
| 407 | return bufferType; |
| 408 | auto maybeBufferType = bufferization::getBufferType( |
| 409 | value: yieldedValue, options, state, invocationStack); |
| 410 | return bufferization::detail::asMemRefType(bufferType: maybeBufferType); |
| 411 | }; |
| 412 | |
| 413 | // Compute buffer type of the default case. |
| 414 | auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock()); |
| 415 | if (failed(Result: maybeBufferType)) |
| 416 | return failure(); |
| 417 | BaseMemRefType bufferType = *maybeBufferType; |
| 418 | |
| 419 | // Compute buffer types of all other cases. |
| 420 | for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { |
| 421 | auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(idx: i)); |
| 422 | if (failed(Result: yieldedBufferType)) |
| 423 | return failure(); |
| 424 | |
| 425 | // Best case: Both branches have the exact same buffer type. |
| 426 | if (bufferType == *yieldedBufferType) |
| 427 | continue; |
| 428 | |
| 429 | // Memory space mismatch. |
| 430 | if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace()) |
| 431 | return op->emitError(message: "inconsistent memory space on switch cases" ); |
| 432 | |
| 433 | // Layout maps are different: Promote to fully dynamic layout map. |
| 434 | bufferType = getMemRefTypeWithFullyDynamicLayout( |
| 435 | tensorType: cast<TensorType>(Val: value.getType()), memorySpace: bufferType.getMemorySpace()); |
| 436 | } |
| 437 | |
| 438 | return cast<BufferLikeType>(Val&: bufferType); |
| 439 | } |
| 440 | }; |
| 441 | |
| 442 | /// Helper function for loop bufferization. Return the indices of all values |
| 443 | /// that have a tensor type. |
| 444 | static DenseSet<int64_t> getTensorIndices(ValueRange values) { |
| 445 | DenseSet<int64_t> result; |
| 446 | for (const auto &it : llvm::enumerate(First&: values)) |
| 447 | if (isa<TensorType>(Val: it.value().getType())) |
| 448 | result.insert(V: it.index()); |
| 449 | return result; |
| 450 | } |
| 451 | |
| 452 | /// Helper function for loop bufferization. Return the indices of all |
| 453 | /// bbArg/yielded value pairs who's buffer relation is "Equivalent". |
| 454 | DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, |
| 455 | ValueRange yieldedValues, |
| 456 | const AnalysisState &state) { |
| 457 | unsigned int minSize = std::min(a: bbArgs.size(), b: yieldedValues.size()); |
| 458 | DenseSet<int64_t> result; |
| 459 | for (unsigned int i = 0; i < minSize; ++i) { |
| 460 | if (!isa<TensorType>(Val: bbArgs[i].getType()) || |
| 461 | !isa<TensorType>(Val: yieldedValues[i].getType())) |
| 462 | continue; |
| 463 | if (state.areEquivalentBufferizedValues(v1: bbArgs[i], v2: yieldedValues[i])) |
| 464 | result.insert(V: i); |
| 465 | } |
| 466 | return result; |
| 467 | } |
| 468 | |
| 469 | /// Helper function for loop bufferization. Return the bufferized values of the |
| 470 | /// given OpOperands. If an operand is not a tensor, return the original value. |
| 471 | static FailureOr<SmallVector<Value>> |
| 472 | getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, |
| 473 | const BufferizationOptions &options, BufferizationState &state) { |
| 474 | SmallVector<Value> result; |
| 475 | for (OpOperand &opOperand : operands) { |
| 476 | if (isa<TensorType>(Val: opOperand.get().getType())) { |
| 477 | FailureOr<Value> resultBuffer = |
| 478 | getBuffer(rewriter, value: opOperand.get(), options, state); |
| 479 | if (failed(Result: resultBuffer)) |
| 480 | return failure(); |
| 481 | result.push_back(Elt: *resultBuffer); |
| 482 | } else { |
| 483 | result.push_back(Elt: opOperand.get()); |
| 484 | } |
| 485 | } |
| 486 | return result; |
| 487 | } |
| 488 | |
| 489 | /// Helper function for loop bufferization. Given a list of bbArgs of the new |
| 490 | /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into |
| 491 | /// ToTensorOps, so that the block body can be moved over to the new op. |
| 492 | static SmallVector<Value> |
| 493 | getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, |
| 494 | Block::BlockArgListType oldBbArgs, |
| 495 | const DenseSet<int64_t> &tensorIndices) { |
| 496 | SmallVector<Value> result; |
| 497 | for (const auto &it : llvm::enumerate(First&: bbArgs)) { |
| 498 | size_t idx = it.index(); |
| 499 | Value val = it.value(); |
| 500 | if (tensorIndices.contains(V: idx)) { |
| 501 | result.push_back(Elt: rewriter |
| 502 | .create<bufferization::ToTensorOp>( |
| 503 | location: val.getLoc(), args: oldBbArgs[idx].getType(), args&: val) |
| 504 | .getResult()); |
| 505 | } else { |
| 506 | result.push_back(Elt: val); |
| 507 | } |
| 508 | } |
| 509 | return result; |
| 510 | } |
| 511 | |
| 512 | /// Compute the bufferized type of a loop iter_arg. This type must be equal to |
| 513 | /// the bufferized type of the corresponding init_arg and the bufferized type |
| 514 | /// of the corresponding yielded value. |
| 515 | /// |
| 516 | /// This function uses bufferization::getBufferType to compute the bufferized |
| 517 | /// type of the init_arg and of the yielded value. (The computation of the |
| 518 | /// bufferized yielded value type usually requires computing the bufferized type |
| 519 | /// of the iter_arg again; the implementation of getBufferType traces back the |
| 520 | /// use-def chain of the given value and computes a buffer type along the way.) |
| 521 | /// If both buffer types are equal, no casts are needed the computed buffer type |
| 522 | /// can be used directly. Otherwise, the buffer types can only differ in their |
| 523 | /// layout map and a cast must be inserted. |
| 524 | static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType( |
| 525 | Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue, |
| 526 | const BufferizationOptions &options, const BufferizationState &state, |
| 527 | SmallVector<Value> &invocationStack) { |
| 528 | // Determine the buffer type of the init_arg. |
| 529 | auto initArgBufferType = |
| 530 | bufferization::getBufferType(value: initArg, options, state, invocationStack); |
| 531 | if (failed(Result: initArgBufferType)) |
| 532 | return failure(); |
| 533 | |
| 534 | if (llvm::count(Range&: invocationStack, Element: iterArg) >= 2) { |
| 535 | // If the iter_arg is already twice on the invocation stack, just take the |
| 536 | // type of the init_arg. This is to avoid infinite loops when calculating |
| 537 | // the buffer type. This will most likely result in computing a memref type |
| 538 | // with a fully dynamic layout map. |
| 539 | |
| 540 | // Note: For more precise layout map computation, a fixpoint iteration could |
| 541 | // be done (i.e., re-computing the yielded buffer type until the bufferized |
| 542 | // iter_arg type no longer changes). This current implementation immediately |
| 543 | // switches to a fully dynamic layout map when a mismatch between bufferized |
| 544 | // init_arg type and bufferized yield value type is detected. |
| 545 | return *initArgBufferType; |
| 546 | } |
| 547 | |
| 548 | // Compute the buffer type of the yielded value. |
| 549 | BufferLikeType yieldedValueBufferType; |
| 550 | if (isa<BaseMemRefType>(Val: yieldedValue.getType())) { |
| 551 | // scf.yield was already bufferized. |
| 552 | yieldedValueBufferType = cast<BufferLikeType>(Val: yieldedValue.getType()); |
| 553 | } else { |
| 554 | // Note: This typically triggers a recursive call for the buffer type of |
| 555 | // the iter_arg. |
| 556 | auto maybeBufferType = bufferization::getBufferType(value: yieldedValue, options, |
| 557 | state, invocationStack); |
| 558 | if (failed(Result: maybeBufferType)) |
| 559 | return failure(); |
| 560 | yieldedValueBufferType = *maybeBufferType; |
| 561 | } |
| 562 | |
| 563 | // If yielded type and init_arg type are the same, use that type directly. |
| 564 | if (*initArgBufferType == yieldedValueBufferType) |
| 565 | return yieldedValueBufferType; |
| 566 | |
| 567 | // If there is a mismatch between the yielded buffer type and the init_arg |
| 568 | // buffer type, the buffer type must be promoted to a fully dynamic layout |
| 569 | // map. |
| 570 | auto yieldedBufferType = cast<BaseMemRefType>(Val&: yieldedValueBufferType); |
| 571 | auto iterTensorType = cast<TensorType>(Val: iterArg.getType()); |
| 572 | auto initBufferType = llvm::cast<BaseMemRefType>(Val&: *initArgBufferType); |
| 573 | if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace()) |
| 574 | return loopOp->emitOpError( |
| 575 | message: "init_arg and yielded value bufferize to inconsistent memory spaces" ); |
| 576 | #ifndef NDEBUG |
| 577 | if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) { |
| 578 | assert( |
| 579 | llvm::all_equal({yieldedRankedBufferType.getShape(), |
| 580 | cast<MemRefType>(initBufferType).getShape(), |
| 581 | cast<RankedTensorType>(iterTensorType).getShape()}) && |
| 582 | "expected same shape" ); |
| 583 | } |
| 584 | #endif // NDEBUG |
| 585 | return cast<BufferLikeType>(Val: getMemRefTypeWithFullyDynamicLayout( |
| 586 | tensorType: iterTensorType, memorySpace: yieldedBufferType.getMemorySpace())); |
| 587 | } |
| 588 | |
| 589 | /// Return `true` if the given loop may have 0 iterations. |
| 590 | bool mayHaveZeroIterations(scf::ForOp forOp) { |
| 591 | std::optional<int64_t> lb = getConstantIntValue(ofr: forOp.getLowerBound()); |
| 592 | std::optional<int64_t> ub = getConstantIntValue(ofr: forOp.getUpperBound()); |
| 593 | if (!lb.has_value() || !ub.has_value()) |
| 594 | return true; |
| 595 | return *ub <= *lb; |
| 596 | } |
| 597 | |
| 598 | /// Bufferization of scf.for. Replace with a new scf.for that operates on |
| 599 | /// memrefs. |
| 600 | struct ForOpInterface |
| 601 | : public BufferizableOpInterface::ExternalModel<ForOpInterface, |
| 602 | scf::ForOp> { |
| 603 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| 604 | const AnalysisState &state) const { |
| 605 | auto forOp = cast<scf::ForOp>(Val: op); |
| 606 | |
| 607 | // If the loop has zero iterations, the results of the op are their |
| 608 | // corresponding init_args, meaning that the init_args bufferize to a read. |
| 609 | if (mayHaveZeroIterations(forOp)) |
| 610 | return true; |
| 611 | |
| 612 | // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of |
| 613 | // its matching bbArg may. |
| 614 | return state.isValueRead(value: forOp.getTiedLoopRegionIterArg(opOperand: &opOperand)); |
| 615 | } |
| 616 | |
| 617 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| 618 | const AnalysisState &state) const { |
| 619 | // Tensor iter_args of scf::ForOps are always considered as a write. |
| 620 | return true; |
| 621 | } |
| 622 | |
| 623 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| 624 | const AnalysisState &state) const { |
| 625 | auto forOp = cast<scf::ForOp>(Val: op); |
| 626 | OpResult opResult = forOp.getTiedLoopResult(opOperand: &opOperand); |
| 627 | BufferRelation relation = bufferRelation(op, opResult, state); |
| 628 | return {{opResult, relation, |
| 629 | /*isDefinite=*/relation == BufferRelation::Equivalent}}; |
| 630 | } |
| 631 | |
| 632 | BufferRelation bufferRelation(Operation *op, OpResult opResult, |
| 633 | const AnalysisState &state) const { |
| 634 | // ForOp results are equivalent to their corresponding init_args if the |
| 635 | // corresponding iter_args and yield values are equivalent. |
| 636 | auto forOp = cast<scf::ForOp>(Val: op); |
| 637 | BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); |
| 638 | bool equivalentYield = state.areEquivalentBufferizedValues( |
| 639 | v1: bbArg, v2: forOp.getTiedLoopYieldedValue(bbArg)->get()); |
| 640 | return equivalentYield ? BufferRelation::Equivalent |
| 641 | : BufferRelation::Unknown; |
| 642 | } |
| 643 | |
| 644 | bool isWritable(Operation *op, Value value, |
| 645 | const AnalysisState &state) const { |
| 646 | // Interestingly, scf::ForOp's bbArg can **always** be viewed |
| 647 | // inplace from the perspective of ops nested under: |
| 648 | // 1. Either the matching iter operand is not bufferized inplace and an |
| 649 | // alloc + optional copy makes the bbArg itself inplaceable. |
| 650 | // 2. Or the matching iter operand is bufferized inplace and bbArg just |
| 651 | // bufferizes to that too. |
| 652 | return true; |
| 653 | } |
| 654 | |
| 655 | LogicalResult |
| 656 | resolveConflicts(Operation *op, RewriterBase &rewriter, |
| 657 | const AnalysisState &analysisState, |
| 658 | const BufferizationState &bufferizationState) const { |
| 659 | auto bufferizableOp = cast<BufferizableOpInterface>(Val: op); |
| 660 | if (failed(Result: bufferizableOp.resolveTensorOpOperandConflicts( |
| 661 | rewriter, analysisState, bufferizationState))) |
| 662 | return failure(); |
| 663 | |
| 664 | if (analysisState.getOptions().copyBeforeWrite) |
| 665 | return success(); |
| 666 | |
| 667 | // According to the `getAliasing...` implementations, a bufferized OpResult |
| 668 | // may alias only with the corresponding bufferized init_arg (or with a |
| 669 | // newly allocated buffer) and not with other buffers defined outside of the |
| 670 | // loop. I.e., the i-th OpResult may alias with the i-th init_arg; |
| 671 | // but not with any other OpOperand. |
| 672 | auto forOp = cast<scf::ForOp>(Val: op); |
| 673 | auto yieldOp = cast<scf::YieldOp>(Val: forOp.getBody()->getTerminator()); |
| 674 | OpBuilder::InsertionGuard g(rewriter); |
| 675 | rewriter.setInsertionPoint(yieldOp); |
| 676 | |
| 677 | // Indices of all iter_args that have tensor type. These are the ones that |
| 678 | // are bufferized. |
| 679 | DenseSet<int64_t> indices = getTensorIndices(values: forOp.getInitArgs()); |
| 680 | // For every yielded value, does it alias with something defined outside of |
| 681 | // the loop? |
| 682 | SmallVector<Value> yieldValues; |
| 683 | for (const auto it : llvm::enumerate(First: yieldOp.getResults())) { |
| 684 | // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this |
| 685 | // type cannot be used in the signature of `resolveConflicts` because the |
| 686 | // op interface is in the "IR" build unit and the `OneShotAnalysisState` |
| 687 | // is defined in the "Transforms" build unit. |
| 688 | if (!indices.contains(V: it.index()) || |
| 689 | doesNotAliasExternalValue( |
| 690 | value: it.value(), region: &forOp.getRegion(), |
| 691 | /*exceptions=*/forOp.getRegionIterArg(index: it.index()), |
| 692 | state: static_cast<const OneShotAnalysisState &>(analysisState))) { |
| 693 | yieldValues.push_back(Elt: it.value()); |
| 694 | continue; |
| 695 | } |
| 696 | FailureOr<Value> alloc = allocateTensorForShapedValue( |
| 697 | b&: rewriter, loc: yieldOp.getLoc(), shapedValue: it.value(), options: analysisState.getOptions(), |
| 698 | state: bufferizationState); |
| 699 | if (failed(Result: alloc)) |
| 700 | return failure(); |
| 701 | yieldValues.push_back(Elt: *alloc); |
| 702 | } |
| 703 | |
| 704 | rewriter.modifyOpInPlace( |
| 705 | root: yieldOp, callable: [&]() { yieldOp.getResultsMutable().assign(values: yieldValues); }); |
| 706 | return success(); |
| 707 | } |
| 708 | |
| 709 | FailureOr<BufferLikeType> |
| 710 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| 711 | const BufferizationState &state, |
| 712 | SmallVector<Value> &invocationStack) const { |
| 713 | auto forOp = cast<scf::ForOp>(Val: op); |
| 714 | assert(getOwnerOfValue(value) == op && "invalid value" ); |
| 715 | assert(isa<TensorType>(value.getType()) && "expected tensor type" ); |
| 716 | |
| 717 | if (auto opResult = dyn_cast<OpResult>(Val&: value)) { |
| 718 | // The type of an OpResult must match the corresponding iter_arg type. |
| 719 | BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); |
| 720 | return bufferization::getBufferType(value: bbArg, options, state, |
| 721 | invocationStack); |
| 722 | } |
| 723 | |
| 724 | // Compute result/argument number. |
| 725 | BlockArgument bbArg = cast<BlockArgument>(Val&: value); |
| 726 | unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber(); |
| 727 | |
| 728 | // Compute the bufferized type. |
| 729 | auto yieldOp = cast<scf::YieldOp>(Val: forOp.getBody()->getTerminator()); |
| 730 | Value yieldedValue = yieldOp.getOperand(i: resultNum); |
| 731 | BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum]; |
| 732 | Value initArg = forOp.getInitArgs()[resultNum]; |
| 733 | return computeLoopRegionIterArgBufferType( |
| 734 | loopOp: op, iterArg, initArg, yieldedValue, options, state, invocationStack); |
| 735 | } |
| 736 | |
| 737 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 738 | const BufferizationOptions &options, |
| 739 | BufferizationState &state) const { |
| 740 | auto forOp = cast<scf::ForOp>(Val: op); |
| 741 | Block *oldLoopBody = forOp.getBody(); |
| 742 | |
| 743 | // Indices of all iter_args that have tensor type. These are the ones that |
| 744 | // are bufferized. |
| 745 | DenseSet<int64_t> indices = getTensorIndices(values: forOp.getInitArgs()); |
| 746 | |
| 747 | // The new memref init_args of the loop. |
| 748 | FailureOr<SmallVector<Value>> maybeInitArgs = |
| 749 | getBuffers(rewriter, operands: forOp.getInitArgsMutable(), options, state); |
| 750 | if (failed(Result: maybeInitArgs)) |
| 751 | return failure(); |
| 752 | SmallVector<Value> initArgs = *maybeInitArgs; |
| 753 | |
| 754 | // Cast init_args if necessary. |
| 755 | SmallVector<Value> castedInitArgs; |
| 756 | for (const auto &it : llvm::enumerate(First&: initArgs)) { |
| 757 | Value initArg = it.value(); |
| 758 | Value result = forOp->getResult(idx: it.index()); |
| 759 | // If the type is not a tensor, bufferization doesn't need to touch it. |
| 760 | if (!isa<TensorType>(Val: result.getType())) { |
| 761 | castedInitArgs.push_back(Elt: initArg); |
| 762 | continue; |
| 763 | } |
| 764 | auto targetType = bufferization::getBufferType(value: result, options, state); |
| 765 | if (failed(Result: targetType)) |
| 766 | return failure(); |
| 767 | castedInitArgs.push_back(Elt: castBuffer(b&: rewriter, buffer: initArg, type: *targetType)); |
| 768 | } |
| 769 | |
| 770 | // Construct a new scf.for op with memref instead of tensor values. |
| 771 | auto newForOp = rewriter.create<scf::ForOp>( |
| 772 | location: forOp.getLoc(), args: forOp.getLowerBound(), args: forOp.getUpperBound(), |
| 773 | args: forOp.getStep(), args&: castedInitArgs); |
| 774 | newForOp->setAttrs(forOp->getAttrs()); |
| 775 | Block *loopBody = newForOp.getBody(); |
| 776 | |
| 777 | // Set up new iter_args. The loop body uses tensors, so wrap the (memref) |
| 778 | // iter_args of the new loop in ToTensorOps. |
| 779 | rewriter.setInsertionPointToStart(loopBody); |
| 780 | SmallVector<Value> iterArgs = |
| 781 | getBbArgReplacements(rewriter, bbArgs: newForOp.getRegionIterArgs(), |
| 782 | oldBbArgs: forOp.getRegionIterArgs(), tensorIndices: indices); |
| 783 | iterArgs.insert(I: iterArgs.begin(), Elt: newForOp.getInductionVar()); |
| 784 | |
| 785 | // Move loop body to new loop. |
| 786 | rewriter.mergeBlocks(source: oldLoopBody, dest: loopBody, argValues: iterArgs); |
| 787 | |
| 788 | // Replace loop results. |
| 789 | replaceOpWithBufferizedValues(rewriter, op, values: newForOp->getResults()); |
| 790 | |
| 791 | return success(); |
| 792 | } |
| 793 | |
| 794 | /// Assert that yielded values of an scf.for op are equivalent to their |
| 795 | /// corresponding bbArgs. In that case, the buffer relations of the |
| 796 | /// corresponding OpResults are "Equivalent". |
| 797 | /// |
| 798 | /// If this is not the case, an allocs+copies are inserted and yielded from |
| 799 | /// the loop. This could be a performance problem, so it must be explicitly |
| 800 | /// activated with `alloc-return-allocs`. |
| 801 | LogicalResult verifyAnalysis(Operation *op, |
| 802 | const AnalysisState &state) const { |
| 803 | const auto &options = |
| 804 | static_cast<const OneShotBufferizationOptions &>(state.getOptions()); |
| 805 | if (options.allowReturnAllocsFromLoops) |
| 806 | return success(); |
| 807 | |
| 808 | auto forOp = cast<scf::ForOp>(Val: op); |
| 809 | auto yieldOp = cast<scf::YieldOp>(Val: forOp.getBody()->getTerminator()); |
| 810 | for (OpResult opResult : op->getOpResults()) { |
| 811 | if (!isa<TensorType>(Val: opResult.getType())) |
| 812 | continue; |
| 813 | |
| 814 | // Note: This is overly strict. We should check for aliasing bufferized |
| 815 | // values. But we don't have a "must-alias" analysis yet. |
| 816 | if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) |
| 817 | return yieldOp->emitError() |
| 818 | << "Yield operand #" << opResult.getResultNumber() |
| 819 | << " is not equivalent to the corresponding iter bbArg" ; |
| 820 | } |
| 821 | |
| 822 | return success(); |
| 823 | } |
| 824 | }; |
| 825 | |
| 826 | /// Bufferization of scf.while. Replace with a new scf.while that operates on |
| 827 | /// memrefs. |
| 828 | struct WhileOpInterface |
| 829 | : public BufferizableOpInterface::ExternalModel<WhileOpInterface, |
| 830 | scf::WhileOp> { |
| 831 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| 832 | const AnalysisState &state) const { |
| 833 | // Tensor iter_args of scf::WhileOps are always considered as a read. |
| 834 | return true; |
| 835 | } |
| 836 | |
| 837 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| 838 | const AnalysisState &state) const { |
| 839 | // Tensor iter_args of scf::WhileOps are always considered as a write. |
| 840 | return true; |
| 841 | } |
| 842 | |
| 843 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| 844 | const AnalysisState &state) const { |
| 845 | auto whileOp = cast<scf::WhileOp>(Val: op); |
| 846 | unsigned int idx = opOperand.getOperandNumber(); |
| 847 | |
| 848 | // The OpResults and OpOperands may not match. They may not even have the |
| 849 | // same type. The number of OpResults and OpOperands can also differ. |
| 850 | if (idx >= op->getNumResults() || |
| 851 | opOperand.get().getType() != op->getResult(idx).getType()) |
| 852 | return {}; |
| 853 | |
| 854 | // The only aliasing OpResult may be the one at the same index. |
| 855 | OpResult opResult = whileOp->getResult(idx); |
| 856 | BufferRelation relation = bufferRelation(op, opResult, state); |
| 857 | return {{opResult, relation, |
| 858 | /*isDefinite=*/relation == BufferRelation::Equivalent}}; |
| 859 | } |
| 860 | |
| 861 | BufferRelation bufferRelation(Operation *op, OpResult opResult, |
| 862 | const AnalysisState &state) const { |
| 863 | // WhileOp results are equivalent to their corresponding init_args if the |
| 864 | // corresponding iter_args and yield values are equivalent (for both the |
| 865 | // "before" and the "after" block). |
| 866 | unsigned int resultNumber = opResult.getResultNumber(); |
| 867 | auto whileOp = cast<scf::WhileOp>(Val: op); |
| 868 | |
| 869 | // The "before" region bbArgs and the OpResults may not match. |
| 870 | if (resultNumber >= whileOp.getBeforeArguments().size()) |
| 871 | return BufferRelation::Unknown; |
| 872 | if (opResult.getType() != |
| 873 | whileOp.getBeforeArguments()[resultNumber].getType()) |
| 874 | return BufferRelation::Unknown; |
| 875 | |
| 876 | auto conditionOp = whileOp.getConditionOp(); |
| 877 | BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; |
| 878 | Value conditionOperand = conditionOp.getArgs()[resultNumber]; |
| 879 | bool equivCondition = |
| 880 | state.areEquivalentBufferizedValues(v1: conditionBbArg, v2: conditionOperand); |
| 881 | |
| 882 | auto yieldOp = whileOp.getYieldOp(); |
| 883 | BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; |
| 884 | Value yieldOperand = yieldOp.getOperand(i: resultNumber); |
| 885 | bool equivYield = |
| 886 | state.areEquivalentBufferizedValues(v1: bodyBbArg, v2: yieldOperand); |
| 887 | |
| 888 | return equivCondition && equivYield ? BufferRelation::Equivalent |
| 889 | : BufferRelation::Unknown; |
| 890 | } |
| 891 | |
| 892 | bool isWritable(Operation *op, Value value, |
| 893 | const AnalysisState &state) const { |
| 894 | // Interestingly, scf::WhileOp's bbArg can **always** be viewed |
| 895 | // inplace from the perspective of ops nested under: |
| 896 | // 1. Either the matching iter operand is not bufferized inplace and an |
| 897 | // alloc + optional copy makes the bbArg itself inplaceable. |
| 898 | // 2. Or the matching iter operand is bufferized inplace and bbArg just |
| 899 | // bufferizes to that too. |
| 900 | return true; |
| 901 | } |
| 902 | |
| 903 | LogicalResult |
| 904 | resolveConflicts(Operation *op, RewriterBase &rewriter, |
| 905 | const AnalysisState &analysisState, |
| 906 | const BufferizationState &bufferizationState) const { |
| 907 | auto bufferizableOp = cast<BufferizableOpInterface>(Val: op); |
| 908 | if (failed(Result: bufferizableOp.resolveTensorOpOperandConflicts( |
| 909 | rewriter, analysisState, bufferizationState))) |
| 910 | return failure(); |
| 911 | |
| 912 | if (analysisState.getOptions().copyBeforeWrite) |
| 913 | return success(); |
| 914 | |
| 915 | // According to the `getAliasing...` implementations, a bufferized OpResult |
| 916 | // may alias only with the corresponding bufferized init_arg and with no |
| 917 | // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; |
| 918 | // but not with any other OpOperand. If a corresponding OpResult/init_arg |
| 919 | // pair bufferizes to equivalent buffers, this aliasing requirement is |
| 920 | // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. |
| 921 | // (New buffer copies do not alias with any buffer.) |
| 922 | OpBuilder::InsertionGuard g(rewriter); |
| 923 | auto whileOp = cast<scf::WhileOp>(Val: op); |
| 924 | auto conditionOp = whileOp.getConditionOp(); |
| 925 | |
| 926 | // For every yielded value, is the value equivalent to its corresponding |
| 927 | // bbArg? |
| 928 | DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( |
| 929 | bbArgs: whileOp.getBeforeArguments(), yieldedValues: conditionOp.getArgs(), state: analysisState); |
| 930 | DenseSet<int64_t> equivalentYieldsAfter = |
| 931 | getEquivalentBuffers(bbArgs: whileOp.getAfterArguments(), |
| 932 | yieldedValues: whileOp.getYieldOp().getResults(), state: analysisState); |
| 933 | |
| 934 | // Update "before" region. |
| 935 | rewriter.setInsertionPoint(conditionOp); |
| 936 | SmallVector<Value> beforeYieldValues; |
| 937 | for (int64_t idx = 0; |
| 938 | idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) { |
| 939 | Value value = conditionOp.getArgs()[idx]; |
| 940 | if (!isa<TensorType>(Val: value.getType()) || |
| 941 | (equivalentYieldsAfter.contains(V: idx) && |
| 942 | equivalentYieldsBefore.contains(V: idx))) { |
| 943 | beforeYieldValues.push_back(Elt: value); |
| 944 | continue; |
| 945 | } |
| 946 | FailureOr<Value> alloc = allocateTensorForShapedValue( |
| 947 | b&: rewriter, loc: conditionOp.getLoc(), shapedValue: value, options: analysisState.getOptions(), |
| 948 | state: bufferizationState); |
| 949 | if (failed(Result: alloc)) |
| 950 | return failure(); |
| 951 | beforeYieldValues.push_back(Elt: *alloc); |
| 952 | } |
| 953 | rewriter.modifyOpInPlace(root: conditionOp, callable: [&]() { |
| 954 | conditionOp.getArgsMutable().assign(values: beforeYieldValues); |
| 955 | }); |
| 956 | |
| 957 | return success(); |
| 958 | } |
| 959 | |
| 960 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 961 | const BufferizationOptions &options, |
| 962 | BufferizationState &state) const { |
| 963 | auto whileOp = cast<scf::WhileOp>(Val: op); |
| 964 | |
| 965 | // Indices of all bbArgs that have tensor type. These are the ones that |
| 966 | // are bufferized. The "before" and "after" regions may have different args. |
| 967 | DenseSet<int64_t> indicesBefore = getTensorIndices(values: whileOp.getInits()); |
| 968 | DenseSet<int64_t> indicesAfter = |
| 969 | getTensorIndices(values: whileOp.getAfterArguments()); |
| 970 | |
| 971 | // The new memref init_args of the loop. |
| 972 | FailureOr<SmallVector<Value>> maybeInitArgs = |
| 973 | getBuffers(rewriter, operands: whileOp.getInitsMutable(), options, state); |
| 974 | if (failed(Result: maybeInitArgs)) |
| 975 | return failure(); |
| 976 | SmallVector<Value> initArgs = *maybeInitArgs; |
| 977 | |
| 978 | // Cast init_args if necessary. |
| 979 | SmallVector<Value> castedInitArgs; |
| 980 | for (const auto &it : llvm::enumerate(First&: initArgs)) { |
| 981 | Value initArg = it.value(); |
| 982 | Value beforeArg = whileOp.getBeforeArguments()[it.index()]; |
| 983 | // If the type is not a tensor, bufferization doesn't need to touch it. |
| 984 | if (!isa<TensorType>(Val: beforeArg.getType())) { |
| 985 | castedInitArgs.push_back(Elt: initArg); |
| 986 | continue; |
| 987 | } |
| 988 | auto targetType = bufferization::getBufferType(value: beforeArg, options, state); |
| 989 | if (failed(Result: targetType)) |
| 990 | return failure(); |
| 991 | castedInitArgs.push_back(Elt: castBuffer(b&: rewriter, buffer: initArg, type: *targetType)); |
| 992 | } |
| 993 | |
| 994 | // The result types of a WhileOp are the same as the "after" bbArg types. |
| 995 | SmallVector<Type> argsTypesAfter = llvm::to_vector( |
| 996 | Range: llvm::map_range(C: whileOp.getAfterArguments(), F: [&](BlockArgument bbArg) { |
| 997 | if (!isa<TensorType>(Val: bbArg.getType())) |
| 998 | return bbArg.getType(); |
| 999 | // TODO: error handling |
| 1000 | return llvm::cast<Type>( |
| 1001 | Val: *bufferization::getBufferType(value: bbArg, options, state)); |
| 1002 | })); |
| 1003 | |
| 1004 | // Construct a new scf.while op with memref instead of tensor values. |
| 1005 | ValueRange argsRangeBefore(castedInitArgs); |
| 1006 | TypeRange argsTypesBefore(argsRangeBefore); |
| 1007 | auto newWhileOp = rewriter.create<scf::WhileOp>( |
| 1008 | location: whileOp.getLoc(), args&: argsTypesAfter, args&: castedInitArgs); |
| 1009 | |
| 1010 | // Add before/after regions to the new op. |
| 1011 | SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(), |
| 1012 | whileOp.getLoc()); |
| 1013 | SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(), |
| 1014 | whileOp.getLoc()); |
| 1015 | Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); |
| 1016 | newWhileOp.getBefore().addArguments(types: argsTypesBefore, locs: bbArgLocsBefore); |
| 1017 | Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); |
| 1018 | newWhileOp.getAfter().addArguments(types: argsTypesAfter, locs: bbArgLocsAfter); |
| 1019 | |
| 1020 | // Set up new iter_args and move the loop condition block to the new op. |
| 1021 | // The old block uses tensors, so wrap the (memref) bbArgs of the new block |
| 1022 | // in ToTensorOps. |
| 1023 | rewriter.setInsertionPointToStart(newBeforeBody); |
| 1024 | SmallVector<Value> newBeforeArgs = |
| 1025 | getBbArgReplacements(rewriter, bbArgs: newWhileOp.getBeforeArguments(), |
| 1026 | oldBbArgs: whileOp.getBeforeArguments(), tensorIndices: indicesBefore); |
| 1027 | rewriter.mergeBlocks(source: whileOp.getBeforeBody(), dest: newBeforeBody, argValues: newBeforeArgs); |
| 1028 | |
| 1029 | // Set up new iter_args and move the loop body block to the new op. |
| 1030 | // The old block uses tensors, so wrap the (memref) bbArgs of the new block |
| 1031 | // in ToTensorOps. |
| 1032 | rewriter.setInsertionPointToStart(newAfterBody); |
| 1033 | SmallVector<Value> newAfterArgs = |
| 1034 | getBbArgReplacements(rewriter, bbArgs: newWhileOp.getAfterArguments(), |
| 1035 | oldBbArgs: whileOp.getAfterArguments(), tensorIndices: indicesAfter); |
| 1036 | rewriter.mergeBlocks(source: whileOp.getAfterBody(), dest: newAfterBody, argValues: newAfterArgs); |
| 1037 | |
| 1038 | // Replace loop results. |
| 1039 | replaceOpWithBufferizedValues(rewriter, op, values: newWhileOp->getResults()); |
| 1040 | |
| 1041 | return success(); |
| 1042 | } |
| 1043 | |
| 1044 | FailureOr<BufferLikeType> |
| 1045 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| 1046 | const BufferizationState &state, |
| 1047 | SmallVector<Value> &invocationStack) const { |
| 1048 | auto whileOp = cast<scf::WhileOp>(Val: op); |
| 1049 | assert(getOwnerOfValue(value) == op && "invalid value" ); |
| 1050 | assert(isa<TensorType>(value.getType()) && "expected tensor type" ); |
| 1051 | |
| 1052 | // Case 1: Block argument of the "before" region. |
| 1053 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) { |
| 1054 | if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { |
| 1055 | Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; |
| 1056 | auto yieldOp = whileOp.getYieldOp(); |
| 1057 | Value yieldedValue = yieldOp.getOperand(i: bbArg.getArgNumber()); |
| 1058 | return computeLoopRegionIterArgBufferType( |
| 1059 | loopOp: op, iterArg: bbArg, initArg, yieldedValue, options, state, invocationStack); |
| 1060 | } |
| 1061 | } |
| 1062 | |
| 1063 | // Case 2: OpResult of the loop or block argument of the "after" region. |
| 1064 | // The bufferized "after" bbArg type can be directly computed from the |
| 1065 | // bufferized "before" bbArg type. |
| 1066 | unsigned resultNum; |
| 1067 | if (auto opResult = dyn_cast<OpResult>(Val&: value)) { |
| 1068 | resultNum = opResult.getResultNumber(); |
| 1069 | } else if (cast<BlockArgument>(Val&: value).getOwner()->getParent() == |
| 1070 | &whileOp.getAfter()) { |
| 1071 | resultNum = cast<BlockArgument>(Val&: value).getArgNumber(); |
| 1072 | } else { |
| 1073 | llvm_unreachable("invalid value" ); |
| 1074 | } |
| 1075 | Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; |
| 1076 | if (!isa<TensorType>(Val: conditionYieldedVal.getType())) { |
| 1077 | // scf.condition was already bufferized. |
| 1078 | return cast<BufferLikeType>(Val: conditionYieldedVal.getType()); |
| 1079 | } |
| 1080 | return bufferization::getBufferType(value: conditionYieldedVal, options, state, |
| 1081 | invocationStack); |
| 1082 | } |
| 1083 | |
| 1084 | /// Assert that yielded values of an scf.while op are equivalent to their |
| 1085 | /// corresponding bbArgs. In that case, the buffer relations of the |
| 1086 | /// corresponding OpResults are "Equivalent". |
| 1087 | /// |
| 1088 | /// If this is not the case, allocs+copies are inserted and yielded from |
| 1089 | /// the loop. This could be a performance problem, so it must be explicitly |
| 1090 | /// activated with `allow-return-allocs`. |
| 1091 | /// |
| 1092 | /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the |
| 1093 | /// equivalence condition must be checked for both. |
| 1094 | LogicalResult verifyAnalysis(Operation *op, |
| 1095 | const AnalysisState &state) const { |
| 1096 | auto whileOp = cast<scf::WhileOp>(Val: op); |
| 1097 | const auto &options = |
| 1098 | static_cast<const OneShotBufferizationOptions &>(state.getOptions()); |
| 1099 | if (options.allowReturnAllocsFromLoops) |
| 1100 | return success(); |
| 1101 | |
| 1102 | auto conditionOp = whileOp.getConditionOp(); |
| 1103 | for (const auto &it : llvm::enumerate(First: conditionOp.getArgs())) { |
| 1104 | Block *block = conditionOp->getBlock(); |
| 1105 | if (!isa<TensorType>(Val: it.value().getType())) |
| 1106 | continue; |
| 1107 | if (it.index() >= block->getNumArguments() || |
| 1108 | !state.areEquivalentBufferizedValues(v1: it.value(), |
| 1109 | v2: block->getArgument(i: it.index()))) |
| 1110 | return conditionOp->emitError() |
| 1111 | << "Condition arg #" << it.index() |
| 1112 | << " is not equivalent to the corresponding iter bbArg" ; |
| 1113 | } |
| 1114 | |
| 1115 | auto yieldOp = whileOp.getYieldOp(); |
| 1116 | for (const auto &it : llvm::enumerate(First: yieldOp.getResults())) { |
| 1117 | Block *block = yieldOp->getBlock(); |
| 1118 | if (!isa<TensorType>(Val: it.value().getType())) |
| 1119 | continue; |
| 1120 | if (it.index() >= block->getNumArguments() || |
| 1121 | !state.areEquivalentBufferizedValues(v1: it.value(), |
| 1122 | v2: block->getArgument(i: it.index()))) |
| 1123 | return yieldOp->emitError() |
| 1124 | << "Yield operand #" << it.index() |
| 1125 | << " is not equivalent to the corresponding iter bbArg" ; |
| 1126 | } |
| 1127 | |
| 1128 | return success(); |
| 1129 | } |
| 1130 | }; |
| 1131 | |
| 1132 | /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so |
| 1133 | /// this is for analysis only. |
| 1134 | struct YieldOpInterface |
| 1135 | : public BufferizableOpInterface::ExternalModel<YieldOpInterface, |
| 1136 | scf::YieldOp> { |
| 1137 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| 1138 | const AnalysisState &state) const { |
| 1139 | return true; |
| 1140 | } |
| 1141 | |
| 1142 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| 1143 | const AnalysisState &state) const { |
| 1144 | return false; |
| 1145 | } |
| 1146 | |
| 1147 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| 1148 | const AnalysisState &state) const { |
| 1149 | if (auto ifOp = dyn_cast<scf::IfOp>(Val: op->getParentOp())) { |
| 1150 | return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()), |
| 1151 | BufferRelation::Equivalent, /*isDefinite=*/false}}; |
| 1152 | } |
| 1153 | if (isa<scf::ExecuteRegionOp>(Val: op->getParentOp())) |
| 1154 | return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()), |
| 1155 | BufferRelation::Equivalent}}; |
| 1156 | return {}; |
| 1157 | } |
| 1158 | |
| 1159 | bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
| 1160 | const AnalysisState &state) const { |
| 1161 | // Yield operands always bufferize inplace. Otherwise, an alloc + copy |
| 1162 | // may be generated inside the block. We should not return/yield allocations |
| 1163 | // when possible. |
| 1164 | return true; |
| 1165 | } |
| 1166 | |
| 1167 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 1168 | const BufferizationOptions &options, |
| 1169 | BufferizationState &state) const { |
| 1170 | auto yieldOp = cast<scf::YieldOp>(Val: op); |
| 1171 | if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp, |
| 1172 | scf::WhileOp>(Val: yieldOp->getParentOp())) |
| 1173 | return yieldOp->emitError(message: "unsupported scf::YieldOp parent" ); |
| 1174 | |
| 1175 | SmallVector<Value> newResults; |
| 1176 | for (const auto &it : llvm::enumerate(First: yieldOp.getResults())) { |
| 1177 | Value value = it.value(); |
| 1178 | if (isa<TensorType>(Val: value.getType())) { |
| 1179 | FailureOr<Value> maybeBuffer = |
| 1180 | getBuffer(rewriter, value, options, state); |
| 1181 | if (failed(Result: maybeBuffer)) |
| 1182 | return failure(); |
| 1183 | Value buffer = *maybeBuffer; |
| 1184 | // We may have to cast the value before yielding it. |
| 1185 | if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>( |
| 1186 | Val: yieldOp->getParentOp())) { |
| 1187 | FailureOr<BufferLikeType> resultType = bufferization::getBufferType( |
| 1188 | value: yieldOp->getParentOp()->getResult(idx: it.index()), options, state); |
| 1189 | if (failed(Result: resultType)) |
| 1190 | return failure(); |
| 1191 | buffer = castBuffer(b&: rewriter, buffer, type: *resultType); |
| 1192 | } else if (auto whileOp = |
| 1193 | dyn_cast<scf::WhileOp>(Val: yieldOp->getParentOp())) { |
| 1194 | FailureOr<BufferLikeType> resultType = bufferization::getBufferType( |
| 1195 | value: whileOp.getBeforeArguments()[it.index()], options, state); |
| 1196 | if (failed(Result: resultType)) |
| 1197 | return failure(); |
| 1198 | buffer = castBuffer(b&: rewriter, buffer, type: *resultType); |
| 1199 | } |
| 1200 | newResults.push_back(Elt: buffer); |
| 1201 | } else { |
| 1202 | newResults.push_back(Elt: value); |
| 1203 | } |
| 1204 | } |
| 1205 | |
| 1206 | replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, args&: newResults); |
| 1207 | return success(); |
| 1208 | } |
| 1209 | }; |
| 1210 | |
| 1211 | /// Bufferization of ForallOp. This also bufferizes the terminator of the |
| 1212 | /// region. There are op interfaces for the terminators (InParallelOp |
| 1213 | /// and ParallelInsertSliceOp), but these are only used during analysis. Not |
| 1214 | /// for bufferization. |
| 1215 | struct ForallOpInterface |
| 1216 | : public BufferizableOpInterface::ExternalModel<ForallOpInterface, |
| 1217 | ForallOp> { |
| 1218 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| 1219 | const AnalysisState &state) const { |
| 1220 | // All tensor operands to `scf.forall` are `shared_outs` and all |
| 1221 | // shared outs are assumed to be read by the loop. This does not |
| 1222 | // account for the case where the entire value is over-written, |
| 1223 | // but being conservative here. |
| 1224 | return true; |
| 1225 | } |
| 1226 | |
| 1227 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| 1228 | const AnalysisState &state) const { |
| 1229 | // Outputs of scf::ForallOps are always considered as a write. |
| 1230 | return true; |
| 1231 | } |
| 1232 | |
| 1233 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| 1234 | const AnalysisState &state) const { |
| 1235 | auto forallOp = cast<ForallOp>(Val: op); |
| 1236 | return { |
| 1237 | {{forallOp.getTiedOpResult(opOperand: &opOperand), BufferRelation::Equivalent}}}; |
| 1238 | } |
| 1239 | |
| 1240 | bool isWritable(Operation *op, Value value, |
| 1241 | const AnalysisState &state) const { |
| 1242 | return true; |
| 1243 | } |
| 1244 | |
| 1245 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| 1246 | const BufferizationOptions &options, |
| 1247 | BufferizationState &state) const { |
| 1248 | OpBuilder::InsertionGuard guard(rewriter); |
| 1249 | auto forallOp = cast<ForallOp>(Val: op); |
| 1250 | int64_t rank = forallOp.getRank(); |
| 1251 | |
| 1252 | // Get buffers for all output operands. |
| 1253 | SmallVector<Value> buffers; |
| 1254 | for (Value out : forallOp.getOutputs()) { |
| 1255 | FailureOr<Value> buffer = getBuffer(rewriter, value: out, options, state); |
| 1256 | if (failed(Result: buffer)) |
| 1257 | return failure(); |
| 1258 | buffers.push_back(Elt: *buffer); |
| 1259 | } |
| 1260 | |
| 1261 | // Use buffers instead of block arguments. |
| 1262 | rewriter.setInsertionPointToStart(forallOp.getBody()); |
| 1263 | for (const auto &it : llvm::zip( |
| 1264 | t: forallOp.getBody()->getArguments().drop_front(N: rank), u&: buffers)) { |
| 1265 | BlockArgument bbArg = std::get<0>(t: it); |
| 1266 | Value buffer = std::get<1>(t: it); |
| 1267 | Value bufferAsTensor = rewriter.create<ToTensorOp>( |
| 1268 | location: forallOp.getLoc(), args: bbArg.getType(), args&: buffer); |
| 1269 | bbArg.replaceAllUsesWith(newValue: bufferAsTensor); |
| 1270 | } |
| 1271 | |
| 1272 | // Create new ForallOp without any results and drop the automatically |
| 1273 | // introduced terminator. |
| 1274 | rewriter.setInsertionPoint(forallOp); |
| 1275 | ForallOp newForallOp; |
| 1276 | newForallOp = rewriter.create<ForallOp>( |
| 1277 | location: forallOp.getLoc(), args: forallOp.getMixedLowerBound(), |
| 1278 | args: forallOp.getMixedUpperBound(), args: forallOp.getMixedStep(), |
| 1279 | /*outputs=*/args: ValueRange(), args: forallOp.getMapping()); |
| 1280 | |
| 1281 | // Keep discardable attributes from the original op. |
| 1282 | newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); |
| 1283 | |
| 1284 | rewriter.eraseOp(op: newForallOp.getBody()->getTerminator()); |
| 1285 | |
| 1286 | // Move over block contents of the old op. |
| 1287 | SmallVector<Value> replacementBbArgs; |
| 1288 | replacementBbArgs.append(in_start: newForallOp.getBody()->getArguments().begin(), |
| 1289 | in_end: newForallOp.getBody()->getArguments().end()); |
| 1290 | replacementBbArgs.append(NumInputs: forallOp.getOutputs().size(), Elt: Value()); |
| 1291 | rewriter.mergeBlocks(source: forallOp.getBody(), dest: newForallOp.getBody(), |
| 1292 | argValues: replacementBbArgs); |
| 1293 | |
| 1294 | // Remove the old op and replace all of its uses. |
| 1295 | replaceOpWithBufferizedValues(rewriter, op, values: buffers); |
| 1296 | |
| 1297 | return success(); |
| 1298 | } |
| 1299 | |
| 1300 | FailureOr<BufferLikeType> |
| 1301 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| 1302 | const BufferizationState &state, |
| 1303 | SmallVector<Value> &invocationStack) const { |
| 1304 | auto forallOp = cast<ForallOp>(Val: op); |
| 1305 | |
| 1306 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) |
| 1307 | // A tensor block argument has the same bufferized type as the |
| 1308 | // corresponding output operand. |
| 1309 | return bufferization::getBufferType( |
| 1310 | value: forallOp.getTiedOpOperand(bbArg)->get(), options, state, |
| 1311 | invocationStack); |
| 1312 | |
| 1313 | // The bufferized result type is the same as the bufferized type of the |
| 1314 | // corresponding output operand. |
| 1315 | return bufferization::getBufferType( |
| 1316 | value: forallOp.getOutputs()[cast<OpResult>(Val&: value).getResultNumber()], options, |
| 1317 | state, invocationStack); |
| 1318 | } |
| 1319 | |
| 1320 | bool isRepetitiveRegion(Operation *op, unsigned index) const { |
| 1321 | auto forallOp = cast<ForallOp>(Val: op); |
| 1322 | |
| 1323 | // This op is repetitive if it has 1 or more steps. |
| 1324 | // If the control variables are dynamic, it is also considered so. |
| 1325 | for (auto [lb, ub, step] : |
| 1326 | llvm::zip(t: forallOp.getMixedLowerBound(), u: forallOp.getMixedUpperBound(), |
| 1327 | args: forallOp.getMixedStep())) { |
| 1328 | std::optional<int64_t> lbConstant = getConstantIntValue(ofr: lb); |
| 1329 | if (!lbConstant) |
| 1330 | return true; |
| 1331 | |
| 1332 | std::optional<int64_t> ubConstant = getConstantIntValue(ofr: ub); |
| 1333 | if (!ubConstant) |
| 1334 | return true; |
| 1335 | |
| 1336 | std::optional<int64_t> stepConstant = getConstantIntValue(ofr: step); |
| 1337 | if (!stepConstant) |
| 1338 | return true; |
| 1339 | |
| 1340 | if (*lbConstant + *stepConstant < *ubConstant) |
| 1341 | return true; |
| 1342 | } |
| 1343 | return false; |
| 1344 | } |
| 1345 | |
| 1346 | bool isParallelRegion(Operation *op, unsigned index) const { |
| 1347 | return isRepetitiveRegion(op, index); |
| 1348 | } |
| 1349 | }; |
| 1350 | |
| 1351 | /// Nothing to do for InParallelOp. |
| 1352 | struct InParallelOpInterface |
| 1353 | : public BufferizableOpInterface::ExternalModel<InParallelOpInterface, |
| 1354 | InParallelOp> { |
| 1355 | LogicalResult bufferize(Operation *op, RewriterBase &b, |
| 1356 | const BufferizationOptions &options, |
| 1357 | BufferizationState &state) const { |
| 1358 | llvm_unreachable("op does not have any tensor OpOperands / OpResults" ); |
| 1359 | return failure(); |
| 1360 | } |
| 1361 | }; |
| 1362 | |
| 1363 | } // namespace |
| 1364 | } // namespace scf |
| 1365 | } // namespace mlir |
| 1366 | |
| 1367 | void mlir::scf::registerBufferizableOpInterfaceExternalModels( |
| 1368 | DialectRegistry ®istry) { |
| 1369 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) { |
| 1370 | ConditionOp::attachInterface<ConditionOpInterface>(context&: *ctx); |
| 1371 | ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(context&: *ctx); |
| 1372 | ForOp::attachInterface<ForOpInterface>(context&: *ctx); |
| 1373 | IfOp::attachInterface<IfOpInterface>(context&: *ctx); |
| 1374 | IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(context&: *ctx); |
| 1375 | ForallOp::attachInterface<ForallOpInterface>(context&: *ctx); |
| 1376 | InParallelOp::attachInterface<InParallelOpInterface>(context&: *ctx); |
| 1377 | WhileOp::attachInterface<WhileOpInterface>(context&: *ctx); |
| 1378 | YieldOp::attachInterface<YieldOpInterface>(context&: *ctx); |
| 1379 | }); |
| 1380 | } |
| 1381 | |