1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
14#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Utils/StaticValueUtils.h"
20#include "mlir/IR/Dialect.h"
21#include "mlir/IR/Operation.h"
22#include "mlir/IR/PatternMatch.h"
23
24using namespace mlir;
25using namespace mlir::bufferization;
26using namespace mlir::scf;
27
28namespace mlir {
29namespace scf {
30namespace {
31
32/// Helper function for loop bufferization. Cast the given buffer to the given
33/// memref type.
34static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
35 assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
36 assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
37 // If the buffer already has the correct type, no cast is needed.
38 if (buffer.getType() == type)
39 return buffer;
40 // TODO: In case `type` has a layout map that is not the fully dynamic
41 // one, we may not be able to cast the buffer. In that case, the loop
42 // iter_arg's layout map must be changed (see uses of `castBuffer`).
43 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
44 "scf.while op bufferization: cast incompatible");
45 return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
46}
47
48/// Helper function for loop bufferization. Return "true" if the given value
49/// is guaranteed to not alias with an external tensor apart from values in
50/// `exceptions`. A value is external if it is defined outside of the given
51/// region or if it is an entry block argument of the region.
52static bool doesNotAliasExternalValue(Value value, Region *region,
53 ValueRange exceptions,
54 const OneShotAnalysisState &state) {
55 assert(llvm::hasSingleElement(region->getBlocks()) &&
56 "expected region with single block");
57 bool result = true;
58 state.applyOnAliases(v: value, fun: [&](Value alias) {
59 if (llvm::is_contained(Range&: exceptions, Element: alias))
60 return;
61 Region *aliasRegion = alias.getParentRegion();
62 if (isa<BlockArgument>(Val: alias) && !region->isProperAncestor(other: aliasRegion))
63 result = false;
64 if (isa<OpResult>(Val: alias) && !region->isAncestor(other: aliasRegion))
65 result = false;
66 });
67 return result;
68}
69
70/// Bufferization of scf.condition.
71struct ConditionOpInterface
72 : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
73 scf::ConditionOp> {
74 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
75 const AnalysisState &state) const {
76 return true;
77 }
78
79 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
80 const AnalysisState &state) const {
81 return false;
82 }
83
84 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
85 const AnalysisState &state) const {
86 return {};
87 }
88
89 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
90 const AnalysisState &state) const {
91 // Condition operands always bufferize inplace. Otherwise, an alloc + copy
92 // may be generated inside the block. We should not return/yield allocations
93 // when possible.
94 return true;
95 }
96
97 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
98 const BufferizationOptions &options,
99 BufferizationState &state) const {
100 auto conditionOp = cast<scf::ConditionOp>(op);
101 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
102
103 SmallVector<Value> newArgs;
104 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
105 Value value = it.value();
106 if (isa<TensorType>(value.getType())) {
107 FailureOr<Value> maybeBuffer =
108 getBuffer(rewriter, value, options, state);
109 if (failed(maybeBuffer))
110 return failure();
111 FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
112 whileOp.getAfterArguments()[it.index()], options, state);
113 if (failed(resultType))
114 return failure();
115 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
116 newArgs.push_back(buffer);
117 } else {
118 newArgs.push_back(value);
119 }
120 }
121
122 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
123 rewriter, op, conditionOp.getCondition(), newArgs);
124 return success();
125 }
126};
127
128/// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
129/// return an empty op.
130static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
131 scf::YieldOp result;
132 for (Block &block : executeRegionOp.getRegion()) {
133 if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
134 if (result)
135 return {};
136 result = yieldOp;
137 }
138 }
139 return result;
140}
141
142/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
143/// fully implemented at the moment.
144struct ExecuteRegionOpInterface
145 : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
146 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
147
148 static bool supportsUnstructuredControlFlow() { return true; }
149
150 bool isWritable(Operation *op, Value value,
151 const AnalysisState &state) const {
152 return true;
153 }
154
155 LogicalResult verifyAnalysis(Operation *op,
156 const AnalysisState &state) const {
157 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
158 // TODO: scf.execute_region with multiple yields are not supported.
159 if (!getUniqueYieldOp(executeRegionOp))
160 return op->emitOpError(message: "op without unique scf.yield is not supported");
161 return success();
162 }
163
164 AliasingOpOperandList
165 getAliasingOpOperands(Operation *op, Value value,
166 const AnalysisState &state) const {
167 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value))
168 return getAliasingBranchOpOperands(op, bbArg, state);
169
170 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
171 // any SSA value that is in scope. To allow for use-def chain traversal
172 // through ExecuteRegionOps in the analysis, the corresponding yield value
173 // is considered to be aliasing with the result.
174 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
175 auto it = llvm::find(Range: op->getOpResults(), Val: value);
176 assert(it != op->getOpResults().end() && "invalid value");
177 size_t resultNum = std::distance(first: op->getOpResults().begin(), last: it);
178 auto yieldOp = getUniqueYieldOp(executeRegionOp);
179 // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
180 if (!yieldOp)
181 return {};
182 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
183 }
184
185 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
186 const BufferizationOptions &options,
187 BufferizationState &state) const {
188 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
189 auto yieldOp = getUniqueYieldOp(executeRegionOp);
190 TypeRange newResultTypes(yieldOp.getResults());
191
192 // Create new op and move over region.
193 auto newOp =
194 rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
195 newOp.getRegion().takeBody(executeRegionOp.getRegion());
196
197 // Bufferize every block.
198 for (Block &block : newOp.getRegion())
199 if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
200 options, state)))
201 return failure();
202
203 // Update all uses of the old op.
204 rewriter.setInsertionPointAfter(newOp);
205 SmallVector<Value> newResults;
206 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
207 if (isa<TensorType>(it.value())) {
208 newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
209 executeRegionOp.getLoc(), it.value(),
210 newOp->getResult(it.index())));
211 } else {
212 newResults.push_back(newOp->getResult(it.index()));
213 }
214 }
215
216 // Replace old op.
217 rewriter.replaceOp(executeRegionOp, newResults);
218
219 return success();
220 }
221};
222
223/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
224struct IfOpInterface
225 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
226 AliasingOpOperandList
227 getAliasingOpOperands(Operation *op, Value value,
228 const AnalysisState &state) const {
229 // IfOps do not have tensor OpOperands. The yielded value can be any SSA
230 // value that is in scope. To allow for use-def chain traversal through
231 // IfOps in the analysis, both corresponding yield values from the then/else
232 // branches are considered to be aliasing with the result.
233 auto ifOp = cast<scf::IfOp>(op);
234 size_t resultNum = std::distance(first: op->getOpResults().begin(),
235 last: llvm::find(Range: op->getOpResults(), Val: value));
236 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
237 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
238 return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
239 {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
240 }
241
242 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
243 const BufferizationOptions &options,
244 BufferizationState &state) const {
245 OpBuilder::InsertionGuard g(rewriter);
246 auto ifOp = cast<scf::IfOp>(op);
247
248 // Compute bufferized result types.
249 SmallVector<Type> newTypes;
250 for (Value result : ifOp.getResults()) {
251 if (!isa<TensorType>(result.getType())) {
252 newTypes.push_back(result.getType());
253 continue;
254 }
255 auto bufferType = bufferization::getBufferType(result, options, state);
256 if (failed(bufferType))
257 return failure();
258 newTypes.push_back(*bufferType);
259 }
260
261 // Create new op.
262 rewriter.setInsertionPoint(ifOp);
263 auto newIfOp =
264 rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
265 /*withElseRegion=*/true);
266
267 // Move over then/else blocks.
268 rewriter.mergeBlocks(source: ifOp.thenBlock(), dest: newIfOp.thenBlock());
269 rewriter.mergeBlocks(source: ifOp.elseBlock(), dest: newIfOp.elseBlock());
270
271 // Replace op results.
272 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
273
274 return success();
275 }
276
277 FailureOr<BaseMemRefType>
278 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
279 const BufferizationState &state,
280 SmallVector<Value> &invocationStack) const {
281 auto ifOp = cast<scf::IfOp>(op);
282 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
283 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
284 assert(value.getDefiningOp() == op && "invalid valid");
285
286 // Determine buffer types of the true/false branches.
287 auto opResult = cast<OpResult>(Val&: value);
288 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
289 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
290 BaseMemRefType thenBufferType, elseBufferType;
291 if (isa<BaseMemRefType>(thenValue.getType())) {
292 // True branch was already bufferized.
293 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
294 } else {
295 auto maybeBufferType = bufferization::getBufferType(
296 value: thenValue, options, state, invocationStack);
297 if (failed(maybeBufferType))
298 return failure();
299 thenBufferType = *maybeBufferType;
300 }
301 if (isa<BaseMemRefType>(elseValue.getType())) {
302 // False branch was already bufferized.
303 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
304 } else {
305 auto maybeBufferType = bufferization::getBufferType(
306 value: elseValue, options, state, invocationStack);
307 if (failed(maybeBufferType))
308 return failure();
309 elseBufferType = *maybeBufferType;
310 }
311
312 // Best case: Both branches have the exact same buffer type.
313 if (thenBufferType == elseBufferType)
314 return thenBufferType;
315
316 // Memory space mismatch.
317 if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
318 return op->emitError(message: "inconsistent memory space on then/else branches");
319
320 // Layout maps are different: Promote to fully dynamic layout map.
321 return getMemRefTypeWithFullyDynamicLayout(
322 tensorType: cast<TensorType>(Val: opResult.getType()), memorySpace: thenBufferType.getMemorySpace());
323 }
324};
325
326/// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
327/// yields memrefs.
328struct IndexSwitchOpInterface
329 : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
330 scf::IndexSwitchOp> {
331 AliasingOpOperandList
332 getAliasingOpOperands(Operation *op, Value value,
333 const AnalysisState &state) const {
334 // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
335 // any SSA. This is similar to IfOps.
336 auto switchOp = cast<scf::IndexSwitchOp>(op);
337 int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber();
338 AliasingOpOperandList result;
339 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
340 auto yieldOp =
341 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
342 result.addAlias(alias: AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
343 BufferRelation::Equivalent,
344 /*isDefinite=*/false));
345 }
346 auto defaultYieldOp =
347 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
348 result.addAlias(alias: AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
349 BufferRelation::Equivalent,
350 /*isDefinite=*/false));
351 return result;
352 }
353
354 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
355 const BufferizationOptions &options,
356 BufferizationState &state) const {
357 OpBuilder::InsertionGuard g(rewriter);
358 auto switchOp = cast<scf::IndexSwitchOp>(op);
359
360 // Compute bufferized result types.
361 SmallVector<Type> newTypes;
362 for (Value result : switchOp.getResults()) {
363 if (!isa<TensorType>(result.getType())) {
364 newTypes.push_back(result.getType());
365 continue;
366 }
367 auto bufferType = bufferization::getBufferType(result, options, state);
368 if (failed(bufferType))
369 return failure();
370 newTypes.push_back(*bufferType);
371 }
372
373 // Create new op.
374 rewriter.setInsertionPoint(switchOp);
375 auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
376 switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
377 switchOp.getCases().size());
378
379 // Move over blocks.
380 for (auto [src, dest] :
381 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
382 rewriter.inlineRegionBefore(src, dest, dest.begin());
383 rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
384 newSwitchOp.getDefaultRegion(),
385 newSwitchOp.getDefaultRegion().begin());
386
387 // Replace op results.
388 replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
389
390 return success();
391 }
392
393 FailureOr<BaseMemRefType>
394 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
395 const BufferizationState &state,
396 SmallVector<Value> &invocationStack) const {
397 auto switchOp = cast<scf::IndexSwitchOp>(op);
398 assert(value.getDefiningOp() == op && "invalid value");
399 int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber();
400
401 // Helper function to get buffer type of a case.
402 auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
403 auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
404 Value yieldedValue = yieldOp->getOperand(resultNum);
405 if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
406 return bufferType;
407 auto maybeBufferType = bufferization::getBufferType(
408 yieldedValue, options, state, invocationStack);
409 if (failed(maybeBufferType))
410 return failure();
411 return maybeBufferType;
412 };
413
414 // Compute buffer type of the default case.
415 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
416 if (failed(maybeBufferType))
417 return failure();
418 BaseMemRefType bufferType = *maybeBufferType;
419
420 // Compute buffer types of all other cases.
421 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
422 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
423 if (failed(yieldedBufferType))
424 return failure();
425
426 // Best case: Both branches have the exact same buffer type.
427 if (bufferType == *yieldedBufferType)
428 continue;
429
430 // Memory space mismatch.
431 if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
432 return op->emitError(message: "inconsistent memory space on switch cases");
433
434 // Layout maps are different: Promote to fully dynamic layout map.
435 bufferType = getMemRefTypeWithFullyDynamicLayout(
436 tensorType: cast<TensorType>(Val: value.getType()), memorySpace: bufferType.getMemorySpace());
437 }
438
439 return bufferType;
440 }
441};
442
443/// Helper function for loop bufferization. Return the indices of all values
444/// that have a tensor type.
445static DenseSet<int64_t> getTensorIndices(ValueRange values) {
446 DenseSet<int64_t> result;
447 for (const auto &it : llvm::enumerate(First&: values))
448 if (isa<TensorType>(Val: it.value().getType()))
449 result.insert(V: it.index());
450 return result;
451}
452
453/// Helper function for loop bufferization. Return the indices of all
454/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
455DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
456 ValueRange yieldedValues,
457 const AnalysisState &state) {
458 unsigned int minSize = std::min(a: bbArgs.size(), b: yieldedValues.size());
459 DenseSet<int64_t> result;
460 for (unsigned int i = 0; i < minSize; ++i) {
461 if (!isa<TensorType>(Val: bbArgs[i].getType()) ||
462 !isa<TensorType>(Val: yieldedValues[i].getType()))
463 continue;
464 if (state.areEquivalentBufferizedValues(v1: bbArgs[i], v2: yieldedValues[i]))
465 result.insert(V: i);
466 }
467 return result;
468}
469
470/// Helper function for loop bufferization. Return the bufferized values of the
471/// given OpOperands. If an operand is not a tensor, return the original value.
472static FailureOr<SmallVector<Value>>
473getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
474 const BufferizationOptions &options, BufferizationState &state) {
475 SmallVector<Value> result;
476 for (OpOperand &opOperand : operands) {
477 if (isa<TensorType>(Val: opOperand.get().getType())) {
478 FailureOr<Value> resultBuffer =
479 getBuffer(rewriter, value: opOperand.get(), options, state);
480 if (failed(Result: resultBuffer))
481 return failure();
482 result.push_back(Elt: *resultBuffer);
483 } else {
484 result.push_back(Elt: opOperand.get());
485 }
486 }
487 return result;
488}
489
490/// Helper function for loop bufferization. Given a list of bbArgs of the new
491/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
492/// ToTensorOps, so that the block body can be moved over to the new op.
493static SmallVector<Value>
494getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
495 Block::BlockArgListType oldBbArgs,
496 const DenseSet<int64_t> &tensorIndices) {
497 SmallVector<Value> result;
498 for (const auto &it : llvm::enumerate(First&: bbArgs)) {
499 size_t idx = it.index();
500 Value val = it.value();
501 if (tensorIndices.contains(V: idx)) {
502 result.push_back(rewriter
503 .create<bufferization::ToTensorOp>(
504 val.getLoc(), oldBbArgs[idx].getType(), val)
505 .getResult());
506 } else {
507 result.push_back(Elt: val);
508 }
509 }
510 return result;
511}
512
513/// Compute the bufferized type of a loop iter_arg. This type must be equal to
514/// the bufferized type of the corresponding init_arg and the bufferized type
515/// of the corresponding yielded value.
516///
517/// This function uses bufferization::getBufferType to compute the bufferized
518/// type of the init_arg and of the yielded value. (The computation of the
519/// bufferized yielded value type usually requires computing the bufferized type
520/// of the iter_arg again; the implementation of getBufferType traces back the
521/// use-def chain of the given value and computes a buffer type along the way.)
522/// If both buffer types are equal, no casts are needed the computed buffer type
523/// can be used directly. Otherwise, the buffer types can only differ in their
524/// layout map and a cast must be inserted.
525static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
526 Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
527 const BufferizationOptions &options, const BufferizationState &state,
528 SmallVector<Value> &invocationStack) {
529 // Determine the buffer type of the init_arg.
530 auto initArgBufferType =
531 bufferization::getBufferType(initArg, options, state, invocationStack);
532 if (failed(initArgBufferType))
533 return failure();
534
535 if (llvm::count(Range&: invocationStack, Element: iterArg) >= 2) {
536 // If the iter_arg is already twice on the invocation stack, just take the
537 // type of the init_arg. This is to avoid infinite loops when calculating
538 // the buffer type. This will most likely result in computing a memref type
539 // with a fully dynamic layout map.
540
541 // Note: For more precise layout map computation, a fixpoint iteration could
542 // be done (i.e., re-computing the yielded buffer type until the bufferized
543 // iter_arg type no longer changes). This current implementation immediately
544 // switches to a fully dynamic layout map when a mismatch between bufferized
545 // init_arg type and bufferized yield value type is detected.
546 return *initArgBufferType;
547 }
548
549 // Compute the buffer type of the yielded value.
550 BaseMemRefType yieldedValueBufferType;
551 if (isa<BaseMemRefType>(Val: yieldedValue.getType())) {
552 // scf.yield was already bufferized.
553 yieldedValueBufferType = cast<BaseMemRefType>(Val: yieldedValue.getType());
554 } else {
555 // Note: This typically triggers a recursive call for the buffer type of
556 // the iter_arg.
557 auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
558 state, invocationStack);
559 if (failed(maybeBufferType))
560 return failure();
561 yieldedValueBufferType = *maybeBufferType;
562 }
563
564 // If yielded type and init_arg type are the same, use that type directly.
565 if (*initArgBufferType == yieldedValueBufferType)
566 return yieldedValueBufferType;
567
568 // If there is a mismatch between the yielded buffer type and the init_arg
569 // buffer type, the buffer type must be promoted to a fully dynamic layout
570 // map.
571 auto yieldedBufferType = cast<BaseMemRefType>(Val&: yieldedValueBufferType);
572 auto iterTensorType = cast<TensorType>(Val: iterArg.getType());
573 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
574 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
575 return loopOp->emitOpError(
576 message: "init_arg and yielded value bufferize to inconsistent memory spaces");
577#ifndef NDEBUG
578 if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
579 assert(
580 llvm::all_equal({yieldedRankedBufferType.getShape(),
581 cast<MemRefType>(initBufferType).getShape(),
582 cast<RankedTensorType>(iterTensorType).getShape()}) &&
583 "expected same shape");
584 }
585#endif // NDEBUG
586 return getMemRefTypeWithFullyDynamicLayout(
587 tensorType: iterTensorType, memorySpace: yieldedBufferType.getMemorySpace());
588}
589
590/// Return `true` if the given loop may have 0 iterations.
591bool mayHaveZeroIterations(scf::ForOp forOp) {
592 std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
593 std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
594 if (!lb.has_value() || !ub.has_value())
595 return true;
596 return *ub <= *lb;
597}
598
599/// Bufferization of scf.for. Replace with a new scf.for that operates on
600/// memrefs.
601struct ForOpInterface
602 : public BufferizableOpInterface::ExternalModel<ForOpInterface,
603 scf::ForOp> {
604 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
605 const AnalysisState &state) const {
606 auto forOp = cast<scf::ForOp>(op);
607
608 // If the loop has zero iterations, the results of the op are their
609 // corresponding init_args, meaning that the init_args bufferize to a read.
610 if (mayHaveZeroIterations(forOp))
611 return true;
612
613 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
614 // its matching bbArg may.
615 return state.isValueRead(value: forOp.getTiedLoopRegionIterArg(&opOperand));
616 }
617
618 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
619 const AnalysisState &state) const {
620 // Tensor iter_args of scf::ForOps are always considered as a write.
621 return true;
622 }
623
624 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
625 const AnalysisState &state) const {
626 auto forOp = cast<scf::ForOp>(op);
627 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
628 BufferRelation relation = bufferRelation(op, opResult, state);
629 return {{opResult, relation,
630 /*isDefinite=*/relation == BufferRelation::Equivalent}};
631 }
632
633 BufferRelation bufferRelation(Operation *op, OpResult opResult,
634 const AnalysisState &state) const {
635 // ForOp results are equivalent to their corresponding init_args if the
636 // corresponding iter_args and yield values are equivalent.
637 auto forOp = cast<scf::ForOp>(op);
638 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
639 bool equivalentYield = state.areEquivalentBufferizedValues(
640 v1: bbArg, v2: forOp.getTiedLoopYieldedValue(bbArg)->get());
641 return equivalentYield ? BufferRelation::Equivalent
642 : BufferRelation::Unknown;
643 }
644
645 bool isWritable(Operation *op, Value value,
646 const AnalysisState &state) const {
647 // Interestingly, scf::ForOp's bbArg can **always** be viewed
648 // inplace from the perspective of ops nested under:
649 // 1. Either the matching iter operand is not bufferized inplace and an
650 // alloc + optional copy makes the bbArg itself inplaceable.
651 // 2. Or the matching iter operand is bufferized inplace and bbArg just
652 // bufferizes to that too.
653 return true;
654 }
655
656 LogicalResult
657 resolveConflicts(Operation *op, RewriterBase &rewriter,
658 const AnalysisState &analysisState,
659 const BufferizationState &bufferizationState) const {
660 auto bufferizableOp = cast<BufferizableOpInterface>(op);
661 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
662 rewriter, analysisState, bufferizationState)))
663 return failure();
664
665 if (analysisState.getOptions().copyBeforeWrite)
666 return success();
667
668 // According to the `getAliasing...` implementations, a bufferized OpResult
669 // may alias only with the corresponding bufferized init_arg (or with a
670 // newly allocated buffer) and not with other buffers defined outside of the
671 // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
672 // but not with any other OpOperand.
673 auto forOp = cast<scf::ForOp>(op);
674 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
675 OpBuilder::InsertionGuard g(rewriter);
676 rewriter.setInsertionPoint(yieldOp);
677
678 // Indices of all iter_args that have tensor type. These are the ones that
679 // are bufferized.
680 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
681 // For every yielded value, does it alias with something defined outside of
682 // the loop?
683 SmallVector<Value> yieldValues;
684 for (const auto it : llvm::enumerate(yieldOp.getResults())) {
685 // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
686 // type cannot be used in the signature of `resolveConflicts` because the
687 // op interface is in the "IR" build unit and the `OneShotAnalysisState`
688 // is defined in the "Transforms" build unit.
689 if (!indices.contains(it.index()) ||
690 doesNotAliasExternalValue(
691 it.value(), &forOp.getRegion(),
692 /*exceptions=*/forOp.getRegionIterArg(it.index()),
693 static_cast<const OneShotAnalysisState &>(analysisState))) {
694 yieldValues.push_back(it.value());
695 continue;
696 }
697 FailureOr<Value> alloc = allocateTensorForShapedValue(
698 rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
699 bufferizationState);
700 if (failed(alloc))
701 return failure();
702 yieldValues.push_back(*alloc);
703 }
704
705 rewriter.modifyOpInPlace(
706 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
707 return success();
708 }
709
710 FailureOr<BaseMemRefType>
711 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
712 const BufferizationState &state,
713 SmallVector<Value> &invocationStack) const {
714 auto forOp = cast<scf::ForOp>(op);
715 assert(getOwnerOfValue(value) == op && "invalid value");
716 assert(isa<TensorType>(value.getType()) && "expected tensor type");
717
718 if (auto opResult = dyn_cast<OpResult>(Val&: value)) {
719 // The type of an OpResult must match the corresponding iter_arg type.
720 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
721 return bufferization::getBufferType(bbArg, options, state,
722 invocationStack);
723 }
724
725 // Compute result/argument number.
726 BlockArgument bbArg = cast<BlockArgument>(Val&: value);
727 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
728
729 // Compute the bufferized type.
730 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
731 Value yieldedValue = yieldOp.getOperand(resultNum);
732 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
733 Value initArg = forOp.getInitArgs()[resultNum];
734 return computeLoopRegionIterArgBufferType(
735 loopOp: op, iterArg, initArg, yieldedValue, options, state, invocationStack);
736 }
737
738 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
739 const BufferizationOptions &options,
740 BufferizationState &state) const {
741 auto forOp = cast<scf::ForOp>(op);
742 Block *oldLoopBody = forOp.getBody();
743
744 // Indices of all iter_args that have tensor type. These are the ones that
745 // are bufferized.
746 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
747
748 // The new memref init_args of the loop.
749 FailureOr<SmallVector<Value>> maybeInitArgs =
750 getBuffers(rewriter, forOp.getInitArgsMutable(), options, state);
751 if (failed(Result: maybeInitArgs))
752 return failure();
753 SmallVector<Value> initArgs = *maybeInitArgs;
754
755 // Cast init_args if necessary.
756 SmallVector<Value> castedInitArgs;
757 for (const auto &it : llvm::enumerate(initArgs)) {
758 Value initArg = it.value();
759 Value result = forOp->getResult(it.index());
760 // If the type is not a tensor, bufferization doesn't need to touch it.
761 if (!isa<TensorType>(result.getType())) {
762 castedInitArgs.push_back(initArg);
763 continue;
764 }
765 auto targetType = bufferization::getBufferType(result, options, state);
766 if (failed(targetType))
767 return failure();
768 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
769 }
770
771 // Construct a new scf.for op with memref instead of tensor values.
772 auto newForOp = rewriter.create<scf::ForOp>(
773 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
774 forOp.getStep(), castedInitArgs);
775 newForOp->setAttrs(forOp->getAttrs());
776 Block *loopBody = newForOp.getBody();
777
778 // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
779 // iter_args of the new loop in ToTensorOps.
780 rewriter.setInsertionPointToStart(loopBody);
781 SmallVector<Value> iterArgs =
782 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
783 forOp.getRegionIterArgs(), indices);
784 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
785
786 // Move loop body to new loop.
787 rewriter.mergeBlocks(source: oldLoopBody, dest: loopBody, argValues: iterArgs);
788
789 // Replace loop results.
790 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
791
792 return success();
793 }
794
795 /// Assert that yielded values of an scf.for op are equivalent to their
796 /// corresponding bbArgs. In that case, the buffer relations of the
797 /// corresponding OpResults are "Equivalent".
798 ///
799 /// If this is not the case, an allocs+copies are inserted and yielded from
800 /// the loop. This could be a performance problem, so it must be explicitly
801 /// activated with `alloc-return-allocs`.
802 LogicalResult verifyAnalysis(Operation *op,
803 const AnalysisState &state) const {
804 const auto &options =
805 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
806 if (options.allowReturnAllocsFromLoops)
807 return success();
808
809 auto forOp = cast<scf::ForOp>(op);
810 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
811 for (OpResult opResult : op->getOpResults()) {
812 if (!isa<TensorType>(Val: opResult.getType()))
813 continue;
814
815 // Note: This is overly strict. We should check for aliasing bufferized
816 // values. But we don't have a "must-alias" analysis yet.
817 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
818 return yieldOp->emitError()
819 << "Yield operand #" << opResult.getResultNumber()
820 << " is not equivalent to the corresponding iter bbArg";
821 }
822
823 return success();
824 }
825};
826
827/// Bufferization of scf.while. Replace with a new scf.while that operates on
828/// memrefs.
829struct WhileOpInterface
830 : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
831 scf::WhileOp> {
832 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
833 const AnalysisState &state) const {
834 // Tensor iter_args of scf::WhileOps are always considered as a read.
835 return true;
836 }
837
838 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
839 const AnalysisState &state) const {
840 // Tensor iter_args of scf::WhileOps are always considered as a write.
841 return true;
842 }
843
844 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
845 const AnalysisState &state) const {
846 auto whileOp = cast<scf::WhileOp>(op);
847 unsigned int idx = opOperand.getOperandNumber();
848
849 // The OpResults and OpOperands may not match. They may not even have the
850 // same type. The number of OpResults and OpOperands can also differ.
851 if (idx >= op->getNumResults() ||
852 opOperand.get().getType() != op->getResult(idx).getType())
853 return {};
854
855 // The only aliasing OpResult may be the one at the same index.
856 OpResult opResult = whileOp->getResult(idx);
857 BufferRelation relation = bufferRelation(op, opResult, state);
858 return {{opResult, relation,
859 /*isDefinite=*/relation == BufferRelation::Equivalent}};
860 }
861
862 BufferRelation bufferRelation(Operation *op, OpResult opResult,
863 const AnalysisState &state) const {
864 // WhileOp results are equivalent to their corresponding init_args if the
865 // corresponding iter_args and yield values are equivalent (for both the
866 // "before" and the "after" block).
867 unsigned int resultNumber = opResult.getResultNumber();
868 auto whileOp = cast<scf::WhileOp>(op);
869
870 // The "before" region bbArgs and the OpResults may not match.
871 if (resultNumber >= whileOp.getBeforeArguments().size())
872 return BufferRelation::Unknown;
873 if (opResult.getType() !=
874 whileOp.getBeforeArguments()[resultNumber].getType())
875 return BufferRelation::Unknown;
876
877 auto conditionOp = whileOp.getConditionOp();
878 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
879 Value conditionOperand = conditionOp.getArgs()[resultNumber];
880 bool equivCondition =
881 state.areEquivalentBufferizedValues(v1: conditionBbArg, v2: conditionOperand);
882
883 auto yieldOp = whileOp.getYieldOp();
884 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
885 Value yieldOperand = yieldOp.getOperand(resultNumber);
886 bool equivYield =
887 state.areEquivalentBufferizedValues(v1: bodyBbArg, v2: yieldOperand);
888
889 return equivCondition && equivYield ? BufferRelation::Equivalent
890 : BufferRelation::Unknown;
891 }
892
893 bool isWritable(Operation *op, Value value,
894 const AnalysisState &state) const {
895 // Interestingly, scf::WhileOp's bbArg can **always** be viewed
896 // inplace from the perspective of ops nested under:
897 // 1. Either the matching iter operand is not bufferized inplace and an
898 // alloc + optional copy makes the bbArg itself inplaceable.
899 // 2. Or the matching iter operand is bufferized inplace and bbArg just
900 // bufferizes to that too.
901 return true;
902 }
903
904 LogicalResult
905 resolveConflicts(Operation *op, RewriterBase &rewriter,
906 const AnalysisState &analysisState,
907 const BufferizationState &bufferizationState) const {
908 auto bufferizableOp = cast<BufferizableOpInterface>(op);
909 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
910 rewriter, analysisState, bufferizationState)))
911 return failure();
912
913 if (analysisState.getOptions().copyBeforeWrite)
914 return success();
915
916 // According to the `getAliasing...` implementations, a bufferized OpResult
917 // may alias only with the corresponding bufferized init_arg and with no
918 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
919 // but not with any other OpOperand. If a corresponding OpResult/init_arg
920 // pair bufferizes to equivalent buffers, this aliasing requirement is
921 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
922 // (New buffer copies do not alias with any buffer.)
923 OpBuilder::InsertionGuard g(rewriter);
924 auto whileOp = cast<scf::WhileOp>(op);
925 auto conditionOp = whileOp.getConditionOp();
926
927 // For every yielded value, is the value equivalent to its corresponding
928 // bbArg?
929 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
930 whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
931 DenseSet<int64_t> equivalentYieldsAfter =
932 getEquivalentBuffers(whileOp.getAfterArguments(),
933 whileOp.getYieldOp().getResults(), analysisState);
934
935 // Update "before" region.
936 rewriter.setInsertionPoint(conditionOp);
937 SmallVector<Value> beforeYieldValues;
938 for (int64_t idx = 0;
939 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
940 Value value = conditionOp.getArgs()[idx];
941 if (!isa<TensorType>(Val: value.getType()) ||
942 (equivalentYieldsAfter.contains(V: idx) &&
943 equivalentYieldsBefore.contains(V: idx))) {
944 beforeYieldValues.push_back(Elt: value);
945 continue;
946 }
947 FailureOr<Value> alloc = allocateTensorForShapedValue(
948 rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
949 bufferizationState);
950 if (failed(Result: alloc))
951 return failure();
952 beforeYieldValues.push_back(Elt: *alloc);
953 }
954 rewriter.modifyOpInPlace(conditionOp, [&]() {
955 conditionOp.getArgsMutable().assign(beforeYieldValues);
956 });
957
958 return success();
959 }
960
961 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
962 const BufferizationOptions &options,
963 BufferizationState &state) const {
964 auto whileOp = cast<scf::WhileOp>(op);
965
966 // Indices of all bbArgs that have tensor type. These are the ones that
967 // are bufferized. The "before" and "after" regions may have different args.
968 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
969 DenseSet<int64_t> indicesAfter =
970 getTensorIndices(whileOp.getAfterArguments());
971
972 // The new memref init_args of the loop.
973 FailureOr<SmallVector<Value>> maybeInitArgs =
974 getBuffers(rewriter, whileOp.getInitsMutable(), options, state);
975 if (failed(Result: maybeInitArgs))
976 return failure();
977 SmallVector<Value> initArgs = *maybeInitArgs;
978
979 // Cast init_args if necessary.
980 SmallVector<Value> castedInitArgs;
981 for (const auto &it : llvm::enumerate(initArgs)) {
982 Value initArg = it.value();
983 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
984 // If the type is not a tensor, bufferization doesn't need to touch it.
985 if (!isa<TensorType>(beforeArg.getType())) {
986 castedInitArgs.push_back(initArg);
987 continue;
988 }
989 auto targetType = bufferization::getBufferType(beforeArg, options, state);
990 if (failed(targetType))
991 return failure();
992 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
993 }
994
995 // The result types of a WhileOp are the same as the "after" bbArg types.
996 SmallVector<Type> argsTypesAfter = llvm::to_vector(
997 llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
998 if (!isa<TensorType>(Val: bbArg.getType()))
999 return bbArg.getType();
1000 // TODO: error handling
1001 return llvm::cast<Type>(
1002 Val: *bufferization::getBufferType(value: bbArg, options, state));
1003 }));
1004
1005 // Construct a new scf.while op with memref instead of tensor values.
1006 ValueRange argsRangeBefore(castedInitArgs);
1007 TypeRange argsTypesBefore(argsRangeBefore);
1008 auto newWhileOp = rewriter.create<scf::WhileOp>(
1009 whileOp.getLoc(), argsTypesAfter, castedInitArgs);
1010
1011 // Add before/after regions to the new op.
1012 SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
1013 whileOp.getLoc());
1014 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
1015 whileOp.getLoc());
1016 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1017 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
1018 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1019 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1020
1021 // Set up new iter_args and move the loop condition block to the new op.
1022 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1023 // in ToTensorOps.
1024 rewriter.setInsertionPointToStart(newBeforeBody);
1025 SmallVector<Value> newBeforeArgs =
1026 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1027 whileOp.getBeforeArguments(), indicesBefore);
1028 rewriter.mergeBlocks(source: whileOp.getBeforeBody(), dest: newBeforeBody, argValues: newBeforeArgs);
1029
1030 // Set up new iter_args and move the loop body block to the new op.
1031 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1032 // in ToTensorOps.
1033 rewriter.setInsertionPointToStart(newAfterBody);
1034 SmallVector<Value> newAfterArgs =
1035 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1036 whileOp.getAfterArguments(), indicesAfter);
1037 rewriter.mergeBlocks(source: whileOp.getAfterBody(), dest: newAfterBody, argValues: newAfterArgs);
1038
1039 // Replace loop results.
1040 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1041
1042 return success();
1043 }
1044
1045 FailureOr<BaseMemRefType>
1046 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1047 const BufferizationState &state,
1048 SmallVector<Value> &invocationStack) const {
1049 auto whileOp = cast<scf::WhileOp>(op);
1050 assert(getOwnerOfValue(value) == op && "invalid value");
1051 assert(isa<TensorType>(value.getType()) && "expected tensor type");
1052
1053 // Case 1: Block argument of the "before" region.
1054 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) {
1055 if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1056 Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1057 auto yieldOp = whileOp.getYieldOp();
1058 Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
1059 return computeLoopRegionIterArgBufferType(
1060 loopOp: op, iterArg: bbArg, initArg, yieldedValue, options, state, invocationStack);
1061 }
1062 }
1063
1064 // Case 2: OpResult of the loop or block argument of the "after" region.
1065 // The bufferized "after" bbArg type can be directly computed from the
1066 // bufferized "before" bbArg type.
1067 unsigned resultNum;
1068 if (auto opResult = dyn_cast<OpResult>(Val&: value)) {
1069 resultNum = opResult.getResultNumber();
1070 } else if (cast<BlockArgument>(Val&: value).getOwner()->getParent() ==
1071 &whileOp.getAfter()) {
1072 resultNum = cast<BlockArgument>(Val&: value).getArgNumber();
1073 } else {
1074 llvm_unreachable("invalid value");
1075 }
1076 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1077 if (!isa<TensorType>(Val: conditionYieldedVal.getType())) {
1078 // scf.condition was already bufferized.
1079 return cast<BaseMemRefType>(Val: conditionYieldedVal.getType());
1080 }
1081 return bufferization::getBufferType(conditionYieldedVal, options, state,
1082 invocationStack);
1083 }
1084
1085 /// Assert that yielded values of an scf.while op are equivalent to their
1086 /// corresponding bbArgs. In that case, the buffer relations of the
1087 /// corresponding OpResults are "Equivalent".
1088 ///
1089 /// If this is not the case, allocs+copies are inserted and yielded from
1090 /// the loop. This could be a performance problem, so it must be explicitly
1091 /// activated with `allow-return-allocs`.
1092 ///
1093 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1094 /// equivalence condition must be checked for both.
1095 LogicalResult verifyAnalysis(Operation *op,
1096 const AnalysisState &state) const {
1097 auto whileOp = cast<scf::WhileOp>(op);
1098 const auto &options =
1099 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1100 if (options.allowReturnAllocsFromLoops)
1101 return success();
1102
1103 auto conditionOp = whileOp.getConditionOp();
1104 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1105 Block *block = conditionOp->getBlock();
1106 if (!isa<TensorType>(it.value().getType()))
1107 continue;
1108 if (it.index() >= block->getNumArguments() ||
1109 !state.areEquivalentBufferizedValues(it.value(),
1110 block->getArgument(it.index())))
1111 return conditionOp->emitError()
1112 << "Condition arg #" << it.index()
1113 << " is not equivalent to the corresponding iter bbArg";
1114 }
1115
1116 auto yieldOp = whileOp.getYieldOp();
1117 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1118 Block *block = yieldOp->getBlock();
1119 if (!isa<TensorType>(it.value().getType()))
1120 continue;
1121 if (it.index() >= block->getNumArguments() ||
1122 !state.areEquivalentBufferizedValues(it.value(),
1123 block->getArgument(it.index())))
1124 return yieldOp->emitError()
1125 << "Yield operand #" << it.index()
1126 << " is not equivalent to the corresponding iter bbArg";
1127 }
1128
1129 return success();
1130 }
1131};
1132
1133/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1134/// this is for analysis only.
1135struct YieldOpInterface
1136 : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1137 scf::YieldOp> {
1138 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1139 const AnalysisState &state) const {
1140 return true;
1141 }
1142
1143 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1144 const AnalysisState &state) const {
1145 return false;
1146 }
1147
1148 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1149 const AnalysisState &state) const {
1150 if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
1151 return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()),
1152 BufferRelation::Equivalent, /*isDefinite=*/false}};
1153 }
1154 if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
1155 return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()),
1156 BufferRelation::Equivalent}};
1157 return {};
1158 }
1159
1160 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1161 const AnalysisState &state) const {
1162 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1163 // may be generated inside the block. We should not return/yield allocations
1164 // when possible.
1165 return true;
1166 }
1167
1168 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1169 const BufferizationOptions &options,
1170 BufferizationState &state) const {
1171 auto yieldOp = cast<scf::YieldOp>(op);
1172 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1173 scf::WhileOp>(yieldOp->getParentOp()))
1174 return yieldOp->emitError("unsupported scf::YieldOp parent");
1175
1176 SmallVector<Value> newResults;
1177 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1178 Value value = it.value();
1179 if (isa<TensorType>(value.getType())) {
1180 FailureOr<Value> maybeBuffer =
1181 getBuffer(rewriter, value, options, state);
1182 if (failed(maybeBuffer))
1183 return failure();
1184 Value buffer = *maybeBuffer;
1185 // We may have to cast the value before yielding it.
1186 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1187 yieldOp->getParentOp())) {
1188 FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1189 yieldOp->getParentOp()->getResult(it.index()), options, state);
1190 if (failed(resultType))
1191 return failure();
1192 buffer = castBuffer(rewriter, buffer, *resultType);
1193 } else if (auto whileOp =
1194 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1195 FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1196 whileOp.getBeforeArguments()[it.index()], options, state);
1197 if (failed(resultType))
1198 return failure();
1199 buffer = castBuffer(rewriter, buffer, *resultType);
1200 }
1201 newResults.push_back(buffer);
1202 } else {
1203 newResults.push_back(value);
1204 }
1205 }
1206
1207 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1208 return success();
1209 }
1210};
1211
1212/// Bufferization of ForallOp. This also bufferizes the terminator of the
1213/// region. There are op interfaces for the terminators (InParallelOp
1214/// and ParallelInsertSliceOp), but these are only used during analysis. Not
1215/// for bufferization.
1216struct ForallOpInterface
1217 : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1218 ForallOp> {
1219 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1220 const AnalysisState &state) const {
1221 // All tensor operands to `scf.forall` are `shared_outs` and all
1222 // shared outs are assumed to be read by the loop. This does not
1223 // account for the case where the entire value is over-written,
1224 // but being conservative here.
1225 return true;
1226 }
1227
1228 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1229 const AnalysisState &state) const {
1230 // Outputs of scf::ForallOps are always considered as a write.
1231 return true;
1232 }
1233
1234 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1235 const AnalysisState &state) const {
1236 auto forallOp = cast<ForallOp>(op);
1237 return {
1238 {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1239 }
1240
1241 bool isWritable(Operation *op, Value value,
1242 const AnalysisState &state) const {
1243 return true;
1244 }
1245
1246 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1247 const BufferizationOptions &options,
1248 BufferizationState &state) const {
1249 OpBuilder::InsertionGuard guard(rewriter);
1250 auto forallOp = cast<ForallOp>(op);
1251 int64_t rank = forallOp.getRank();
1252
1253 // Get buffers for all output operands.
1254 SmallVector<Value> buffers;
1255 for (Value out : forallOp.getOutputs()) {
1256 FailureOr<Value> buffer = getBuffer(rewriter, out, options, state);
1257 if (failed(buffer))
1258 return failure();
1259 buffers.push_back(*buffer);
1260 }
1261
1262 // Use buffers instead of block arguments.
1263 rewriter.setInsertionPointToStart(forallOp.getBody());
1264 for (const auto &it : llvm::zip(
1265 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1266 BlockArgument bbArg = std::get<0>(it);
1267 Value buffer = std::get<1>(it);
1268 Value bufferAsTensor = rewriter.create<ToTensorOp>(
1269 forallOp.getLoc(), bbArg.getType(), buffer);
1270 bbArg.replaceAllUsesWith(bufferAsTensor);
1271 }
1272
1273 // Create new ForallOp without any results and drop the automatically
1274 // introduced terminator.
1275 rewriter.setInsertionPoint(forallOp);
1276 ForallOp newForallOp;
1277 newForallOp = rewriter.create<ForallOp>(
1278 forallOp.getLoc(), forallOp.getMixedLowerBound(),
1279 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1280 /*outputs=*/ValueRange(), forallOp.getMapping());
1281
1282 // Keep discardable attributes from the original op.
1283 newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1284
1285 rewriter.eraseOp(op: newForallOp.getBody()->getTerminator());
1286
1287 // Move over block contents of the old op.
1288 SmallVector<Value> replacementBbArgs;
1289 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1290 newForallOp.getBody()->getArguments().end());
1291 replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1292 rewriter.mergeBlocks(source: forallOp.getBody(), dest: newForallOp.getBody(),
1293 argValues: replacementBbArgs);
1294
1295 // Remove the old op and replace all of its uses.
1296 replaceOpWithBufferizedValues(rewriter, op, values: buffers);
1297
1298 return success();
1299 }
1300
1301 FailureOr<BaseMemRefType>
1302 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1303 const BufferizationState &state,
1304 SmallVector<Value> &invocationStack) const {
1305 auto forallOp = cast<ForallOp>(op);
1306
1307 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value))
1308 // A tensor block argument has the same bufferized type as the
1309 // corresponding output operand.
1310 return bufferization::getBufferType(
1311 value: forallOp.getTiedOpOperand(bbArg)->get(), options, state,
1312 invocationStack);
1313
1314 // The bufferized result type is the same as the bufferized type of the
1315 // corresponding output operand.
1316 return bufferization::getBufferType(
1317 value: forallOp.getOutputs()[cast<OpResult>(Val&: value).getResultNumber()], options,
1318 state, invocationStack);
1319 }
1320
1321 bool isRepetitiveRegion(Operation *op, unsigned index) const {
1322 auto forallOp = cast<ForallOp>(op);
1323
1324 // This op is repetitive if it has 1 or more steps.
1325 // If the control variables are dynamic, it is also considered so.
1326 for (auto [lb, ub, step] :
1327 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1328 forallOp.getMixedStep())) {
1329 std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1330 if (!lbConstant)
1331 return true;
1332
1333 std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1334 if (!ubConstant)
1335 return true;
1336
1337 std::optional<int64_t> stepConstant = getConstantIntValue(step);
1338 if (!stepConstant)
1339 return true;
1340
1341 if (*lbConstant + *stepConstant < *ubConstant)
1342 return true;
1343 }
1344 return false;
1345 }
1346
1347 bool isParallelRegion(Operation *op, unsigned index) const {
1348 return isRepetitiveRegion(op, index);
1349 }
1350};
1351
1352/// Nothing to do for InParallelOp.
1353struct InParallelOpInterface
1354 : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1355 InParallelOp> {
1356 LogicalResult bufferize(Operation *op, RewriterBase &b,
1357 const BufferizationOptions &options,
1358 BufferizationState &state) const {
1359 llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1360 return failure();
1361 }
1362};
1363
1364} // namespace
1365} // namespace scf
1366} // namespace mlir
1367
1368void mlir::scf::registerBufferizableOpInterfaceExternalModels(
1369 DialectRegistry &registry) {
1370 registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1371 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1372 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1373 ForOp::attachInterface<ForOpInterface>(*ctx);
1374 IfOp::attachInterface<IfOpInterface>(*ctx);
1375 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1376 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1377 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1378 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1379 YieldOp::attachInterface<YieldOpInterface>(*ctx);
1380 });
1381}
1382

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp