1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
14#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/Utils/StaticValueUtils.h"
19#include "mlir/IR/Dialect.h"
20#include "mlir/IR/Operation.h"
21#include "mlir/IR/PatternMatch.h"
22
23using namespace mlir;
24using namespace mlir::bufferization;
25using namespace mlir::scf;
26
27namespace mlir {
28namespace scf {
29namespace {
30
31/// Helper function for loop bufferization. Cast the given buffer to the given
32/// memref type.
33static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
34 assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
35 assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
36 // If the buffer already has the correct type, no cast is needed.
37 if (buffer.getType() == type)
38 return buffer;
39 // TODO: In case `type` has a layout map that is not the fully dynamic
40 // one, we may not be able to cast the buffer. In that case, the loop
41 // iter_arg's layout map must be changed (see uses of `castBuffer`).
42 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
43 "scf.while op bufferization: cast incompatible");
44 return b.create<memref::CastOp>(location: buffer.getLoc(), args&: type, args&: buffer).getResult();
45}
46
47/// Helper function for loop bufferization. Return "true" if the given value
48/// is guaranteed to not alias with an external tensor apart from values in
49/// `exceptions`. A value is external if it is defined outside of the given
50/// region or if it is an entry block argument of the region.
51static bool doesNotAliasExternalValue(Value value, Region *region,
52 ValueRange exceptions,
53 const OneShotAnalysisState &state) {
54 assert(llvm::hasSingleElement(region->getBlocks()) &&
55 "expected region with single block");
56 bool result = true;
57 state.applyOnAliases(v: value, fun: [&](Value alias) {
58 if (llvm::is_contained(Range&: exceptions, Element: alias))
59 return;
60 Region *aliasRegion = alias.getParentRegion();
61 if (isa<BlockArgument>(Val: alias) && !region->isProperAncestor(other: aliasRegion))
62 result = false;
63 if (isa<OpResult>(Val: alias) && !region->isAncestor(other: aliasRegion))
64 result = false;
65 });
66 return result;
67}
68
69/// Bufferization of scf.condition.
70struct ConditionOpInterface
71 : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
72 scf::ConditionOp> {
73 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
74 const AnalysisState &state) const {
75 return true;
76 }
77
78 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
79 const AnalysisState &state) const {
80 return false;
81 }
82
83 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
84 const AnalysisState &state) const {
85 return {};
86 }
87
88 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
89 const AnalysisState &state) const {
90 // Condition operands always bufferize inplace. Otherwise, an alloc + copy
91 // may be generated inside the block. We should not return/yield allocations
92 // when possible.
93 return true;
94 }
95
96 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
97 const BufferizationOptions &options,
98 BufferizationState &state) const {
99 auto conditionOp = cast<scf::ConditionOp>(Val: op);
100 auto whileOp = cast<scf::WhileOp>(Val: conditionOp->getParentOp());
101
102 SmallVector<Value> newArgs;
103 for (const auto &it : llvm::enumerate(First: conditionOp.getArgs())) {
104 Value value = it.value();
105 if (isa<TensorType>(Val: value.getType())) {
106 FailureOr<Value> maybeBuffer =
107 getBuffer(rewriter, value, options, state);
108 if (failed(Result: maybeBuffer))
109 return failure();
110 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
111 value: whileOp.getAfterArguments()[it.index()], options, state);
112 if (failed(Result: resultType))
113 return failure();
114 Value buffer = castBuffer(b&: rewriter, buffer: *maybeBuffer, type: *resultType);
115 newArgs.push_back(Elt: buffer);
116 } else {
117 newArgs.push_back(Elt: value);
118 }
119 }
120
121 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
122 rewriter, op, args: conditionOp.getCondition(), args&: newArgs);
123 return success();
124 }
125};
126
127/// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
128/// return an empty op.
129static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
130 scf::YieldOp result;
131 for (Block &block : executeRegionOp.getRegion()) {
132 if (auto yieldOp = dyn_cast<scf::YieldOp>(Val: block.getTerminator())) {
133 if (result)
134 return {};
135 result = yieldOp;
136 }
137 }
138 return result;
139}
140
141/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
142/// fully implemented at the moment.
143struct ExecuteRegionOpInterface
144 : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
145 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
146
147 static bool supportsUnstructuredControlFlow() { return true; }
148
149 bool isWritable(Operation *op, Value value,
150 const AnalysisState &state) const {
151 return true;
152 }
153
154 LogicalResult verifyAnalysis(Operation *op,
155 const AnalysisState &state) const {
156 auto executeRegionOp = cast<scf::ExecuteRegionOp>(Val: op);
157 // TODO: scf.execute_region with multiple yields are not supported.
158 if (!getUniqueYieldOp(executeRegionOp))
159 return op->emitOpError(message: "op without unique scf.yield is not supported");
160 return success();
161 }
162
163 AliasingOpOperandList
164 getAliasingOpOperands(Operation *op, Value value,
165 const AnalysisState &state) const {
166 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value))
167 return getAliasingBranchOpOperands(op, bbArg, state);
168
169 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
170 // any SSA value that is in scope. To allow for use-def chain traversal
171 // through ExecuteRegionOps in the analysis, the corresponding yield value
172 // is considered to be aliasing with the result.
173 auto executeRegionOp = cast<scf::ExecuteRegionOp>(Val: op);
174 auto it = llvm::find(Range: op->getOpResults(), Val: value);
175 assert(it != op->getOpResults().end() && "invalid value");
176 size_t resultNum = std::distance(first: op->getOpResults().begin(), last: it);
177 auto yieldOp = getUniqueYieldOp(executeRegionOp);
178 // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
179 if (!yieldOp)
180 return {};
181 return {{&yieldOp->getOpOperand(idx: resultNum), BufferRelation::Equivalent}};
182 }
183
184 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
185 const BufferizationOptions &options,
186 BufferizationState &state) const {
187 auto executeRegionOp = cast<scf::ExecuteRegionOp>(Val: op);
188 auto yieldOp = getUniqueYieldOp(executeRegionOp);
189 TypeRange newResultTypes(yieldOp.getResults());
190
191 // Create new op and move over region.
192 auto newOp =
193 rewriter.create<scf::ExecuteRegionOp>(location: op->getLoc(), args&: newResultTypes);
194 newOp.getRegion().takeBody(other&: executeRegionOp.getRegion());
195
196 // Bufferize every block.
197 for (Block &block : newOp.getRegion())
198 if (failed(Result: bufferization::bufferizeBlockSignature(block: &block, rewriter,
199 options, state)))
200 return failure();
201
202 // Update all uses of the old op.
203 rewriter.setInsertionPointAfter(newOp);
204 SmallVector<Value> newResults;
205 for (const auto &it : llvm::enumerate(First: executeRegionOp->getResultTypes())) {
206 if (isa<TensorType>(Val: it.value())) {
207 newResults.push_back(Elt: rewriter.create<bufferization::ToTensorOp>(
208 location: executeRegionOp.getLoc(), args&: it.value(),
209 args: newOp->getResult(idx: it.index())));
210 } else {
211 newResults.push_back(Elt: newOp->getResult(idx: it.index()));
212 }
213 }
214
215 // Replace old op.
216 rewriter.replaceOp(op: executeRegionOp, newValues: newResults);
217
218 return success();
219 }
220};
221
222/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
223struct IfOpInterface
224 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
225 AliasingOpOperandList
226 getAliasingOpOperands(Operation *op, Value value,
227 const AnalysisState &state) const {
228 // IfOps do not have tensor OpOperands. The yielded value can be any SSA
229 // value that is in scope. To allow for use-def chain traversal through
230 // IfOps in the analysis, both corresponding yield values from the then/else
231 // branches are considered to be aliasing with the result.
232 auto ifOp = cast<scf::IfOp>(Val: op);
233 size_t resultNum = std::distance(first: op->getOpResults().begin(),
234 last: llvm::find(Range: op->getOpResults(), Val: value));
235 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(idx: resultNum);
236 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(idx: resultNum);
237 return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
238 {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
239 }
240
241 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
242 const BufferizationOptions &options,
243 BufferizationState &state) const {
244 OpBuilder::InsertionGuard g(rewriter);
245 auto ifOp = cast<scf::IfOp>(Val: op);
246
247 // Compute bufferized result types.
248 SmallVector<Type> newTypes;
249 for (Value result : ifOp.getResults()) {
250 if (!isa<TensorType>(Val: result.getType())) {
251 newTypes.push_back(Elt: result.getType());
252 continue;
253 }
254 auto bufferType = bufferization::getBufferType(value: result, options, state);
255 if (failed(Result: bufferType))
256 return failure();
257 newTypes.push_back(Elt: *bufferType);
258 }
259
260 // Create new op.
261 rewriter.setInsertionPoint(ifOp);
262 auto newIfOp =
263 rewriter.create<scf::IfOp>(location: ifOp.getLoc(), args&: newTypes, args: ifOp.getCondition(),
264 /*withElseRegion=*/args: true);
265
266 // Move over then/else blocks.
267 rewriter.mergeBlocks(source: ifOp.thenBlock(), dest: newIfOp.thenBlock());
268 rewriter.mergeBlocks(source: ifOp.elseBlock(), dest: newIfOp.elseBlock());
269
270 // Replace op results.
271 replaceOpWithBufferizedValues(rewriter, op, values: newIfOp->getResults());
272
273 return success();
274 }
275
276 FailureOr<BufferLikeType>
277 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
278 const BufferizationState &state,
279 SmallVector<Value> &invocationStack) const {
280 auto ifOp = cast<scf::IfOp>(Val: op);
281 auto thenYieldOp = cast<scf::YieldOp>(Val: ifOp.thenBlock()->getTerminator());
282 auto elseYieldOp = cast<scf::YieldOp>(Val: ifOp.elseBlock()->getTerminator());
283 assert(value.getDefiningOp() == op && "invalid valid");
284
285 // Determine buffer types of the true/false branches.
286 auto opResult = cast<OpResult>(Val&: value);
287 auto thenValue = thenYieldOp.getOperand(i: opResult.getResultNumber());
288 auto elseValue = elseYieldOp.getOperand(i: opResult.getResultNumber());
289 BaseMemRefType thenBufferType, elseBufferType;
290 if (isa<BaseMemRefType>(Val: thenValue.getType())) {
291 // True branch was already bufferized.
292 thenBufferType = cast<BaseMemRefType>(Val: thenValue.getType());
293 } else {
294 auto maybeBufferType =
295 bufferization::detail::asMemRefType(bufferType: bufferization::getBufferType(
296 value: thenValue, options, state, invocationStack));
297 if (failed(Result: maybeBufferType))
298 return failure();
299 thenBufferType = *maybeBufferType;
300 }
301 if (isa<BaseMemRefType>(Val: elseValue.getType())) {
302 // False branch was already bufferized.
303 elseBufferType = cast<BaseMemRefType>(Val: elseValue.getType());
304 } else {
305 auto maybeBufferType =
306 bufferization::detail::asMemRefType(bufferType: bufferization::getBufferType(
307 value: elseValue, options, state, invocationStack));
308 if (failed(Result: maybeBufferType))
309 return failure();
310 elseBufferType = *maybeBufferType;
311 }
312
313 // Best case: Both branches have the exact same buffer type.
314 if (thenBufferType == elseBufferType)
315 return cast<BufferLikeType>(Val&: thenBufferType);
316
317 // Memory space mismatch.
318 if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
319 return op->emitError(message: "inconsistent memory space on then/else branches");
320
321 // Layout maps are different: Promote to fully dynamic layout map.
322 return cast<BufferLikeType>(Val: getMemRefTypeWithFullyDynamicLayout(
323 tensorType: cast<TensorType>(Val: opResult.getType()), memorySpace: thenBufferType.getMemorySpace()));
324 }
325};
326
327/// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
328/// yields memrefs.
329struct IndexSwitchOpInterface
330 : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
331 scf::IndexSwitchOp> {
332 AliasingOpOperandList
333 getAliasingOpOperands(Operation *op, Value value,
334 const AnalysisState &state) const {
335 // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
336 // any SSA. This is similar to IfOps.
337 auto switchOp = cast<scf::IndexSwitchOp>(Val: op);
338 int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber();
339 AliasingOpOperandList result;
340 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
341 auto yieldOp =
342 cast<scf::YieldOp>(Val: switchOp.getCaseBlock(idx: i).getTerminator());
343 result.addAlias(alias: AliasingOpOperand(&yieldOp->getOpOperand(idx: resultNum),
344 BufferRelation::Equivalent,
345 /*isDefinite=*/false));
346 }
347 auto defaultYieldOp =
348 cast<scf::YieldOp>(Val: switchOp.getDefaultBlock().getTerminator());
349 result.addAlias(alias: AliasingOpOperand(&defaultYieldOp->getOpOperand(idx: resultNum),
350 BufferRelation::Equivalent,
351 /*isDefinite=*/false));
352 return result;
353 }
354
355 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
356 const BufferizationOptions &options,
357 BufferizationState &state) const {
358 OpBuilder::InsertionGuard g(rewriter);
359 auto switchOp = cast<scf::IndexSwitchOp>(Val: op);
360
361 // Compute bufferized result types.
362 SmallVector<Type> newTypes;
363 for (Value result : switchOp.getResults()) {
364 if (!isa<TensorType>(Val: result.getType())) {
365 newTypes.push_back(Elt: result.getType());
366 continue;
367 }
368 auto bufferType = bufferization::getBufferType(value: result, options, state);
369 if (failed(Result: bufferType))
370 return failure();
371 newTypes.push_back(Elt: *bufferType);
372 }
373
374 // Create new op.
375 rewriter.setInsertionPoint(switchOp);
376 auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
377 location: switchOp.getLoc(), args&: newTypes, args: switchOp.getArg(), args: switchOp.getCases(),
378 args: switchOp.getCases().size());
379
380 // Move over blocks.
381 for (auto [src, dest] :
382 llvm::zip(t: switchOp.getCaseRegions(), u: newSwitchOp.getCaseRegions()))
383 rewriter.inlineRegionBefore(region&: src, parent&: dest, before: dest.begin());
384 rewriter.inlineRegionBefore(region&: switchOp.getDefaultRegion(),
385 parent&: newSwitchOp.getDefaultRegion(),
386 before: newSwitchOp.getDefaultRegion().begin());
387
388 // Replace op results.
389 replaceOpWithBufferizedValues(rewriter, op, values: newSwitchOp->getResults());
390
391 return success();
392 }
393
394 FailureOr<BufferLikeType>
395 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
396 const BufferizationState &state,
397 SmallVector<Value> &invocationStack) const {
398 auto switchOp = cast<scf::IndexSwitchOp>(Val: op);
399 assert(value.getDefiningOp() == op && "invalid value");
400 int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber();
401
402 // Helper function to get buffer type of a case.
403 auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
404 auto yieldOp = cast<scf::YieldOp>(Val: b.getTerminator());
405 Value yieldedValue = yieldOp->getOperand(idx: resultNum);
406 if (auto bufferType = dyn_cast<BaseMemRefType>(Val: yieldedValue.getType()))
407 return bufferType;
408 auto maybeBufferType = bufferization::getBufferType(
409 value: yieldedValue, options, state, invocationStack);
410 return bufferization::detail::asMemRefType(bufferType: maybeBufferType);
411 };
412
413 // Compute buffer type of the default case.
414 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
415 if (failed(Result: maybeBufferType))
416 return failure();
417 BaseMemRefType bufferType = *maybeBufferType;
418
419 // Compute buffer types of all other cases.
420 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
421 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(idx: i));
422 if (failed(Result: yieldedBufferType))
423 return failure();
424
425 // Best case: Both branches have the exact same buffer type.
426 if (bufferType == *yieldedBufferType)
427 continue;
428
429 // Memory space mismatch.
430 if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
431 return op->emitError(message: "inconsistent memory space on switch cases");
432
433 // Layout maps are different: Promote to fully dynamic layout map.
434 bufferType = getMemRefTypeWithFullyDynamicLayout(
435 tensorType: cast<TensorType>(Val: value.getType()), memorySpace: bufferType.getMemorySpace());
436 }
437
438 return cast<BufferLikeType>(Val&: bufferType);
439 }
440};
441
442/// Helper function for loop bufferization. Return the indices of all values
443/// that have a tensor type.
444static DenseSet<int64_t> getTensorIndices(ValueRange values) {
445 DenseSet<int64_t> result;
446 for (const auto &it : llvm::enumerate(First&: values))
447 if (isa<TensorType>(Val: it.value().getType()))
448 result.insert(V: it.index());
449 return result;
450}
451
452/// Helper function for loop bufferization. Return the indices of all
453/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
454DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
455 ValueRange yieldedValues,
456 const AnalysisState &state) {
457 unsigned int minSize = std::min(a: bbArgs.size(), b: yieldedValues.size());
458 DenseSet<int64_t> result;
459 for (unsigned int i = 0; i < minSize; ++i) {
460 if (!isa<TensorType>(Val: bbArgs[i].getType()) ||
461 !isa<TensorType>(Val: yieldedValues[i].getType()))
462 continue;
463 if (state.areEquivalentBufferizedValues(v1: bbArgs[i], v2: yieldedValues[i]))
464 result.insert(V: i);
465 }
466 return result;
467}
468
469/// Helper function for loop bufferization. Return the bufferized values of the
470/// given OpOperands. If an operand is not a tensor, return the original value.
471static FailureOr<SmallVector<Value>>
472getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
473 const BufferizationOptions &options, BufferizationState &state) {
474 SmallVector<Value> result;
475 for (OpOperand &opOperand : operands) {
476 if (isa<TensorType>(Val: opOperand.get().getType())) {
477 FailureOr<Value> resultBuffer =
478 getBuffer(rewriter, value: opOperand.get(), options, state);
479 if (failed(Result: resultBuffer))
480 return failure();
481 result.push_back(Elt: *resultBuffer);
482 } else {
483 result.push_back(Elt: opOperand.get());
484 }
485 }
486 return result;
487}
488
489/// Helper function for loop bufferization. Given a list of bbArgs of the new
490/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
491/// ToTensorOps, so that the block body can be moved over to the new op.
492static SmallVector<Value>
493getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
494 Block::BlockArgListType oldBbArgs,
495 const DenseSet<int64_t> &tensorIndices) {
496 SmallVector<Value> result;
497 for (const auto &it : llvm::enumerate(First&: bbArgs)) {
498 size_t idx = it.index();
499 Value val = it.value();
500 if (tensorIndices.contains(V: idx)) {
501 result.push_back(Elt: rewriter
502 .create<bufferization::ToTensorOp>(
503 location: val.getLoc(), args: oldBbArgs[idx].getType(), args&: val)
504 .getResult());
505 } else {
506 result.push_back(Elt: val);
507 }
508 }
509 return result;
510}
511
512/// Compute the bufferized type of a loop iter_arg. This type must be equal to
513/// the bufferized type of the corresponding init_arg and the bufferized type
514/// of the corresponding yielded value.
515///
516/// This function uses bufferization::getBufferType to compute the bufferized
517/// type of the init_arg and of the yielded value. (The computation of the
518/// bufferized yielded value type usually requires computing the bufferized type
519/// of the iter_arg again; the implementation of getBufferType traces back the
520/// use-def chain of the given value and computes a buffer type along the way.)
521/// If both buffer types are equal, no casts are needed the computed buffer type
522/// can be used directly. Otherwise, the buffer types can only differ in their
523/// layout map and a cast must be inserted.
524static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
525 Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
526 const BufferizationOptions &options, const BufferizationState &state,
527 SmallVector<Value> &invocationStack) {
528 // Determine the buffer type of the init_arg.
529 auto initArgBufferType =
530 bufferization::getBufferType(value: initArg, options, state, invocationStack);
531 if (failed(Result: initArgBufferType))
532 return failure();
533
534 if (llvm::count(Range&: invocationStack, Element: iterArg) >= 2) {
535 // If the iter_arg is already twice on the invocation stack, just take the
536 // type of the init_arg. This is to avoid infinite loops when calculating
537 // the buffer type. This will most likely result in computing a memref type
538 // with a fully dynamic layout map.
539
540 // Note: For more precise layout map computation, a fixpoint iteration could
541 // be done (i.e., re-computing the yielded buffer type until the bufferized
542 // iter_arg type no longer changes). This current implementation immediately
543 // switches to a fully dynamic layout map when a mismatch between bufferized
544 // init_arg type and bufferized yield value type is detected.
545 return *initArgBufferType;
546 }
547
548 // Compute the buffer type of the yielded value.
549 BufferLikeType yieldedValueBufferType;
550 if (isa<BaseMemRefType>(Val: yieldedValue.getType())) {
551 // scf.yield was already bufferized.
552 yieldedValueBufferType = cast<BufferLikeType>(Val: yieldedValue.getType());
553 } else {
554 // Note: This typically triggers a recursive call for the buffer type of
555 // the iter_arg.
556 auto maybeBufferType = bufferization::getBufferType(value: yieldedValue, options,
557 state, invocationStack);
558 if (failed(Result: maybeBufferType))
559 return failure();
560 yieldedValueBufferType = *maybeBufferType;
561 }
562
563 // If yielded type and init_arg type are the same, use that type directly.
564 if (*initArgBufferType == yieldedValueBufferType)
565 return yieldedValueBufferType;
566
567 // If there is a mismatch between the yielded buffer type and the init_arg
568 // buffer type, the buffer type must be promoted to a fully dynamic layout
569 // map.
570 auto yieldedBufferType = cast<BaseMemRefType>(Val&: yieldedValueBufferType);
571 auto iterTensorType = cast<TensorType>(Val: iterArg.getType());
572 auto initBufferType = llvm::cast<BaseMemRefType>(Val&: *initArgBufferType);
573 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
574 return loopOp->emitOpError(
575 message: "init_arg and yielded value bufferize to inconsistent memory spaces");
576#ifndef NDEBUG
577 if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
578 assert(
579 llvm::all_equal({yieldedRankedBufferType.getShape(),
580 cast<MemRefType>(initBufferType).getShape(),
581 cast<RankedTensorType>(iterTensorType).getShape()}) &&
582 "expected same shape");
583 }
584#endif // NDEBUG
585 return cast<BufferLikeType>(Val: getMemRefTypeWithFullyDynamicLayout(
586 tensorType: iterTensorType, memorySpace: yieldedBufferType.getMemorySpace()));
587}
588
589/// Return `true` if the given loop may have 0 iterations.
590bool mayHaveZeroIterations(scf::ForOp forOp) {
591 std::optional<int64_t> lb = getConstantIntValue(ofr: forOp.getLowerBound());
592 std::optional<int64_t> ub = getConstantIntValue(ofr: forOp.getUpperBound());
593 if (!lb.has_value() || !ub.has_value())
594 return true;
595 return *ub <= *lb;
596}
597
598/// Bufferization of scf.for. Replace with a new scf.for that operates on
599/// memrefs.
600struct ForOpInterface
601 : public BufferizableOpInterface::ExternalModel<ForOpInterface,
602 scf::ForOp> {
603 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
604 const AnalysisState &state) const {
605 auto forOp = cast<scf::ForOp>(Val: op);
606
607 // If the loop has zero iterations, the results of the op are their
608 // corresponding init_args, meaning that the init_args bufferize to a read.
609 if (mayHaveZeroIterations(forOp))
610 return true;
611
612 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
613 // its matching bbArg may.
614 return state.isValueRead(value: forOp.getTiedLoopRegionIterArg(opOperand: &opOperand));
615 }
616
617 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
618 const AnalysisState &state) const {
619 // Tensor iter_args of scf::ForOps are always considered as a write.
620 return true;
621 }
622
623 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
624 const AnalysisState &state) const {
625 auto forOp = cast<scf::ForOp>(Val: op);
626 OpResult opResult = forOp.getTiedLoopResult(opOperand: &opOperand);
627 BufferRelation relation = bufferRelation(op, opResult, state);
628 return {{opResult, relation,
629 /*isDefinite=*/relation == BufferRelation::Equivalent}};
630 }
631
632 BufferRelation bufferRelation(Operation *op, OpResult opResult,
633 const AnalysisState &state) const {
634 // ForOp results are equivalent to their corresponding init_args if the
635 // corresponding iter_args and yield values are equivalent.
636 auto forOp = cast<scf::ForOp>(Val: op);
637 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
638 bool equivalentYield = state.areEquivalentBufferizedValues(
639 v1: bbArg, v2: forOp.getTiedLoopYieldedValue(bbArg)->get());
640 return equivalentYield ? BufferRelation::Equivalent
641 : BufferRelation::Unknown;
642 }
643
644 bool isWritable(Operation *op, Value value,
645 const AnalysisState &state) const {
646 // Interestingly, scf::ForOp's bbArg can **always** be viewed
647 // inplace from the perspective of ops nested under:
648 // 1. Either the matching iter operand is not bufferized inplace and an
649 // alloc + optional copy makes the bbArg itself inplaceable.
650 // 2. Or the matching iter operand is bufferized inplace and bbArg just
651 // bufferizes to that too.
652 return true;
653 }
654
655 LogicalResult
656 resolveConflicts(Operation *op, RewriterBase &rewriter,
657 const AnalysisState &analysisState,
658 const BufferizationState &bufferizationState) const {
659 auto bufferizableOp = cast<BufferizableOpInterface>(Val: op);
660 if (failed(Result: bufferizableOp.resolveTensorOpOperandConflicts(
661 rewriter, analysisState, bufferizationState)))
662 return failure();
663
664 if (analysisState.getOptions().copyBeforeWrite)
665 return success();
666
667 // According to the `getAliasing...` implementations, a bufferized OpResult
668 // may alias only with the corresponding bufferized init_arg (or with a
669 // newly allocated buffer) and not with other buffers defined outside of the
670 // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
671 // but not with any other OpOperand.
672 auto forOp = cast<scf::ForOp>(Val: op);
673 auto yieldOp = cast<scf::YieldOp>(Val: forOp.getBody()->getTerminator());
674 OpBuilder::InsertionGuard g(rewriter);
675 rewriter.setInsertionPoint(yieldOp);
676
677 // Indices of all iter_args that have tensor type. These are the ones that
678 // are bufferized.
679 DenseSet<int64_t> indices = getTensorIndices(values: forOp.getInitArgs());
680 // For every yielded value, does it alias with something defined outside of
681 // the loop?
682 SmallVector<Value> yieldValues;
683 for (const auto it : llvm::enumerate(First: yieldOp.getResults())) {
684 // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
685 // type cannot be used in the signature of `resolveConflicts` because the
686 // op interface is in the "IR" build unit and the `OneShotAnalysisState`
687 // is defined in the "Transforms" build unit.
688 if (!indices.contains(V: it.index()) ||
689 doesNotAliasExternalValue(
690 value: it.value(), region: &forOp.getRegion(),
691 /*exceptions=*/forOp.getRegionIterArg(index: it.index()),
692 state: static_cast<const OneShotAnalysisState &>(analysisState))) {
693 yieldValues.push_back(Elt: it.value());
694 continue;
695 }
696 FailureOr<Value> alloc = allocateTensorForShapedValue(
697 b&: rewriter, loc: yieldOp.getLoc(), shapedValue: it.value(), options: analysisState.getOptions(),
698 state: bufferizationState);
699 if (failed(Result: alloc))
700 return failure();
701 yieldValues.push_back(Elt: *alloc);
702 }
703
704 rewriter.modifyOpInPlace(
705 root: yieldOp, callable: [&]() { yieldOp.getResultsMutable().assign(values: yieldValues); });
706 return success();
707 }
708
709 FailureOr<BufferLikeType>
710 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
711 const BufferizationState &state,
712 SmallVector<Value> &invocationStack) const {
713 auto forOp = cast<scf::ForOp>(Val: op);
714 assert(getOwnerOfValue(value) == op && "invalid value");
715 assert(isa<TensorType>(value.getType()) && "expected tensor type");
716
717 if (auto opResult = dyn_cast<OpResult>(Val&: value)) {
718 // The type of an OpResult must match the corresponding iter_arg type.
719 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
720 return bufferization::getBufferType(value: bbArg, options, state,
721 invocationStack);
722 }
723
724 // Compute result/argument number.
725 BlockArgument bbArg = cast<BlockArgument>(Val&: value);
726 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
727
728 // Compute the bufferized type.
729 auto yieldOp = cast<scf::YieldOp>(Val: forOp.getBody()->getTerminator());
730 Value yieldedValue = yieldOp.getOperand(i: resultNum);
731 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
732 Value initArg = forOp.getInitArgs()[resultNum];
733 return computeLoopRegionIterArgBufferType(
734 loopOp: op, iterArg, initArg, yieldedValue, options, state, invocationStack);
735 }
736
737 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
738 const BufferizationOptions &options,
739 BufferizationState &state) const {
740 auto forOp = cast<scf::ForOp>(Val: op);
741 Block *oldLoopBody = forOp.getBody();
742
743 // Indices of all iter_args that have tensor type. These are the ones that
744 // are bufferized.
745 DenseSet<int64_t> indices = getTensorIndices(values: forOp.getInitArgs());
746
747 // The new memref init_args of the loop.
748 FailureOr<SmallVector<Value>> maybeInitArgs =
749 getBuffers(rewriter, operands: forOp.getInitArgsMutable(), options, state);
750 if (failed(Result: maybeInitArgs))
751 return failure();
752 SmallVector<Value> initArgs = *maybeInitArgs;
753
754 // Cast init_args if necessary.
755 SmallVector<Value> castedInitArgs;
756 for (const auto &it : llvm::enumerate(First&: initArgs)) {
757 Value initArg = it.value();
758 Value result = forOp->getResult(idx: it.index());
759 // If the type is not a tensor, bufferization doesn't need to touch it.
760 if (!isa<TensorType>(Val: result.getType())) {
761 castedInitArgs.push_back(Elt: initArg);
762 continue;
763 }
764 auto targetType = bufferization::getBufferType(value: result, options, state);
765 if (failed(Result: targetType))
766 return failure();
767 castedInitArgs.push_back(Elt: castBuffer(b&: rewriter, buffer: initArg, type: *targetType));
768 }
769
770 // Construct a new scf.for op with memref instead of tensor values.
771 auto newForOp = rewriter.create<scf::ForOp>(
772 location: forOp.getLoc(), args: forOp.getLowerBound(), args: forOp.getUpperBound(),
773 args: forOp.getStep(), args&: castedInitArgs);
774 newForOp->setAttrs(forOp->getAttrs());
775 Block *loopBody = newForOp.getBody();
776
777 // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
778 // iter_args of the new loop in ToTensorOps.
779 rewriter.setInsertionPointToStart(loopBody);
780 SmallVector<Value> iterArgs =
781 getBbArgReplacements(rewriter, bbArgs: newForOp.getRegionIterArgs(),
782 oldBbArgs: forOp.getRegionIterArgs(), tensorIndices: indices);
783 iterArgs.insert(I: iterArgs.begin(), Elt: newForOp.getInductionVar());
784
785 // Move loop body to new loop.
786 rewriter.mergeBlocks(source: oldLoopBody, dest: loopBody, argValues: iterArgs);
787
788 // Replace loop results.
789 replaceOpWithBufferizedValues(rewriter, op, values: newForOp->getResults());
790
791 return success();
792 }
793
794 /// Assert that yielded values of an scf.for op are equivalent to their
795 /// corresponding bbArgs. In that case, the buffer relations of the
796 /// corresponding OpResults are "Equivalent".
797 ///
798 /// If this is not the case, an allocs+copies are inserted and yielded from
799 /// the loop. This could be a performance problem, so it must be explicitly
800 /// activated with `alloc-return-allocs`.
801 LogicalResult verifyAnalysis(Operation *op,
802 const AnalysisState &state) const {
803 const auto &options =
804 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
805 if (options.allowReturnAllocsFromLoops)
806 return success();
807
808 auto forOp = cast<scf::ForOp>(Val: op);
809 auto yieldOp = cast<scf::YieldOp>(Val: forOp.getBody()->getTerminator());
810 for (OpResult opResult : op->getOpResults()) {
811 if (!isa<TensorType>(Val: opResult.getType()))
812 continue;
813
814 // Note: This is overly strict. We should check for aliasing bufferized
815 // values. But we don't have a "must-alias" analysis yet.
816 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
817 return yieldOp->emitError()
818 << "Yield operand #" << opResult.getResultNumber()
819 << " is not equivalent to the corresponding iter bbArg";
820 }
821
822 return success();
823 }
824};
825
826/// Bufferization of scf.while. Replace with a new scf.while that operates on
827/// memrefs.
828struct WhileOpInterface
829 : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
830 scf::WhileOp> {
831 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
832 const AnalysisState &state) const {
833 // Tensor iter_args of scf::WhileOps are always considered as a read.
834 return true;
835 }
836
837 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
838 const AnalysisState &state) const {
839 // Tensor iter_args of scf::WhileOps are always considered as a write.
840 return true;
841 }
842
843 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
844 const AnalysisState &state) const {
845 auto whileOp = cast<scf::WhileOp>(Val: op);
846 unsigned int idx = opOperand.getOperandNumber();
847
848 // The OpResults and OpOperands may not match. They may not even have the
849 // same type. The number of OpResults and OpOperands can also differ.
850 if (idx >= op->getNumResults() ||
851 opOperand.get().getType() != op->getResult(idx).getType())
852 return {};
853
854 // The only aliasing OpResult may be the one at the same index.
855 OpResult opResult = whileOp->getResult(idx);
856 BufferRelation relation = bufferRelation(op, opResult, state);
857 return {{opResult, relation,
858 /*isDefinite=*/relation == BufferRelation::Equivalent}};
859 }
860
861 BufferRelation bufferRelation(Operation *op, OpResult opResult,
862 const AnalysisState &state) const {
863 // WhileOp results are equivalent to their corresponding init_args if the
864 // corresponding iter_args and yield values are equivalent (for both the
865 // "before" and the "after" block).
866 unsigned int resultNumber = opResult.getResultNumber();
867 auto whileOp = cast<scf::WhileOp>(Val: op);
868
869 // The "before" region bbArgs and the OpResults may not match.
870 if (resultNumber >= whileOp.getBeforeArguments().size())
871 return BufferRelation::Unknown;
872 if (opResult.getType() !=
873 whileOp.getBeforeArguments()[resultNumber].getType())
874 return BufferRelation::Unknown;
875
876 auto conditionOp = whileOp.getConditionOp();
877 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
878 Value conditionOperand = conditionOp.getArgs()[resultNumber];
879 bool equivCondition =
880 state.areEquivalentBufferizedValues(v1: conditionBbArg, v2: conditionOperand);
881
882 auto yieldOp = whileOp.getYieldOp();
883 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
884 Value yieldOperand = yieldOp.getOperand(i: resultNumber);
885 bool equivYield =
886 state.areEquivalentBufferizedValues(v1: bodyBbArg, v2: yieldOperand);
887
888 return equivCondition && equivYield ? BufferRelation::Equivalent
889 : BufferRelation::Unknown;
890 }
891
892 bool isWritable(Operation *op, Value value,
893 const AnalysisState &state) const {
894 // Interestingly, scf::WhileOp's bbArg can **always** be viewed
895 // inplace from the perspective of ops nested under:
896 // 1. Either the matching iter operand is not bufferized inplace and an
897 // alloc + optional copy makes the bbArg itself inplaceable.
898 // 2. Or the matching iter operand is bufferized inplace and bbArg just
899 // bufferizes to that too.
900 return true;
901 }
902
903 LogicalResult
904 resolveConflicts(Operation *op, RewriterBase &rewriter,
905 const AnalysisState &analysisState,
906 const BufferizationState &bufferizationState) const {
907 auto bufferizableOp = cast<BufferizableOpInterface>(Val: op);
908 if (failed(Result: bufferizableOp.resolveTensorOpOperandConflicts(
909 rewriter, analysisState, bufferizationState)))
910 return failure();
911
912 if (analysisState.getOptions().copyBeforeWrite)
913 return success();
914
915 // According to the `getAliasing...` implementations, a bufferized OpResult
916 // may alias only with the corresponding bufferized init_arg and with no
917 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
918 // but not with any other OpOperand. If a corresponding OpResult/init_arg
919 // pair bufferizes to equivalent buffers, this aliasing requirement is
920 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
921 // (New buffer copies do not alias with any buffer.)
922 OpBuilder::InsertionGuard g(rewriter);
923 auto whileOp = cast<scf::WhileOp>(Val: op);
924 auto conditionOp = whileOp.getConditionOp();
925
926 // For every yielded value, is the value equivalent to its corresponding
927 // bbArg?
928 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
929 bbArgs: whileOp.getBeforeArguments(), yieldedValues: conditionOp.getArgs(), state: analysisState);
930 DenseSet<int64_t> equivalentYieldsAfter =
931 getEquivalentBuffers(bbArgs: whileOp.getAfterArguments(),
932 yieldedValues: whileOp.getYieldOp().getResults(), state: analysisState);
933
934 // Update "before" region.
935 rewriter.setInsertionPoint(conditionOp);
936 SmallVector<Value> beforeYieldValues;
937 for (int64_t idx = 0;
938 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
939 Value value = conditionOp.getArgs()[idx];
940 if (!isa<TensorType>(Val: value.getType()) ||
941 (equivalentYieldsAfter.contains(V: idx) &&
942 equivalentYieldsBefore.contains(V: idx))) {
943 beforeYieldValues.push_back(Elt: value);
944 continue;
945 }
946 FailureOr<Value> alloc = allocateTensorForShapedValue(
947 b&: rewriter, loc: conditionOp.getLoc(), shapedValue: value, options: analysisState.getOptions(),
948 state: bufferizationState);
949 if (failed(Result: alloc))
950 return failure();
951 beforeYieldValues.push_back(Elt: *alloc);
952 }
953 rewriter.modifyOpInPlace(root: conditionOp, callable: [&]() {
954 conditionOp.getArgsMutable().assign(values: beforeYieldValues);
955 });
956
957 return success();
958 }
959
960 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
961 const BufferizationOptions &options,
962 BufferizationState &state) const {
963 auto whileOp = cast<scf::WhileOp>(Val: op);
964
965 // Indices of all bbArgs that have tensor type. These are the ones that
966 // are bufferized. The "before" and "after" regions may have different args.
967 DenseSet<int64_t> indicesBefore = getTensorIndices(values: whileOp.getInits());
968 DenseSet<int64_t> indicesAfter =
969 getTensorIndices(values: whileOp.getAfterArguments());
970
971 // The new memref init_args of the loop.
972 FailureOr<SmallVector<Value>> maybeInitArgs =
973 getBuffers(rewriter, operands: whileOp.getInitsMutable(), options, state);
974 if (failed(Result: maybeInitArgs))
975 return failure();
976 SmallVector<Value> initArgs = *maybeInitArgs;
977
978 // Cast init_args if necessary.
979 SmallVector<Value> castedInitArgs;
980 for (const auto &it : llvm::enumerate(First&: initArgs)) {
981 Value initArg = it.value();
982 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
983 // If the type is not a tensor, bufferization doesn't need to touch it.
984 if (!isa<TensorType>(Val: beforeArg.getType())) {
985 castedInitArgs.push_back(Elt: initArg);
986 continue;
987 }
988 auto targetType = bufferization::getBufferType(value: beforeArg, options, state);
989 if (failed(Result: targetType))
990 return failure();
991 castedInitArgs.push_back(Elt: castBuffer(b&: rewriter, buffer: initArg, type: *targetType));
992 }
993
994 // The result types of a WhileOp are the same as the "after" bbArg types.
995 SmallVector<Type> argsTypesAfter = llvm::to_vector(
996 Range: llvm::map_range(C: whileOp.getAfterArguments(), F: [&](BlockArgument bbArg) {
997 if (!isa<TensorType>(Val: bbArg.getType()))
998 return bbArg.getType();
999 // TODO: error handling
1000 return llvm::cast<Type>(
1001 Val: *bufferization::getBufferType(value: bbArg, options, state));
1002 }));
1003
1004 // Construct a new scf.while op with memref instead of tensor values.
1005 ValueRange argsRangeBefore(castedInitArgs);
1006 TypeRange argsTypesBefore(argsRangeBefore);
1007 auto newWhileOp = rewriter.create<scf::WhileOp>(
1008 location: whileOp.getLoc(), args&: argsTypesAfter, args&: castedInitArgs);
1009
1010 // Add before/after regions to the new op.
1011 SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
1012 whileOp.getLoc());
1013 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
1014 whileOp.getLoc());
1015 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1016 newWhileOp.getBefore().addArguments(types: argsTypesBefore, locs: bbArgLocsBefore);
1017 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1018 newWhileOp.getAfter().addArguments(types: argsTypesAfter, locs: bbArgLocsAfter);
1019
1020 // Set up new iter_args and move the loop condition block to the new op.
1021 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1022 // in ToTensorOps.
1023 rewriter.setInsertionPointToStart(newBeforeBody);
1024 SmallVector<Value> newBeforeArgs =
1025 getBbArgReplacements(rewriter, bbArgs: newWhileOp.getBeforeArguments(),
1026 oldBbArgs: whileOp.getBeforeArguments(), tensorIndices: indicesBefore);
1027 rewriter.mergeBlocks(source: whileOp.getBeforeBody(), dest: newBeforeBody, argValues: newBeforeArgs);
1028
1029 // Set up new iter_args and move the loop body block to the new op.
1030 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1031 // in ToTensorOps.
1032 rewriter.setInsertionPointToStart(newAfterBody);
1033 SmallVector<Value> newAfterArgs =
1034 getBbArgReplacements(rewriter, bbArgs: newWhileOp.getAfterArguments(),
1035 oldBbArgs: whileOp.getAfterArguments(), tensorIndices: indicesAfter);
1036 rewriter.mergeBlocks(source: whileOp.getAfterBody(), dest: newAfterBody, argValues: newAfterArgs);
1037
1038 // Replace loop results.
1039 replaceOpWithBufferizedValues(rewriter, op, values: newWhileOp->getResults());
1040
1041 return success();
1042 }
1043
1044 FailureOr<BufferLikeType>
1045 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1046 const BufferizationState &state,
1047 SmallVector<Value> &invocationStack) const {
1048 auto whileOp = cast<scf::WhileOp>(Val: op);
1049 assert(getOwnerOfValue(value) == op && "invalid value");
1050 assert(isa<TensorType>(value.getType()) && "expected tensor type");
1051
1052 // Case 1: Block argument of the "before" region.
1053 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) {
1054 if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1055 Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1056 auto yieldOp = whileOp.getYieldOp();
1057 Value yieldedValue = yieldOp.getOperand(i: bbArg.getArgNumber());
1058 return computeLoopRegionIterArgBufferType(
1059 loopOp: op, iterArg: bbArg, initArg, yieldedValue, options, state, invocationStack);
1060 }
1061 }
1062
1063 // Case 2: OpResult of the loop or block argument of the "after" region.
1064 // The bufferized "after" bbArg type can be directly computed from the
1065 // bufferized "before" bbArg type.
1066 unsigned resultNum;
1067 if (auto opResult = dyn_cast<OpResult>(Val&: value)) {
1068 resultNum = opResult.getResultNumber();
1069 } else if (cast<BlockArgument>(Val&: value).getOwner()->getParent() ==
1070 &whileOp.getAfter()) {
1071 resultNum = cast<BlockArgument>(Val&: value).getArgNumber();
1072 } else {
1073 llvm_unreachable("invalid value");
1074 }
1075 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1076 if (!isa<TensorType>(Val: conditionYieldedVal.getType())) {
1077 // scf.condition was already bufferized.
1078 return cast<BufferLikeType>(Val: conditionYieldedVal.getType());
1079 }
1080 return bufferization::getBufferType(value: conditionYieldedVal, options, state,
1081 invocationStack);
1082 }
1083
1084 /// Assert that yielded values of an scf.while op are equivalent to their
1085 /// corresponding bbArgs. In that case, the buffer relations of the
1086 /// corresponding OpResults are "Equivalent".
1087 ///
1088 /// If this is not the case, allocs+copies are inserted and yielded from
1089 /// the loop. This could be a performance problem, so it must be explicitly
1090 /// activated with `allow-return-allocs`.
1091 ///
1092 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1093 /// equivalence condition must be checked for both.
1094 LogicalResult verifyAnalysis(Operation *op,
1095 const AnalysisState &state) const {
1096 auto whileOp = cast<scf::WhileOp>(Val: op);
1097 const auto &options =
1098 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1099 if (options.allowReturnAllocsFromLoops)
1100 return success();
1101
1102 auto conditionOp = whileOp.getConditionOp();
1103 for (const auto &it : llvm::enumerate(First: conditionOp.getArgs())) {
1104 Block *block = conditionOp->getBlock();
1105 if (!isa<TensorType>(Val: it.value().getType()))
1106 continue;
1107 if (it.index() >= block->getNumArguments() ||
1108 !state.areEquivalentBufferizedValues(v1: it.value(),
1109 v2: block->getArgument(i: it.index())))
1110 return conditionOp->emitError()
1111 << "Condition arg #" << it.index()
1112 << " is not equivalent to the corresponding iter bbArg";
1113 }
1114
1115 auto yieldOp = whileOp.getYieldOp();
1116 for (const auto &it : llvm::enumerate(First: yieldOp.getResults())) {
1117 Block *block = yieldOp->getBlock();
1118 if (!isa<TensorType>(Val: it.value().getType()))
1119 continue;
1120 if (it.index() >= block->getNumArguments() ||
1121 !state.areEquivalentBufferizedValues(v1: it.value(),
1122 v2: block->getArgument(i: it.index())))
1123 return yieldOp->emitError()
1124 << "Yield operand #" << it.index()
1125 << " is not equivalent to the corresponding iter bbArg";
1126 }
1127
1128 return success();
1129 }
1130};
1131
1132/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1133/// this is for analysis only.
1134struct YieldOpInterface
1135 : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1136 scf::YieldOp> {
1137 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1138 const AnalysisState &state) const {
1139 return true;
1140 }
1141
1142 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1143 const AnalysisState &state) const {
1144 return false;
1145 }
1146
1147 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1148 const AnalysisState &state) const {
1149 if (auto ifOp = dyn_cast<scf::IfOp>(Val: op->getParentOp())) {
1150 return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()),
1151 BufferRelation::Equivalent, /*isDefinite=*/false}};
1152 }
1153 if (isa<scf::ExecuteRegionOp>(Val: op->getParentOp()))
1154 return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()),
1155 BufferRelation::Equivalent}};
1156 return {};
1157 }
1158
1159 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1160 const AnalysisState &state) const {
1161 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1162 // may be generated inside the block. We should not return/yield allocations
1163 // when possible.
1164 return true;
1165 }
1166
1167 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1168 const BufferizationOptions &options,
1169 BufferizationState &state) const {
1170 auto yieldOp = cast<scf::YieldOp>(Val: op);
1171 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1172 scf::WhileOp>(Val: yieldOp->getParentOp()))
1173 return yieldOp->emitError(message: "unsupported scf::YieldOp parent");
1174
1175 SmallVector<Value> newResults;
1176 for (const auto &it : llvm::enumerate(First: yieldOp.getResults())) {
1177 Value value = it.value();
1178 if (isa<TensorType>(Val: value.getType())) {
1179 FailureOr<Value> maybeBuffer =
1180 getBuffer(rewriter, value, options, state);
1181 if (failed(Result: maybeBuffer))
1182 return failure();
1183 Value buffer = *maybeBuffer;
1184 // We may have to cast the value before yielding it.
1185 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1186 Val: yieldOp->getParentOp())) {
1187 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1188 value: yieldOp->getParentOp()->getResult(idx: it.index()), options, state);
1189 if (failed(Result: resultType))
1190 return failure();
1191 buffer = castBuffer(b&: rewriter, buffer, type: *resultType);
1192 } else if (auto whileOp =
1193 dyn_cast<scf::WhileOp>(Val: yieldOp->getParentOp())) {
1194 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1195 value: whileOp.getBeforeArguments()[it.index()], options, state);
1196 if (failed(Result: resultType))
1197 return failure();
1198 buffer = castBuffer(b&: rewriter, buffer, type: *resultType);
1199 }
1200 newResults.push_back(Elt: buffer);
1201 } else {
1202 newResults.push_back(Elt: value);
1203 }
1204 }
1205
1206 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, args&: newResults);
1207 return success();
1208 }
1209};
1210
1211/// Bufferization of ForallOp. This also bufferizes the terminator of the
1212/// region. There are op interfaces for the terminators (InParallelOp
1213/// and ParallelInsertSliceOp), but these are only used during analysis. Not
1214/// for bufferization.
1215struct ForallOpInterface
1216 : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1217 ForallOp> {
1218 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1219 const AnalysisState &state) const {
1220 // All tensor operands to `scf.forall` are `shared_outs` and all
1221 // shared outs are assumed to be read by the loop. This does not
1222 // account for the case where the entire value is over-written,
1223 // but being conservative here.
1224 return true;
1225 }
1226
1227 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1228 const AnalysisState &state) const {
1229 // Outputs of scf::ForallOps are always considered as a write.
1230 return true;
1231 }
1232
1233 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1234 const AnalysisState &state) const {
1235 auto forallOp = cast<ForallOp>(Val: op);
1236 return {
1237 {{forallOp.getTiedOpResult(opOperand: &opOperand), BufferRelation::Equivalent}}};
1238 }
1239
1240 bool isWritable(Operation *op, Value value,
1241 const AnalysisState &state) const {
1242 return true;
1243 }
1244
1245 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1246 const BufferizationOptions &options,
1247 BufferizationState &state) const {
1248 OpBuilder::InsertionGuard guard(rewriter);
1249 auto forallOp = cast<ForallOp>(Val: op);
1250 int64_t rank = forallOp.getRank();
1251
1252 // Get buffers for all output operands.
1253 SmallVector<Value> buffers;
1254 for (Value out : forallOp.getOutputs()) {
1255 FailureOr<Value> buffer = getBuffer(rewriter, value: out, options, state);
1256 if (failed(Result: buffer))
1257 return failure();
1258 buffers.push_back(Elt: *buffer);
1259 }
1260
1261 // Use buffers instead of block arguments.
1262 rewriter.setInsertionPointToStart(forallOp.getBody());
1263 for (const auto &it : llvm::zip(
1264 t: forallOp.getBody()->getArguments().drop_front(N: rank), u&: buffers)) {
1265 BlockArgument bbArg = std::get<0>(t: it);
1266 Value buffer = std::get<1>(t: it);
1267 Value bufferAsTensor = rewriter.create<ToTensorOp>(
1268 location: forallOp.getLoc(), args: bbArg.getType(), args&: buffer);
1269 bbArg.replaceAllUsesWith(newValue: bufferAsTensor);
1270 }
1271
1272 // Create new ForallOp without any results and drop the automatically
1273 // introduced terminator.
1274 rewriter.setInsertionPoint(forallOp);
1275 ForallOp newForallOp;
1276 newForallOp = rewriter.create<ForallOp>(
1277 location: forallOp.getLoc(), args: forallOp.getMixedLowerBound(),
1278 args: forallOp.getMixedUpperBound(), args: forallOp.getMixedStep(),
1279 /*outputs=*/args: ValueRange(), args: forallOp.getMapping());
1280
1281 // Keep discardable attributes from the original op.
1282 newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1283
1284 rewriter.eraseOp(op: newForallOp.getBody()->getTerminator());
1285
1286 // Move over block contents of the old op.
1287 SmallVector<Value> replacementBbArgs;
1288 replacementBbArgs.append(in_start: newForallOp.getBody()->getArguments().begin(),
1289 in_end: newForallOp.getBody()->getArguments().end());
1290 replacementBbArgs.append(NumInputs: forallOp.getOutputs().size(), Elt: Value());
1291 rewriter.mergeBlocks(source: forallOp.getBody(), dest: newForallOp.getBody(),
1292 argValues: replacementBbArgs);
1293
1294 // Remove the old op and replace all of its uses.
1295 replaceOpWithBufferizedValues(rewriter, op, values: buffers);
1296
1297 return success();
1298 }
1299
1300 FailureOr<BufferLikeType>
1301 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1302 const BufferizationState &state,
1303 SmallVector<Value> &invocationStack) const {
1304 auto forallOp = cast<ForallOp>(Val: op);
1305
1306 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value))
1307 // A tensor block argument has the same bufferized type as the
1308 // corresponding output operand.
1309 return bufferization::getBufferType(
1310 value: forallOp.getTiedOpOperand(bbArg)->get(), options, state,
1311 invocationStack);
1312
1313 // The bufferized result type is the same as the bufferized type of the
1314 // corresponding output operand.
1315 return bufferization::getBufferType(
1316 value: forallOp.getOutputs()[cast<OpResult>(Val&: value).getResultNumber()], options,
1317 state, invocationStack);
1318 }
1319
1320 bool isRepetitiveRegion(Operation *op, unsigned index) const {
1321 auto forallOp = cast<ForallOp>(Val: op);
1322
1323 // This op is repetitive if it has 1 or more steps.
1324 // If the control variables are dynamic, it is also considered so.
1325 for (auto [lb, ub, step] :
1326 llvm::zip(t: forallOp.getMixedLowerBound(), u: forallOp.getMixedUpperBound(),
1327 args: forallOp.getMixedStep())) {
1328 std::optional<int64_t> lbConstant = getConstantIntValue(ofr: lb);
1329 if (!lbConstant)
1330 return true;
1331
1332 std::optional<int64_t> ubConstant = getConstantIntValue(ofr: ub);
1333 if (!ubConstant)
1334 return true;
1335
1336 std::optional<int64_t> stepConstant = getConstantIntValue(ofr: step);
1337 if (!stepConstant)
1338 return true;
1339
1340 if (*lbConstant + *stepConstant < *ubConstant)
1341 return true;
1342 }
1343 return false;
1344 }
1345
1346 bool isParallelRegion(Operation *op, unsigned index) const {
1347 return isRepetitiveRegion(op, index);
1348 }
1349};
1350
1351/// Nothing to do for InParallelOp.
1352struct InParallelOpInterface
1353 : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1354 InParallelOp> {
1355 LogicalResult bufferize(Operation *op, RewriterBase &b,
1356 const BufferizationOptions &options,
1357 BufferizationState &state) const {
1358 llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1359 return failure();
1360 }
1361};
1362
1363} // namespace
1364} // namespace scf
1365} // namespace mlir
1366
1367void mlir::scf::registerBufferizableOpInterfaceExternalModels(
1368 DialectRegistry &registry) {
1369 registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1370 ConditionOp::attachInterface<ConditionOpInterface>(context&: *ctx);
1371 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(context&: *ctx);
1372 ForOp::attachInterface<ForOpInterface>(context&: *ctx);
1373 IfOp::attachInterface<IfOpInterface>(context&: *ctx);
1374 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(context&: *ctx);
1375 ForallOp::attachInterface<ForallOpInterface>(context&: *ctx);
1376 InParallelOp::attachInterface<InParallelOpInterface>(context&: *ctx);
1377 WhileOp::attachInterface<WhileOpInterface>(context&: *ctx);
1378 YieldOp::attachInterface<YieldOpInterface>(context&: *ctx);
1379 });
1380}
1381

source code of mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp