1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
14#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Utils/StaticValueUtils.h"
20#include "mlir/IR/Dialect.h"
21#include "mlir/IR/Operation.h"
22#include "mlir/IR/PatternMatch.h"
23
24using namespace mlir;
25using namespace mlir::bufferization;
26using namespace mlir::scf;
27
28namespace mlir {
29namespace scf {
30namespace {
31
32/// Helper function for loop bufferization. Cast the given buffer to the given
33/// memref type.
34static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
35 assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
36 assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
37 // If the buffer already has the correct type, no cast is needed.
38 if (buffer.getType() == type)
39 return buffer;
40 // TODO: In case `type` has a layout map that is not the fully dynamic
41 // one, we may not be able to cast the buffer. In that case, the loop
42 // iter_arg's layout map must be changed (see uses of `castBuffer`).
43 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
44 "scf.while op bufferization: cast incompatible");
45 return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
46}
47
48/// Helper function for loop bufferization. Return "true" if the given value
49/// is guaranteed to not alias with an external tensor apart from values in
50/// `exceptions`. A value is external if it is defined outside of the given
51/// region or if it is an entry block argument of the region.
52static bool doesNotAliasExternalValue(Value value, Region *region,
53 ValueRange exceptions,
54 const OneShotAnalysisState &state) {
55 assert(region->getBlocks().size() == 1 &&
56 "expected region with single block");
57 bool result = true;
58 state.applyOnAliases(v: value, fun: [&](Value alias) {
59 if (llvm::is_contained(Range&: exceptions, Element: alias))
60 return;
61 Region *aliasRegion = alias.getParentRegion();
62 if (isa<BlockArgument>(Val: alias) && !region->isProperAncestor(other: aliasRegion))
63 result = false;
64 if (isa<OpResult>(Val: alias) && !region->isAncestor(other: aliasRegion))
65 result = false;
66 });
67 return result;
68}
69
70/// Bufferization of scf.condition.
71struct ConditionOpInterface
72 : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
73 scf::ConditionOp> {
74 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
75 const AnalysisState &state) const {
76 return true;
77 }
78
79 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
80 const AnalysisState &state) const {
81 return false;
82 }
83
84 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
85 const AnalysisState &state) const {
86 return {};
87 }
88
89 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
90 const AnalysisState &state) const {
91 // Condition operands always bufferize inplace. Otherwise, an alloc + copy
92 // may be generated inside the block. We should not return/yield allocations
93 // when possible.
94 return true;
95 }
96
97 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
98 const BufferizationOptions &options) const {
99 auto conditionOp = cast<scf::ConditionOp>(op);
100 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
101
102 SmallVector<Value> newArgs;
103 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
104 Value value = it.value();
105 if (isa<TensorType>(value.getType())) {
106 FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
107 if (failed(maybeBuffer))
108 return failure();
109 FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
110 whileOp.getAfterArguments()[it.index()], options);
111 if (failed(resultType))
112 return failure();
113 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114 newArgs.push_back(buffer);
115 } else {
116 newArgs.push_back(value);
117 }
118 }
119
120 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121 rewriter, op, conditionOp.getCondition(), newArgs);
122 return success();
123 }
124};
125
126/// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
127/// return an empty op.
128static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
129 scf::YieldOp result;
130 for (Block &block : executeRegionOp.getRegion()) {
131 if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
132 if (result)
133 return {};
134 result = yieldOp;
135 }
136 }
137 return result;
138}
139
140/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
141/// fully implemented at the moment.
142struct ExecuteRegionOpInterface
143 : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
144 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
145
146 static bool supportsUnstructuredControlFlow() { return true; }
147
148 bool isWritable(Operation *op, Value value,
149 const AnalysisState &state) const {
150 return true;
151 }
152
153 LogicalResult verifyAnalysis(Operation *op,
154 const AnalysisState &state) const {
155 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
156 // TODO: scf.execute_region with multiple yields are not supported.
157 if (!getUniqueYieldOp(executeRegionOp))
158 return op->emitOpError(message: "op without unique scf.yield is not supported");
159 return success();
160 }
161
162 AliasingOpOperandList
163 getAliasingOpOperands(Operation *op, Value value,
164 const AnalysisState &state) const {
165 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value))
166 return getAliasingBranchOpOperands(op, bbArg, state);
167
168 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
169 // any SSA value that is in scope. To allow for use-def chain traversal
170 // through ExecuteRegionOps in the analysis, the corresponding yield value
171 // is considered to be aliasing with the result.
172 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
173 auto it = llvm::find(Range: op->getOpResults(), Val: value);
174 assert(it != op->getOpResults().end() && "invalid value");
175 size_t resultNum = std::distance(first: op->getOpResults().begin(), last: it);
176 auto yieldOp = getUniqueYieldOp(executeRegionOp);
177 // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
178 if (!yieldOp)
179 return {};
180 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
181 }
182
183 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
184 const BufferizationOptions &options) const {
185 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
186 auto yieldOp = getUniqueYieldOp(executeRegionOp);
187 TypeRange newResultTypes(yieldOp.getResults());
188
189 // Create new op and move over region.
190 auto newOp =
191 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
192 newOp.getRegion().takeBody(executeRegionOp.getRegion());
193
194 // Bufferize every block.
195 for (Block &block : newOp.getRegion())
196 if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
197 options)))
198 return failure();
199
200 // Update all uses of the old op.
201 rewriter.setInsertionPointAfter(newOp);
202 SmallVector<Value> newResults;
203 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
204 if (isa<TensorType>(it.value())) {
205 newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
206 executeRegionOp.getLoc(), newOp->getResult(it.index())));
207 } else {
208 newResults.push_back(newOp->getResult(it.index()));
209 }
210 }
211
212 // Replace old op.
213 rewriter.replaceOp(executeRegionOp, newResults);
214
215 return success();
216 }
217};
218
219/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
220struct IfOpInterface
221 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
222 AliasingOpOperandList
223 getAliasingOpOperands(Operation *op, Value value,
224 const AnalysisState &state) const {
225 // IfOps do not have tensor OpOperands. The yielded value can be any SSA
226 // value that is in scope. To allow for use-def chain traversal through
227 // IfOps in the analysis, both corresponding yield values from the then/else
228 // branches are considered to be aliasing with the result.
229 auto ifOp = cast<scf::IfOp>(op);
230 size_t resultNum = std::distance(first: op->getOpResults().begin(),
231 last: llvm::find(Range: op->getOpResults(), Val: value));
232 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
233 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
234 return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
235 {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
236 }
237
238 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
239 const BufferizationOptions &options) const {
240 OpBuilder::InsertionGuard g(rewriter);
241 auto ifOp = cast<scf::IfOp>(op);
242
243 // Compute bufferized result types.
244 SmallVector<Type> newTypes;
245 for (Value result : ifOp.getResults()) {
246 if (!isa<TensorType>(result.getType())) {
247 newTypes.push_back(result.getType());
248 continue;
249 }
250 auto bufferType = bufferization::getBufferType(result, options);
251 if (failed(bufferType))
252 return failure();
253 newTypes.push_back(*bufferType);
254 }
255
256 // Create new op.
257 rewriter.setInsertionPoint(ifOp);
258 auto newIfOp =
259 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
260 /*withElseRegion=*/true);
261
262 // Move over then/else blocks.
263 rewriter.mergeBlocks(source: ifOp.thenBlock(), dest: newIfOp.thenBlock());
264 rewriter.mergeBlocks(source: ifOp.elseBlock(), dest: newIfOp.elseBlock());
265
266 // Replace op results.
267 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
268
269 return success();
270 }
271
272 FailureOr<BaseMemRefType>
273 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
274 SmallVector<Value> &invocationStack) const {
275 auto ifOp = cast<scf::IfOp>(op);
276 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
277 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
278 assert(value.getDefiningOp() == op && "invalid valid");
279
280 // Determine buffer types of the true/false branches.
281 auto opResult = cast<OpResult>(Val&: value);
282 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
283 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
284 BaseMemRefType thenBufferType, elseBufferType;
285 if (isa<BaseMemRefType>(thenValue.getType())) {
286 // True branch was already bufferized.
287 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
288 } else {
289 auto maybeBufferType =
290 bufferization::getBufferType(value: thenValue, options, invocationStack);
291 if (failed(maybeBufferType))
292 return failure();
293 thenBufferType = *maybeBufferType;
294 }
295 if (isa<BaseMemRefType>(elseValue.getType())) {
296 // False branch was already bufferized.
297 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
298 } else {
299 auto maybeBufferType =
300 bufferization::getBufferType(value: elseValue, options, invocationStack);
301 if (failed(maybeBufferType))
302 return failure();
303 elseBufferType = *maybeBufferType;
304 }
305
306 // Best case: Both branches have the exact same buffer type.
307 if (thenBufferType == elseBufferType)
308 return thenBufferType;
309
310 // Memory space mismatch.
311 if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
312 return op->emitError(message: "inconsistent memory space on then/else branches");
313
314 // Layout maps are different: Promote to fully dynamic layout map.
315 return getMemRefTypeWithFullyDynamicLayout(
316 tensorType: cast<TensorType>(Val: opResult.getType()), memorySpace: thenBufferType.getMemorySpace());
317 }
318};
319
320/// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
321/// yields memrefs.
322struct IndexSwitchOpInterface
323 : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
324 scf::IndexSwitchOp> {
325 AliasingOpOperandList
326 getAliasingOpOperands(Operation *op, Value value,
327 const AnalysisState &state) const {
328 // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
329 // any SSA. This is similar to IfOps.
330 auto switchOp = cast<scf::IndexSwitchOp>(op);
331 int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber();
332 AliasingOpOperandList result;
333 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
334 auto yieldOp =
335 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
336 result.addAlias(alias: AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
337 BufferRelation::Equivalent,
338 /*isDefinite=*/false));
339 }
340 auto defaultYieldOp =
341 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
342 result.addAlias(alias: AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
343 BufferRelation::Equivalent,
344 /*isDefinite=*/false));
345 return result;
346 }
347
348 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
349 const BufferizationOptions &options) const {
350 OpBuilder::InsertionGuard g(rewriter);
351 auto switchOp = cast<scf::IndexSwitchOp>(op);
352
353 // Compute bufferized result types.
354 SmallVector<Type> newTypes;
355 for (Value result : switchOp.getResults()) {
356 if (!isa<TensorType>(result.getType())) {
357 newTypes.push_back(result.getType());
358 continue;
359 }
360 auto bufferType = bufferization::getBufferType(result, options);
361 if (failed(bufferType))
362 return failure();
363 newTypes.push_back(*bufferType);
364 }
365
366 // Create new op.
367 rewriter.setInsertionPoint(switchOp);
368 auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
369 switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
370 switchOp.getCases().size());
371
372 // Move over blocks.
373 for (auto [src, dest] :
374 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
375 rewriter.inlineRegionBefore(src, dest, dest.begin());
376 rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
377 newSwitchOp.getDefaultRegion(),
378 newSwitchOp.getDefaultRegion().begin());
379
380 // Replace op results.
381 replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
382
383 return success();
384 }
385
386 FailureOr<BaseMemRefType>
387 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
388 SmallVector<Value> &invocationStack) const {
389 auto switchOp = cast<scf::IndexSwitchOp>(op);
390 assert(value.getDefiningOp() == op && "invalid value");
391 int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber();
392
393 // Helper function to get buffer type of a case.
394 SmallVector<BaseMemRefType> yieldedTypes;
395 auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
396 auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
397 Value yieldedValue = yieldOp->getOperand(resultNum);
398 if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
399 return bufferType;
400 auto maybeBufferType =
401 bufferization::getBufferType(yieldedValue, options, invocationStack);
402 if (failed(maybeBufferType))
403 return failure();
404 return maybeBufferType;
405 };
406
407 // Compute buffer type of the default case.
408 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
409 if (failed(maybeBufferType))
410 return failure();
411 BaseMemRefType bufferType = *maybeBufferType;
412
413 // Compute buffer types of all other cases.
414 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
415 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
416 if (failed(yieldedBufferType))
417 return failure();
418
419 // Best case: Both branches have the exact same buffer type.
420 if (bufferType == *yieldedBufferType)
421 continue;
422
423 // Memory space mismatch.
424 if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
425 return op->emitError(message: "inconsistent memory space on switch cases");
426
427 // Layout maps are different: Promote to fully dynamic layout map.
428 bufferType = getMemRefTypeWithFullyDynamicLayout(
429 tensorType: cast<TensorType>(Val: value.getType()), memorySpace: bufferType.getMemorySpace());
430 }
431
432 return bufferType;
433 }
434};
435
436/// Helper function for loop bufferization. Return the indices of all values
437/// that have a tensor type.
438static DenseSet<int64_t> getTensorIndices(ValueRange values) {
439 DenseSet<int64_t> result;
440 for (const auto &it : llvm::enumerate(First&: values))
441 if (isa<TensorType>(Val: it.value().getType()))
442 result.insert(V: it.index());
443 return result;
444}
445
446/// Helper function for loop bufferization. Return the indices of all
447/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
448DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
449 ValueRange yieldedValues,
450 const AnalysisState &state) {
451 unsigned int minSize = std::min(a: bbArgs.size(), b: yieldedValues.size());
452 DenseSet<int64_t> result;
453 for (unsigned int i = 0; i < minSize; ++i) {
454 if (!isa<TensorType>(Val: bbArgs[i].getType()) ||
455 !isa<TensorType>(Val: yieldedValues[i].getType()))
456 continue;
457 if (state.areEquivalentBufferizedValues(v1: bbArgs[i], v2: yieldedValues[i]))
458 result.insert(V: i);
459 }
460 return result;
461}
462
463/// Helper function for loop bufferization. Return the bufferized values of the
464/// given OpOperands. If an operand is not a tensor, return the original value.
465static FailureOr<SmallVector<Value>>
466getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
467 const BufferizationOptions &options) {
468 SmallVector<Value> result;
469 for (OpOperand &opOperand : operands) {
470 if (isa<TensorType>(Val: opOperand.get().getType())) {
471 FailureOr<Value> resultBuffer =
472 getBuffer(rewriter, value: opOperand.get(), options);
473 if (failed(result: resultBuffer))
474 return failure();
475 result.push_back(Elt: *resultBuffer);
476 } else {
477 result.push_back(Elt: opOperand.get());
478 }
479 }
480 return result;
481}
482
483/// Helper function for loop bufferization. Given a list of bbArgs of the new
484/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
485/// ToTensorOps, so that the block body can be moved over to the new op.
486static SmallVector<Value>
487getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
488 const DenseSet<int64_t> &tensorIndices) {
489 SmallVector<Value> result;
490 for (const auto &it : llvm::enumerate(First&: bbArgs)) {
491 size_t idx = it.index();
492 Value val = it.value();
493 if (tensorIndices.contains(V: idx)) {
494 result.push_back(
495 rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
496 .getResult());
497 } else {
498 result.push_back(Elt: val);
499 }
500 }
501 return result;
502}
503
504/// Compute the bufferized type of a loop iter_arg. This type must be equal to
505/// the bufferized type of the corresponding init_arg and the bufferized type
506/// of the corresponding yielded value.
507///
508/// This function uses bufferization::getBufferType to compute the bufferized
509/// type of the init_arg and of the yielded value. (The computation of the
510/// bufferized yielded value type usually requires computing the bufferized type
511/// of the iter_arg again; the implementation of getBufferType traces back the
512/// use-def chain of the given value and computes a buffer type along the way.)
513/// If both buffer types are equal, no casts are needed the computed buffer type
514/// can be used directly. Otherwise, the buffer types can only differ in their
515/// layout map and a cast must be inserted.
516static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
517 Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
518 const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
519 // Determine the buffer type of the init_arg.
520 auto initArgBufferType =
521 bufferization::getBufferType(initArg, options, invocationStack);
522 if (failed(initArgBufferType))
523 return failure();
524
525 if (llvm::count(Range&: invocationStack, Element: iterArg) >= 2) {
526 // If the iter_arg is already twice on the invocation stack, just take the
527 // type of the init_arg. This is to avoid infinite loops when calculating
528 // the buffer type. This will most likely result in computing a memref type
529 // with a fully dynamic layout map.
530
531 // Note: For more precise layout map computation, a fixpoint iteration could
532 // be done (i.e., re-computing the yielded buffer type until the bufferized
533 // iter_arg type no longer changes). This current implementation immediately
534 // switches to a fully dynamic layout map when a mismatch between bufferized
535 // init_arg type and bufferized yield value type is detected.
536 return *initArgBufferType;
537 }
538
539 // Compute the buffer type of the yielded value.
540 BaseMemRefType yieldedValueBufferType;
541 if (isa<BaseMemRefType>(Val: yieldedValue.getType())) {
542 // scf.yield was already bufferized.
543 yieldedValueBufferType = cast<BaseMemRefType>(Val: yieldedValue.getType());
544 } else {
545 // Note: This typically triggers a recursive call for the buffer type of
546 // the iter_arg.
547 auto maybeBufferType =
548 bufferization::getBufferType(yieldedValue, options, invocationStack);
549 if (failed(maybeBufferType))
550 return failure();
551 yieldedValueBufferType = *maybeBufferType;
552 }
553
554 // If yielded type and init_arg type are the same, use that type directly.
555 if (*initArgBufferType == yieldedValueBufferType)
556 return yieldedValueBufferType;
557
558 // If there is a mismatch between the yielded buffer type and the init_arg
559 // buffer type, the buffer type must be promoted to a fully dynamic layout
560 // map.
561 auto yieldedBufferType = cast<BaseMemRefType>(Val&: yieldedValueBufferType);
562 auto iterTensorType = cast<TensorType>(Val: iterArg.getType());
563 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
564 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
565 return loopOp->emitOpError(
566 message: "init_arg and yielded value bufferize to inconsistent memory spaces");
567#ifndef NDEBUG
568 if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
569 assert(
570 llvm::all_equal({yieldedRankedBufferType.getShape(),
571 cast<MemRefType>(initBufferType).getShape(),
572 cast<RankedTensorType>(iterTensorType).getShape()}) &&
573 "expected same shape");
574 }
575#endif // NDEBUG
576 return getMemRefTypeWithFullyDynamicLayout(
577 tensorType: iterTensorType, memorySpace: yieldedBufferType.getMemorySpace());
578}
579
580/// Return `true` if the given loop may have 0 iterations.
581bool mayHaveZeroIterations(scf::ForOp forOp) {
582 std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
583 std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
584 if (!lb.has_value() || !ub.has_value())
585 return true;
586 return *ub <= *lb;
587}
588
589/// Bufferization of scf.for. Replace with a new scf.for that operates on
590/// memrefs.
591struct ForOpInterface
592 : public BufferizableOpInterface::ExternalModel<ForOpInterface,
593 scf::ForOp> {
594 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
595 const AnalysisState &state) const {
596 auto forOp = cast<scf::ForOp>(op);
597
598 // If the loop has zero iterations, the results of the op are their
599 // corresponding init_args, meaning that the init_args bufferize to a read.
600 if (mayHaveZeroIterations(forOp))
601 return true;
602
603 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
604 // its matching bbArg may.
605 return state.isValueRead(value: forOp.getTiedLoopRegionIterArg(&opOperand));
606 }
607
608 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
609 const AnalysisState &state) const {
610 // Tensor iter_args of scf::ForOps are always considered as a write.
611 return true;
612 }
613
614 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
615 const AnalysisState &state) const {
616 auto forOp = cast<scf::ForOp>(op);
617 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
618 BufferRelation relation = bufferRelation(op, opResult, state);
619 return {{opResult, relation,
620 /*isDefinite=*/relation == BufferRelation::Equivalent}};
621 }
622
623 BufferRelation bufferRelation(Operation *op, OpResult opResult,
624 const AnalysisState &state) const {
625 // ForOp results are equivalent to their corresponding init_args if the
626 // corresponding iter_args and yield values are equivalent.
627 auto forOp = cast<scf::ForOp>(op);
628 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
629 bool equivalentYield = state.areEquivalentBufferizedValues(
630 v1: bbArg, v2: forOp.getTiedLoopYieldedValue(bbArg)->get());
631 return equivalentYield ? BufferRelation::Equivalent
632 : BufferRelation::Unknown;
633 }
634
635 bool isWritable(Operation *op, Value value,
636 const AnalysisState &state) const {
637 // Interestingly, scf::ForOp's bbArg can **always** be viewed
638 // inplace from the perspective of ops nested under:
639 // 1. Either the matching iter operand is not bufferized inplace and an
640 // alloc + optional copy makes the bbArg itself inplaceable.
641 // 2. Or the matching iter operand is bufferized inplace and bbArg just
642 // bufferizes to that too.
643 return true;
644 }
645
646 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
647 const AnalysisState &state) const {
648 auto bufferizableOp = cast<BufferizableOpInterface>(op);
649 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
650 return failure();
651
652 if (!state.getOptions().enforceAliasingInvariants)
653 return success();
654
655 // According to the `getAliasing...` implementations, a bufferized OpResult
656 // may alias only with the corresponding bufferized init_arg (or with a
657 // newly allocated buffer) and not with other buffers defined outside of the
658 // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
659 // but not with any other OpOperand.
660 auto forOp = cast<scf::ForOp>(op);
661 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
662 OpBuilder::InsertionGuard g(rewriter);
663 rewriter.setInsertionPoint(yieldOp);
664
665 // Indices of all iter_args that have tensor type. These are the ones that
666 // are bufferized.
667 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
668 // For every yielded value, does it alias with something defined outside of
669 // the loop?
670 SmallVector<Value> yieldValues;
671 for (const auto it : llvm::enumerate(yieldOp.getResults())) {
672 // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
673 // type cannot be used in the signature of `resolveConflicts` because the
674 // op interface is in the "IR" build unit and the `OneShotAnalysisState`
675 // is defined in the "Transforms" build unit.
676 if (!indices.contains(it.index()) ||
677 doesNotAliasExternalValue(
678 it.value(), &forOp.getRegion(),
679 /*exceptions=*/forOp.getRegionIterArg(it.index()),
680 static_cast<const OneShotAnalysisState &>(state))) {
681 yieldValues.push_back(it.value());
682 continue;
683 }
684 FailureOr<Value> alloc = allocateTensorForShapedValue(
685 rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
686 if (failed(alloc))
687 return failure();
688 yieldValues.push_back(*alloc);
689 }
690
691 rewriter.modifyOpInPlace(
692 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
693 return success();
694 }
695
696 FailureOr<BaseMemRefType>
697 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
698 SmallVector<Value> &invocationStack) const {
699 auto forOp = cast<scf::ForOp>(op);
700 assert(getOwnerOfValue(value) == op && "invalid value");
701 assert(isa<TensorType>(value.getType()) && "expected tensor type");
702
703 if (auto opResult = dyn_cast<OpResult>(Val&: value)) {
704 // The type of an OpResult must match the corresponding iter_arg type.
705 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
706 return bufferization::getBufferType(bbArg, options, invocationStack);
707 }
708
709 // Compute result/argument number.
710 BlockArgument bbArg = cast<BlockArgument>(Val&: value);
711 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
712
713 // Compute the bufferized type.
714 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
715 Value yieldedValue = yieldOp.getOperand(resultNum);
716 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
717 Value initArg = forOp.getInitArgs()[resultNum];
718 return computeLoopRegionIterArgBufferType(
719 loopOp: op, iterArg, initArg, yieldedValue, options, invocationStack);
720 }
721
722 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
723 const BufferizationOptions &options) const {
724 auto forOp = cast<scf::ForOp>(op);
725 Block *oldLoopBody = forOp.getBody();
726
727 // Indices of all iter_args that have tensor type. These are the ones that
728 // are bufferized.
729 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
730
731 // The new memref init_args of the loop.
732 FailureOr<SmallVector<Value>> maybeInitArgs =
733 getBuffers(rewriter, forOp.getInitArgsMutable(), options);
734 if (failed(result: maybeInitArgs))
735 return failure();
736 SmallVector<Value> initArgs = *maybeInitArgs;
737
738 // Cast init_args if necessary.
739 SmallVector<Value> castedInitArgs;
740 for (const auto &it : llvm::enumerate(initArgs)) {
741 Value initArg = it.value();
742 Value result = forOp->getResult(it.index());
743 // If the type is not a tensor, bufferization doesn't need to touch it.
744 if (!isa<TensorType>(result.getType())) {
745 castedInitArgs.push_back(initArg);
746 continue;
747 }
748 auto targetType = bufferization::getBufferType(result, options);
749 if (failed(targetType))
750 return failure();
751 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
752 }
753
754 // Construct a new scf.for op with memref instead of tensor values.
755 auto newForOp = rewriter.create<scf::ForOp>(
756 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
757 forOp.getStep(), castedInitArgs);
758 newForOp->setAttrs(forOp->getAttrs());
759 Block *loopBody = newForOp.getBody();
760
761 // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
762 // iter_args of the new loop in ToTensorOps.
763 rewriter.setInsertionPointToStart(loopBody);
764 SmallVector<Value> iterArgs =
765 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
766 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
767
768 // Move loop body to new loop.
769 rewriter.mergeBlocks(source: oldLoopBody, dest: loopBody, argValues: iterArgs);
770
771 // Replace loop results.
772 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
773
774 return success();
775 }
776
777 /// Assert that yielded values of an scf.for op are equivalent to their
778 /// corresponding bbArgs. In that case, the buffer relations of the
779 /// corresponding OpResults are "Equivalent".
780 ///
781 /// If this is not the case, an allocs+copies are inserted and yielded from
782 /// the loop. This could be a performance problem, so it must be explicitly
783 /// activated with `alloc-return-allocs`.
784 LogicalResult verifyAnalysis(Operation *op,
785 const AnalysisState &state) const {
786 const auto &options =
787 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
788 if (options.allowReturnAllocsFromLoops)
789 return success();
790
791 auto forOp = cast<scf::ForOp>(op);
792 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
793 for (OpResult opResult : op->getOpResults()) {
794 if (!isa<TensorType>(Val: opResult.getType()))
795 continue;
796
797 // Note: This is overly strict. We should check for aliasing bufferized
798 // values. But we don't have a "must-alias" analysis yet.
799 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
800 return yieldOp->emitError()
801 << "Yield operand #" << opResult.getResultNumber()
802 << " is not equivalent to the corresponding iter bbArg";
803 }
804
805 return success();
806 }
807};
808
809/// Bufferization of scf.while. Replace with a new scf.while that operates on
810/// memrefs.
811struct WhileOpInterface
812 : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
813 scf::WhileOp> {
814 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
815 const AnalysisState &state) const {
816 // Tensor iter_args of scf::WhileOps are always considered as a read.
817 return true;
818 }
819
820 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
821 const AnalysisState &state) const {
822 // Tensor iter_args of scf::WhileOps are always considered as a write.
823 return true;
824 }
825
826 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
827 const AnalysisState &state) const {
828 auto whileOp = cast<scf::WhileOp>(op);
829 unsigned int idx = opOperand.getOperandNumber();
830
831 // The OpResults and OpOperands may not match. They may not even have the
832 // same type. The number of OpResults and OpOperands can also differ.
833 if (idx >= op->getNumResults() ||
834 opOperand.get().getType() != op->getResult(idx).getType())
835 return {};
836
837 // The only aliasing OpResult may be the one at the same index.
838 OpResult opResult = whileOp->getResult(idx);
839 BufferRelation relation = bufferRelation(op, opResult, state);
840 return {{opResult, relation,
841 /*isDefinite=*/relation == BufferRelation::Equivalent}};
842 }
843
844 BufferRelation bufferRelation(Operation *op, OpResult opResult,
845 const AnalysisState &state) const {
846 // WhileOp results are equivalent to their corresponding init_args if the
847 // corresponding iter_args and yield values are equivalent (for both the
848 // "before" and the "after" block).
849 unsigned int resultNumber = opResult.getResultNumber();
850 auto whileOp = cast<scf::WhileOp>(op);
851
852 // The "before" region bbArgs and the OpResults may not match.
853 if (resultNumber >= whileOp.getBeforeArguments().size())
854 return BufferRelation::Unknown;
855 if (opResult.getType() !=
856 whileOp.getBeforeArguments()[resultNumber].getType())
857 return BufferRelation::Unknown;
858
859 auto conditionOp = whileOp.getConditionOp();
860 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
861 Value conditionOperand = conditionOp.getArgs()[resultNumber];
862 bool equivCondition =
863 state.areEquivalentBufferizedValues(v1: conditionBbArg, v2: conditionOperand);
864
865 auto yieldOp = whileOp.getYieldOp();
866 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
867 Value yieldOperand = yieldOp.getOperand(resultNumber);
868 bool equivYield =
869 state.areEquivalentBufferizedValues(v1: bodyBbArg, v2: yieldOperand);
870
871 return equivCondition && equivYield ? BufferRelation::Equivalent
872 : BufferRelation::Unknown;
873 }
874
875 bool isWritable(Operation *op, Value value,
876 const AnalysisState &state) const {
877 // Interestingly, scf::WhileOp's bbArg can **always** be viewed
878 // inplace from the perspective of ops nested under:
879 // 1. Either the matching iter operand is not bufferized inplace and an
880 // alloc + optional copy makes the bbArg itself inplaceable.
881 // 2. Or the matching iter operand is bufferized inplace and bbArg just
882 // bufferizes to that too.
883 return true;
884 }
885
886 LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
887 const AnalysisState &state) const {
888 auto bufferizableOp = cast<BufferizableOpInterface>(op);
889 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
890 return failure();
891
892 if (!state.getOptions().enforceAliasingInvariants)
893 return success();
894
895 // According to the `getAliasing...` implementations, a bufferized OpResult
896 // may alias only with the corresponding bufferized init_arg and with no
897 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
898 // but not with any other OpOperand. If a corresponding OpResult/init_arg
899 // pair bufferizes to equivalent buffers, this aliasing requirement is
900 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
901 // (New buffer copies do not alias with any buffer.)
902 OpBuilder::InsertionGuard g(rewriter);
903 auto whileOp = cast<scf::WhileOp>(op);
904 auto conditionOp = whileOp.getConditionOp();
905
906 // For every yielded value, is the value equivalent to its corresponding
907 // bbArg?
908 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
909 whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
910 DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
911 whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
912
913 // Update "before" region.
914 rewriter.setInsertionPoint(conditionOp);
915 SmallVector<Value> beforeYieldValues;
916 for (int64_t idx = 0;
917 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
918 Value value = conditionOp.getArgs()[idx];
919 if (!isa<TensorType>(Val: value.getType()) ||
920 (equivalentYieldsAfter.contains(V: idx) &&
921 equivalentYieldsBefore.contains(V: idx))) {
922 beforeYieldValues.push_back(Elt: value);
923 continue;
924 }
925 FailureOr<Value> alloc = allocateTensorForShapedValue(
926 rewriter, conditionOp.getLoc(), value, state.getOptions());
927 if (failed(result: alloc))
928 return failure();
929 beforeYieldValues.push_back(Elt: *alloc);
930 }
931 rewriter.modifyOpInPlace(conditionOp, [&]() {
932 conditionOp.getArgsMutable().assign(beforeYieldValues);
933 });
934
935 return success();
936 }
937
938 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
939 const BufferizationOptions &options) const {
940 auto whileOp = cast<scf::WhileOp>(op);
941
942 // Indices of all bbArgs that have tensor type. These are the ones that
943 // are bufferized. The "before" and "after" regions may have different args.
944 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
945 DenseSet<int64_t> indicesAfter =
946 getTensorIndices(whileOp.getAfterArguments());
947
948 // The new memref init_args of the loop.
949 FailureOr<SmallVector<Value>> maybeInitArgs =
950 getBuffers(rewriter, whileOp.getInitsMutable(), options);
951 if (failed(result: maybeInitArgs))
952 return failure();
953 SmallVector<Value> initArgs = *maybeInitArgs;
954
955 // Cast init_args if necessary.
956 SmallVector<Value> castedInitArgs;
957 for (const auto &it : llvm::enumerate(initArgs)) {
958 Value initArg = it.value();
959 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
960 // If the type is not a tensor, bufferization doesn't need to touch it.
961 if (!isa<TensorType>(beforeArg.getType())) {
962 castedInitArgs.push_back(initArg);
963 continue;
964 }
965 auto targetType = bufferization::getBufferType(beforeArg, options);
966 if (failed(targetType))
967 return failure();
968 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
969 }
970
971 // The result types of a WhileOp are the same as the "after" bbArg types.
972 SmallVector<Type> argsTypesAfter = llvm::to_vector(
973 llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
974 if (!isa<TensorType>(Val: bbArg.getType()))
975 return bbArg.getType();
976 // TODO: error handling
977 return llvm::cast<Type>(
978 Val: *bufferization::getBufferType(value: bbArg, options));
979 }));
980
981 // Construct a new scf.while op with memref instead of tensor values.
982 ValueRange argsRangeBefore(castedInitArgs);
983 TypeRange argsTypesBefore(argsRangeBefore);
984 auto newWhileOp = rewriter.create<scf::WhileOp>(
985 whileOp.getLoc(), argsTypesAfter, castedInitArgs);
986
987 // Add before/after regions to the new op.
988 SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
989 whileOp.getLoc());
990 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
991 whileOp.getLoc());
992 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
993 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
994 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
995 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
996
997 // Set up new iter_args and move the loop condition block to the new op.
998 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
999 // in ToTensorOps.
1000 rewriter.setInsertionPointToStart(newBeforeBody);
1001 SmallVector<Value> newBeforeArgs = getBbArgReplacements(
1002 rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
1003 rewriter.mergeBlocks(source: whileOp.getBeforeBody(), dest: newBeforeBody, argValues: newBeforeArgs);
1004
1005 // Set up new iter_args and move the loop body block to the new op.
1006 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1007 // in ToTensorOps.
1008 rewriter.setInsertionPointToStart(newAfterBody);
1009 SmallVector<Value> newAfterArgs = getBbArgReplacements(
1010 rewriter, newWhileOp.getAfterArguments(), indicesAfter);
1011 rewriter.mergeBlocks(source: whileOp.getAfterBody(), dest: newAfterBody, argValues: newAfterArgs);
1012
1013 // Replace loop results.
1014 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1015
1016 return success();
1017 }
1018
1019 FailureOr<BaseMemRefType>
1020 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1021 SmallVector<Value> &invocationStack) const {
1022 auto whileOp = cast<scf::WhileOp>(op);
1023 assert(getOwnerOfValue(value) == op && "invalid value");
1024 assert(isa<TensorType>(value.getType()) && "expected tensor type");
1025
1026 // Case 1: Block argument of the "before" region.
1027 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) {
1028 if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1029 Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1030 auto yieldOp = whileOp.getYieldOp();
1031 Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
1032 return computeLoopRegionIterArgBufferType(
1033 loopOp: op, iterArg: bbArg, initArg, yieldedValue, options, invocationStack);
1034 }
1035 }
1036
1037 // Case 2: OpResult of the loop or block argument of the "after" region.
1038 // The bufferized "after" bbArg type can be directly computed from the
1039 // bufferized "before" bbArg type.
1040 unsigned resultNum;
1041 if (auto opResult = dyn_cast<OpResult>(Val&: value)) {
1042 resultNum = opResult.getResultNumber();
1043 } else if (cast<BlockArgument>(Val&: value).getOwner()->getParent() ==
1044 &whileOp.getAfter()) {
1045 resultNum = cast<BlockArgument>(Val&: value).getArgNumber();
1046 } else {
1047 llvm_unreachable("invalid value");
1048 }
1049 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1050 if (!isa<TensorType>(Val: conditionYieldedVal.getType())) {
1051 // scf.condition was already bufferized.
1052 return cast<BaseMemRefType>(Val: conditionYieldedVal.getType());
1053 }
1054 return bufferization::getBufferType(conditionYieldedVal, options,
1055 invocationStack);
1056 }
1057
1058 /// Assert that yielded values of an scf.while op are equivalent to their
1059 /// corresponding bbArgs. In that case, the buffer relations of the
1060 /// corresponding OpResults are "Equivalent".
1061 ///
1062 /// If this is not the case, allocs+copies are inserted and yielded from
1063 /// the loop. This could be a performance problem, so it must be explicitly
1064 /// activated with `allow-return-allocs`.
1065 ///
1066 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1067 /// equivalence condition must be checked for both.
1068 LogicalResult verifyAnalysis(Operation *op,
1069 const AnalysisState &state) const {
1070 auto whileOp = cast<scf::WhileOp>(op);
1071 const auto &options =
1072 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1073 if (options.allowReturnAllocsFromLoops)
1074 return success();
1075
1076 auto conditionOp = whileOp.getConditionOp();
1077 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1078 Block *block = conditionOp->getBlock();
1079 if (!isa<TensorType>(it.value().getType()))
1080 continue;
1081 if (it.index() >= block->getNumArguments() ||
1082 !state.areEquivalentBufferizedValues(it.value(),
1083 block->getArgument(it.index())))
1084 return conditionOp->emitError()
1085 << "Condition arg #" << it.index()
1086 << " is not equivalent to the corresponding iter bbArg";
1087 }
1088
1089 auto yieldOp = whileOp.getYieldOp();
1090 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1091 Block *block = yieldOp->getBlock();
1092 if (!isa<TensorType>(it.value().getType()))
1093 continue;
1094 if (it.index() >= block->getNumArguments() ||
1095 !state.areEquivalentBufferizedValues(it.value(),
1096 block->getArgument(it.index())))
1097 return yieldOp->emitError()
1098 << "Yield operand #" << it.index()
1099 << " is not equivalent to the corresponding iter bbArg";
1100 }
1101
1102 return success();
1103 }
1104};
1105
1106/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1107/// this is for analysis only.
1108struct YieldOpInterface
1109 : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1110 scf::YieldOp> {
1111 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1112 const AnalysisState &state) const {
1113 return true;
1114 }
1115
1116 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1117 const AnalysisState &state) const {
1118 return false;
1119 }
1120
1121 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1122 const AnalysisState &state) const {
1123 if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
1124 return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()),
1125 BufferRelation::Equivalent, /*isDefinite=*/false}};
1126 }
1127 if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
1128 return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()),
1129 BufferRelation::Equivalent}};
1130 return {};
1131 }
1132
1133 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1134 const AnalysisState &state) const {
1135 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1136 // may be generated inside the block. We should not return/yield allocations
1137 // when possible.
1138 return true;
1139 }
1140
1141 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1142 const BufferizationOptions &options) const {
1143 auto yieldOp = cast<scf::YieldOp>(op);
1144 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1145 scf::WhileOp>(yieldOp->getParentOp()))
1146 return yieldOp->emitError("unsupported scf::YieldOp parent");
1147
1148 SmallVector<Value> newResults;
1149 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1150 Value value = it.value();
1151 if (isa<TensorType>(value.getType())) {
1152 FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
1153 if (failed(maybeBuffer))
1154 return failure();
1155 Value buffer = *maybeBuffer;
1156 // We may have to cast the value before yielding it.
1157 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1158 yieldOp->getParentOp())) {
1159 FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1160 yieldOp->getParentOp()->getResult(it.index()), options);
1161 if (failed(resultType))
1162 return failure();
1163 buffer = castBuffer(rewriter, buffer, *resultType);
1164 } else if (auto whileOp =
1165 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1166 FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1167 whileOp.getBeforeArguments()[it.index()], options);
1168 if (failed(resultType))
1169 return failure();
1170 buffer = castBuffer(rewriter, buffer, *resultType);
1171 }
1172 newResults.push_back(buffer);
1173 } else {
1174 newResults.push_back(value);
1175 }
1176 }
1177
1178 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1179 return success();
1180 }
1181};
1182
1183/// Return `true` if the given loop may have 0 iterations.
1184bool mayHaveZeroIterations(scf::ForallOp forallOp) {
1185 for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
1186 forallOp.getMixedUpperBound())) {
1187 std::optional<int64_t> lbConst = getConstantIntValue(lb);
1188 std::optional<int64_t> ubConst = getConstantIntValue(ub);
1189 if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
1190 return true;
1191 }
1192 return false;
1193}
1194
1195/// Bufferization of ForallOp. This also bufferizes the terminator of the
1196/// region. There are op interfaces for the terminators (InParallelOp
1197/// and ParallelInsertSliceOp), but these are only used during analysis. Not
1198/// for bufferization.
1199struct ForallOpInterface
1200 : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1201 ForallOp> {
1202 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1203 const AnalysisState &state) const {
1204 auto forallOp = cast<ForallOp>(op);
1205
1206 // If the loop has zero iterations, the results of the op are their
1207 // corresponding shared_outs, meaning that the shared_outs bufferize to a
1208 // read.
1209 if (mayHaveZeroIterations(forallOp))
1210 return true;
1211
1212 // scf::ForallOp alone doesn't bufferize to a memory read, one of the
1213 // uses of its matching bbArg may.
1214 return state.isValueRead(value: forallOp.getTiedBlockArgument(&opOperand));
1215 }
1216
1217 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1218 const AnalysisState &state) const {
1219 // Outputs of scf::ForallOps are always considered as a write.
1220 return true;
1221 }
1222
1223 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1224 const AnalysisState &state) const {
1225 auto forallOp = cast<ForallOp>(op);
1226 return {
1227 {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1228 }
1229
1230 bool isWritable(Operation *op, Value value,
1231 const AnalysisState &state) const {
1232 return true;
1233 }
1234
1235 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1236 const BufferizationOptions &options) const {
1237 OpBuilder::InsertionGuard guard(rewriter);
1238 auto forallOp = cast<ForallOp>(op);
1239 int64_t rank = forallOp.getRank();
1240
1241 // Get buffers for all output operands.
1242 SmallVector<Value> buffers;
1243 for (Value out : forallOp.getOutputs()) {
1244 FailureOr<Value> buffer = getBuffer(rewriter, out, options);
1245 if (failed(buffer))
1246 return failure();
1247 buffers.push_back(*buffer);
1248 }
1249
1250 // Use buffers instead of block arguments.
1251 rewriter.setInsertionPointToStart(forallOp.getBody());
1252 for (const auto &it : llvm::zip(
1253 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1254 BlockArgument bbArg = std::get<0>(it);
1255 Value buffer = std::get<1>(it);
1256 Value bufferAsTensor =
1257 rewriter.create<ToTensorOp>(forallOp.getLoc(), buffer);
1258 bbArg.replaceAllUsesWith(bufferAsTensor);
1259 }
1260
1261 // Create new ForallOp without any results and drop the automatically
1262 // introduced terminator.
1263 rewriter.setInsertionPoint(forallOp);
1264 ForallOp newForallOp;
1265 newForallOp = rewriter.create<ForallOp>(
1266 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1267 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1268 /*outputs=*/ValueRange(), forallOp.getMapping());
1269
1270 rewriter.eraseOp(op: newForallOp.getBody()->getTerminator());
1271
1272 // Move over block contents of the old op.
1273 SmallVector<Value> replacementBbArgs;
1274 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1275 newForallOp.getBody()->getArguments().end());
1276 replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1277 rewriter.mergeBlocks(source: forallOp.getBody(), dest: newForallOp.getBody(),
1278 argValues: replacementBbArgs);
1279
1280 // Remove the old op and replace all of its uses.
1281 replaceOpWithBufferizedValues(rewriter, op, values: buffers);
1282
1283 return success();
1284 }
1285
1286 FailureOr<BaseMemRefType>
1287 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1288 SmallVector<Value> &invocationStack) const {
1289 auto forallOp = cast<ForallOp>(op);
1290
1291 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value))
1292 // A tensor block argument has the same bufferized type as the
1293 // corresponding output operand.
1294 return bufferization::getBufferType(
1295 value: forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
1296
1297 // The bufferized result type is the same as the bufferized type of the
1298 // corresponding output operand.
1299 return bufferization::getBufferType(
1300 value: forallOp.getOutputs()[cast<OpResult>(Val&: value).getResultNumber()], options,
1301 invocationStack);
1302 }
1303
1304 bool isRepetitiveRegion(Operation *op, unsigned index) const {
1305 auto forallOp = cast<ForallOp>(op);
1306
1307 // This op is repetitive if it has 1 or more steps.
1308 // If the control variables are dynamic, it is also considered so.
1309 for (auto [lb, ub, step] :
1310 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1311 forallOp.getMixedStep())) {
1312 std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1313 if (!lbConstant)
1314 return true;
1315
1316 std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1317 if (!ubConstant)
1318 return true;
1319
1320 std::optional<int64_t> stepConstant = getConstantIntValue(step);
1321 if (!stepConstant)
1322 return true;
1323
1324 if (*lbConstant + *stepConstant < *ubConstant)
1325 return true;
1326 }
1327 return false;
1328 }
1329
1330 bool isParallelRegion(Operation *op, unsigned index) const {
1331 return isRepetitiveRegion(op, index);
1332 }
1333};
1334
1335/// Nothing to do for InParallelOp.
1336struct InParallelOpInterface
1337 : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1338 InParallelOp> {
1339 LogicalResult bufferize(Operation *op, RewriterBase &b,
1340 const BufferizationOptions &options) const {
1341 llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1342 return failure();
1343 }
1344};
1345
1346} // namespace
1347} // namespace scf
1348} // namespace mlir
1349
1350void mlir::scf::registerBufferizableOpInterfaceExternalModels(
1351 DialectRegistry &registry) {
1352 registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1353 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1354 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1355 ForOp::attachInterface<ForOpInterface>(*ctx);
1356 IfOp::attachInterface<IfOpInterface>(*ctx);
1357 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1358 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1359 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1360 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1361 YieldOp::attachInterface<YieldOpInterface>(*ctx);
1362 });
1363}
1364

source code of mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp