1//===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10#include "mlir/Dialect/Arith/Utils/Utils.h"
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/IR/AsmState.h"
16#include "mlir/IR/Operation.h"
17#include "mlir/IR/TypeUtilities.h"
18#include "mlir/IR/Value.h"
19#include "mlir/Interfaces/ControlFlowInterfaces.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/Support/Debug.h"
22
23//===----------------------------------------------------------------------===//
24// BufferizableOpInterface
25//===----------------------------------------------------------------------===//
26
27namespace mlir {
28namespace bufferization {
29
30#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
31
32} // namespace bufferization
33} // namespace mlir
34
35MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
36
37#define DEBUG_TYPE "bufferizable-op-interface"
38#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
39#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
40
41using namespace mlir;
42using namespace bufferization;
43
44static bool isRepetitiveRegion(Region *region,
45 const BufferizationOptions &options) {
46 Operation *op = region->getParentOp();
47 if (auto bufferizableOp = options.dynCastBufferizableOp(op))
48 if (bufferizableOp.isRepetitiveRegion(index: region->getRegionNumber()))
49 return true;
50 return false;
51}
52
53Region *AnalysisState::getEnclosingRepetitiveRegion(
54 Operation *op, const BufferizationOptions &options) {
55 if (!op->getBlock())
56 return nullptr;
57 if (auto iter = enclosingRepetitiveRegionCache.find_as(Val: op);
58 iter != enclosingRepetitiveRegionCache.end())
59 return iter->second;
60 return enclosingRepetitiveRegionCache[op] =
61 getEnclosingRepetitiveRegion(block: op->getBlock(), options);
62}
63
64Region *AnalysisState::getEnclosingRepetitiveRegion(
65 Value value, const BufferizationOptions &options) {
66 if (auto iter = enclosingRepetitiveRegionCache.find_as(Val: value);
67 iter != enclosingRepetitiveRegionCache.end())
68 return iter->second;
69
70 Region *region = value.getParentRegion();
71 // Collect all visited regions since we only know the repetitive region we
72 // want to map it to later on
73 SmallVector<Region *> visitedRegions;
74 while (region) {
75 visitedRegions.push_back(Elt: region);
76 if (isRepetitiveRegion(region, options))
77 break;
78 region = region->getParentRegion();
79 }
80 enclosingRepetitiveRegionCache[value] = region;
81 for (Region *r : visitedRegions)
82 enclosingRepetitiveRegionCache[r] = region;
83 return region;
84}
85
86Region *AnalysisState::getEnclosingRepetitiveRegion(
87 Block *block, const BufferizationOptions &options) {
88 if (auto iter = enclosingRepetitiveRegionCache.find_as(Val: block);
89 iter != enclosingRepetitiveRegionCache.end())
90 return iter->second;
91
92 Region *region = block->getParent();
93 Operation *op = nullptr;
94 // Collect all visited regions since we only know the repetitive region we
95 // want to map it to later on
96 SmallVector<Region *> visitedRegions;
97 do {
98 op = region->getParentOp();
99 if (isRepetitiveRegion(region, options))
100 break;
101 } while ((region = op->getParentRegion()));
102
103 enclosingRepetitiveRegionCache[block] = region;
104 for (Region *r : visitedRegions)
105 enclosingRepetitiveRegionCache[r] = region;
106 return region;
107}
108
109bool AnalysisState::insideMutuallyExclusiveRegions(Operation *op0,
110 Operation *op1) {
111 auto key = std::make_pair(x&: op0, y&: op1);
112 if (auto iter = insideMutuallyExclusiveRegionsCache.find(Val: key);
113 iter != insideMutuallyExclusiveRegionsCache.end())
114 return iter->second;
115 bool result = ::mlir::insideMutuallyExclusiveRegions(a: op0, b: op1);
116 // Populate results for both orderings of the ops.
117 insideMutuallyExclusiveRegionsCache[key] = result;
118 insideMutuallyExclusiveRegionsCache[std::make_pair(x&: op1, y&: op0)] = result;
119 return result;
120}
121
122void AnalysisState::resetCache() {
123 enclosingRepetitiveRegionCache.clear();
124 insideMutuallyExclusiveRegionsCache.clear();
125}
126
127SymbolTableCollection &BufferizationState::getSymbolTables() {
128 return symbolTables;
129}
130
131Region *bufferization::getNextEnclosingRepetitiveRegion(
132 Region *region, const BufferizationOptions &options) {
133 assert(isRepetitiveRegion(region, options) && "expected repetitive region");
134 while ((region = region->getParentRegion())) {
135 if (isRepetitiveRegion(region, options))
136 break;
137 }
138 return region;
139}
140
141Region *bufferization::getParallelRegion(Region *region,
142 const BufferizationOptions &options) {
143 while (region) {
144 auto bufferizableOp = options.dynCastBufferizableOp(op: region->getParentOp());
145 if (bufferizableOp &&
146 bufferizableOp.isParallelRegion(index: region->getRegionNumber())) {
147 assert(isRepetitiveRegion(region, options) &&
148 "expected that all parallel regions are also repetitive regions");
149 return region;
150 }
151 region = region->getParentRegion();
152 }
153 return nullptr;
154}
155
156Operation *bufferization::getOwnerOfValue(Value value) {
157 if (auto opResult = llvm::dyn_cast<OpResult>(Val&: value))
158 return opResult.getDefiningOp();
159 return llvm::cast<BlockArgument>(Val&: value).getOwner()->getParentOp();
160}
161
162/// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
163/// shaped value is copied. Otherwise, a tensor with undefined contents is
164/// allocated.
165FailureOr<Value> bufferization::allocateTensorForShapedValue(
166 OpBuilder &b, Location loc, Value shapedValue,
167 const BufferizationOptions &options, const BufferizationState &state,
168 bool copy) {
169 Value tensor;
170 if (llvm::isa<RankedTensorType>(Val: shapedValue.getType())) {
171 tensor = shapedValue;
172 } else if (llvm::isa<MemRefType>(Val: shapedValue.getType())) {
173 tensor = b.create<ToTensorOp>(
174 location: loc, args: memref::getTensorTypeFromMemRefType(type: shapedValue.getType()),
175 args&: shapedValue);
176 } else if (llvm::isa<UnrankedTensorType>(Val: shapedValue.getType()) ||
177 llvm::isa<UnrankedMemRefType>(Val: shapedValue.getType())) {
178 return getOwnerOfValue(value: shapedValue)
179 ->emitError(message: "copying of unranked tensors is not implemented");
180 } else {
181 llvm_unreachable("expected RankedTensorType or MemRefType");
182 }
183 RankedTensorType tensorType = llvm::cast<RankedTensorType>(Val: tensor.getType());
184 SmallVector<Value> dynamicSizes;
185 if (!copy) {
186 // Compute the dynamic part of the shape.
187 // First try to query the shape via ReifyRankedShapedTypeOpInterface.
188 bool reifiedShapes = false;
189 if (llvm::isa<RankedTensorType>(Val: shapedValue.getType()) &&
190 llvm::isa<OpResult>(Val: shapedValue)) {
191 ReifiedRankedShapedTypeDims resultDims;
192 if (succeeded(
193 Result: reifyResultShapes(b, op: shapedValue.getDefiningOp(), reifiedReturnShapes&: resultDims))) {
194 reifiedShapes = true;
195 auto &shape =
196 resultDims[llvm::cast<OpResult>(Val&: shapedValue).getResultNumber()];
197 for (const auto &dim : enumerate(First: tensorType.getShape())) {
198 if (ShapedType::isDynamic(dValue: dim.value())) {
199 dynamicSizes.push_back(
200 Elt: getValueOrCreateConstantIndexOp(b, loc, ofr: shape[dim.index()]));
201 }
202 }
203 }
204 }
205
206 // If the shape could not be reified, create DimOps.
207 if (!reifiedShapes)
208 populateDynamicDimSizes(b, loc, shapedValue: tensor, dynamicDims&: dynamicSizes);
209 }
210
211 // Create AllocTensorOp.
212 auto allocTensorOp = b.create<AllocTensorOp>(location: loc, args&: tensorType, args&: dynamicSizes,
213 args: copy ? tensor : Value());
214
215 // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
216 if (copy)
217 return allocTensorOp.getResult();
218 auto copyBufferType =
219 detail::asMemRefType(bufferType: getBufferType(value: tensor, options, state));
220 if (failed(Result: copyBufferType))
221 return failure();
222 std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
223 if (!memorySpace)
224 memorySpace = options.defaultMemorySpaceFn(tensorType);
225 if (memorySpace.has_value())
226 allocTensorOp.setMemorySpaceAttr(memorySpace.value());
227 return allocTensorOp.getResult();
228}
229
230LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
231 RewriterBase &rewriter, const AnalysisState &analysisState,
232 const BufferizationState &bufferizationState) {
233 OpBuilder::InsertionGuard g(rewriter);
234 Operation *op = getOperation();
235 SmallVector<OpOperand *> outOfPlaceOpOperands;
236 DenseSet<OpOperand *> copiedOpOperands;
237 SmallVector<Value> outOfPlaceValues;
238 DenseSet<Value> copiedOpValues;
239
240 // Find all out-of-place OpOperands.
241 for (OpOperand &opOperand : op->getOpOperands()) {
242 Type operandType = opOperand.get().getType();
243 if (!llvm::isa<TensorType>(Val: operandType))
244 continue;
245 if (analysisState.isInPlace(opOperand))
246 continue;
247 if (llvm::isa<UnrankedTensorType>(Val: operandType))
248 return op->emitError(message: "copying of unranked tensors is not implemented");
249
250 AliasingValueList aliasingValues =
251 analysisState.getAliasingValues(opOperand);
252 if (aliasingValues.getNumAliases() == 1 &&
253 isa<OpResult>(Val: aliasingValues.getAliases()[0].value) &&
254 !analysisState.bufferizesToMemoryWrite(opOperand) &&
255 analysisState
256 .getAliasingOpOperands(value: aliasingValues.getAliases()[0].value)
257 .getNumAliases() == 1 &&
258 !isa<UnrankedTensorType>(
259 Val: aliasingValues.getAliases()[0].value.getType())) {
260 // The op itself does not write but may create exactly one alias. Instead
261 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
262 // be smaller than the OpOperand (e.g., in the case of an extract_slice,
263 // where the result is usually a smaller part of the source). Do not apply
264 // this optimization if the OpResult is an unranked tensor (because those
265 // cannot be copied at the moment).
266 Value value = aliasingValues.getAliases()[0].value;
267 outOfPlaceValues.push_back(Elt: value);
268 if (!analysisState.canOmitTensorCopy(opOperand))
269 copiedOpValues.insert(V: value);
270 } else {
271 // In all other cases, make a copy of the OpOperand.
272 outOfPlaceOpOperands.push_back(Elt: &opOperand);
273 if (!analysisState.canOmitTensorCopy(opOperand))
274 copiedOpOperands.insert(V: &opOperand);
275 }
276 }
277
278 // Insert copies of OpOperands.
279 rewriter.setInsertionPoint(op);
280 for (OpOperand *opOperand : outOfPlaceOpOperands) {
281 FailureOr<Value> copy = allocateTensorForShapedValue(
282 b&: rewriter, loc: op->getLoc(), shapedValue: opOperand->get(), options: analysisState.getOptions(),
283 state: bufferizationState, copy: copiedOpOperands.contains(V: opOperand));
284 if (failed(Result: copy))
285 return failure();
286 rewriter.modifyOpInPlace(root: op, callable: [&]() { opOperand->set(*copy); });
287 }
288
289 // Insert copies of Values.
290 rewriter.setInsertionPointAfter(op);
291 for (Value value : outOfPlaceValues) {
292 FailureOr<Value> copy = allocateTensorForShapedValue(
293 b&: rewriter, loc: op->getLoc(), shapedValue: value, options: analysisState.getOptions(),
294 state: bufferizationState, copy: copiedOpValues.count(V: value));
295 if (failed(Result: copy))
296 return failure();
297 SmallVector<OpOperand *> uses = llvm::to_vector(
298 Range: llvm::map_range(C: value.getUses(), F: [](OpOperand &use) { return &use; }));
299 for (OpOperand *use : uses) {
300 // Do not update the alloc_tensor op that we just created.
301 if (use->getOwner() == copy->getDefiningOp())
302 continue;
303 // tensor.dim ops may have been created to be used as alloc_tensor op
304 // dynamic extents. Do not update these either.
305 if (isa<tensor::DimOp>(Val: use->getOwner()))
306 continue;
307 rewriter.modifyOpInPlace(root: use->getOwner(), callable: [&]() { use->set(*copy); });
308 }
309 }
310
311 return success();
312}
313
314//===----------------------------------------------------------------------===//
315// OpFilter
316//===----------------------------------------------------------------------===//
317
318bool OpFilter::isOpAllowed(Operation *op) const {
319 // All other ops: Allow/disallow according to filter.
320 bool isAllowed = !hasAllowRule();
321 for (const Entry &entry : entries) {
322 bool filterResult = entry.fn(op);
323 switch (entry.type) {
324 case Entry::ALLOW:
325 isAllowed |= filterResult;
326 break;
327 case Entry::DENY:
328 if (filterResult)
329 // DENY filter matches. This op is no allowed. (Even if other ALLOW
330 // filters may match.)
331 return false;
332 };
333 }
334 return isAllowed;
335}
336
337//===----------------------------------------------------------------------===//
338// BufferizationOptions
339//===----------------------------------------------------------------------===//
340
341namespace {
342
343/// Default function arg type converter: Use a fully dynamic layout map.
344BaseMemRefType
345defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
346 func::FuncOp funcOp,
347 const BufferizationOptions &options) {
348 return getMemRefTypeWithFullyDynamicLayout(tensorType: type, memorySpace);
349}
350/// Default unknown type converter: Use a fully dynamic layout map.
351BaseMemRefType
352defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
353 const BufferizationOptions &options) {
354 return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
355}
356
357} // namespace
358
359// Default constructor for BufferizationOptions.
360BufferizationOptions::BufferizationOptions()
361 : functionArgTypeConverterFn(defaultFunctionArgTypeConverter),
362 unknownTypeConverterFn(defaultUnknownTypeConverter) {}
363
364bool BufferizationOptions::isOpAllowed(Operation *op) const {
365 // Special case: If function boundary bufferization is deactivated, do not
366 // allow ops that belong to the `func` dialect.
367 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(Val: op->getDialect());
368 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
369 return false;
370
371 return opFilter.isOpAllowed(op);
372}
373
374BufferizableOpInterface
375BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
376 if (!isOpAllowed(op))
377 return nullptr;
378 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(Val: op);
379 if (!bufferizableOp)
380 return nullptr;
381 return bufferizableOp;
382}
383
384BufferizableOpInterface
385BufferizationOptions::dynCastBufferizableOp(Value value) const {
386 return dynCastBufferizableOp(op: getOwnerOfValue(value));
387}
388
389void BufferizationOptions::setFunctionBoundaryTypeConversion(
390 LayoutMapOption layoutMapOption) {
391 functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
392 func::FuncOp funcOp,
393 const BufferizationOptions &options) {
394 if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
395 return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
396 memorySpace);
397 return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
398 memorySpace);
399 };
400 inferFunctionResultLayout =
401 layoutMapOption == LayoutMapOption::InferLayoutMap;
402}
403
404//===----------------------------------------------------------------------===//
405// Helper functions for BufferizableOpInterface
406//===----------------------------------------------------------------------===//
407
408static void setInsertionPointAfter(OpBuilder &b, Value value) {
409 if (auto bbArg = llvm::dyn_cast<BlockArgument>(Val&: value)) {
410 b.setInsertionPointToStart(bbArg.getOwner());
411 } else {
412 b.setInsertionPointAfter(value.getDefiningOp());
413 }
414}
415
416/// Determine which OpOperand* will alias with `value` if the op is bufferized
417/// in place. Return all tensor OpOperand* if the op is not bufferizable.
418AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const {
419 if (Operation *op = getOwnerOfValue(value))
420 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
421 return bufferizableOp.getAliasingOpOperands(value, state: *this);
422
423 // The op is not bufferizable.
424 return detail::unknownGetAliasingOpOperands(value);
425}
426
427/// Determine which Values will alias with `opOperand` if the op is bufferized
428/// in place. Return all tensor Values if the op is not bufferizable.
429AliasingValueList AnalysisState::getAliasingValues(OpOperand &opOperand) const {
430 if (auto bufferizableOp =
431 getOptions().dynCastBufferizableOp(op: opOperand.getOwner()))
432 return bufferizableOp.getAliasingValues(opOperand, state: *this);
433
434 // The op is not bufferizable.
435 return detail::unknownGetAliasingValues(opOperand);
436}
437
438/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
439/// op is not bufferizable.
440bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
441 if (auto bufferizableOp =
442 getOptions().dynCastBufferizableOp(op: opOperand.getOwner()))
443 return bufferizableOp.bufferizesToMemoryRead(opOperand, state: *this);
444
445 // Unknown op that returns a tensor. The inplace analysis does not support it.
446 // Conservatively return true.
447 return true;
448}
449
450/// Return true if `opOperand` bufferizes to a memory write. Return
451/// `true` if the op is not bufferizable.
452bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
453 if (auto bufferizableOp =
454 getOptions().dynCastBufferizableOp(op: opOperand.getOwner()))
455 return bufferizableOp.bufferizesToMemoryWrite(opOperand, state: *this);
456
457 // Unknown op that returns a tensor. The inplace analysis does not support it.
458 // Conservatively return true.
459 return true;
460}
461
462/// Return true if `opOperand` does neither read nor write but bufferizes to an
463/// alias. Return false if the op is not bufferizable.
464bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
465 if (auto bufferizableOp =
466 getOptions().dynCastBufferizableOp(op: opOperand.getOwner()))
467 return bufferizableOp.bufferizesToAliasOnly(opOperand, state: *this);
468
469 // Unknown op that returns a tensor. The inplace analysis does not support it.
470 // Conservatively return false.
471 return false;
472}
473
474bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
475 auto opResult = llvm::dyn_cast<OpResult>(Val&: value);
476 if (!opResult)
477 return true;
478 auto bufferizableOp = getOptions().dynCastBufferizableOp(value);
479 if (!bufferizableOp)
480 return true;
481 return bufferizableOp.resultBufferizesToMemoryWrite(opResult, state: *this);
482}
483
484/// Return true if the given value is read by an op that bufferizes to a memory
485/// read. Also takes into account ops that create an alias but do not read by
486/// themselves (e.g., ExtractSliceOp).
487bool AnalysisState::isValueRead(Value value) const {
488 assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
489 SmallVector<OpOperand *> workingSet;
490 DenseSet<OpOperand *> visited;
491 for (OpOperand &use : value.getUses())
492 workingSet.push_back(Elt: &use);
493
494 while (!workingSet.empty()) {
495 OpOperand *uMaybeReading = workingSet.pop_back_val();
496 if (!visited.insert(V: uMaybeReading).second)
497 continue;
498
499 // Skip over all ops that neither read nor write (but create an alias).
500 if (bufferizesToAliasOnly(opOperand&: *uMaybeReading))
501 for (AliasingValue alias : getAliasingValues(opOperand&: *uMaybeReading))
502 for (OpOperand &use : alias.value.getUses())
503 workingSet.push_back(Elt: &use);
504 if (bufferizesToMemoryRead(opOperand&: *uMaybeReading))
505 return true;
506 }
507
508 return false;
509}
510
511// Starting from `opOperand`, follow the use-def chain in reverse, always
512// selecting the aliasing OpOperands. Find and return Values for which
513// `condition` evaluates to true. Uses of such matching Values are not
514// traversed any further, the visited aliasing opOperands will be preserved
515// through `visitedOpOperands`.
516llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
517 OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
518 TraversalConfig config,
519 llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
520 llvm::DenseSet<Value> visited;
521 llvm::SetVector<Value> result, workingSet;
522 workingSet.insert(X: opOperand->get());
523
524 if (visitedOpOperands)
525 visitedOpOperands->insert(V: opOperand);
526
527 while (!workingSet.empty()) {
528 Value value = workingSet.pop_back_val();
529
530 if (!config.revisitAlreadyVisitedValues && visited.contains(V: value)) {
531 // Stop traversal if value was already visited.
532 if (config.alwaysIncludeLeaves)
533 result.insert(X: value);
534 continue;
535 }
536 visited.insert(V: value);
537
538 if (condition(value)) {
539 result.insert(X: value);
540 continue;
541 }
542
543 if (!config.followUnknownOps && !options.dynCastBufferizableOp(value)) {
544 // Stop iterating if `followUnknownOps` is unset and the op is either
545 // not bufferizable or excluded in the OpFilter.
546 if (config.alwaysIncludeLeaves)
547 result.insert(X: value);
548 continue;
549 }
550
551 AliasingOpOperandList aliases = getAliasingOpOperands(value);
552 if (aliases.getNumAliases() == 0) {
553 // The traversal ends naturally if there are no more OpOperands that
554 // could be followed.
555 if (config.alwaysIncludeLeaves)
556 result.insert(X: value);
557 continue;
558 }
559
560 for (AliasingOpOperand a : aliases) {
561 if (config.followEquivalentOnly &&
562 a.relation != BufferRelation::Equivalent) {
563 // Stop iterating if `followEquivalentOnly` is set but the alias is not
564 // equivalent.
565 if (config.alwaysIncludeLeaves)
566 result.insert(X: value);
567 continue;
568 }
569
570 if (config.followInPlaceOnly && !isInPlace(opOperand&: *a.opOperand)) {
571 // Stop iterating if `followInPlaceOnly` is set but the alias is
572 // out-of-place.
573 if (config.alwaysIncludeLeaves)
574 result.insert(X: value);
575 continue;
576 }
577
578 if (config.followSameTypeOrCastsOnly &&
579 a.opOperand->get().getType() != value.getType() &&
580 !value.getDefiningOp<CastOpInterface>()) {
581 // Stop iterating if `followSameTypeOrCastsOnly` is set but the alias is
582 // has a different type and the op is not a cast.
583 if (config.alwaysIncludeLeaves)
584 result.insert(X: value);
585 continue;
586 }
587
588 workingSet.insert(X: a.opOperand->get());
589 if (visitedOpOperands)
590 visitedOpOperands->insert(V: a.opOperand);
591 }
592 }
593
594 return result;
595}
596
597// Find the values that define the contents of the given operand's value.
598llvm::SetVector<Value>
599AnalysisState::findDefinitions(OpOperand *opOperand) const {
600 TraversalConfig config;
601 config.alwaysIncludeLeaves = false;
602 return findValueInReverseUseDefChain(
603 opOperand, condition: [&](Value v) { return this->bufferizesToMemoryWrite(value: v); },
604 config);
605}
606
607AnalysisState::AnalysisState(const BufferizationOptions &options)
608 : AnalysisState(options, TypeID::get<AnalysisState>()) {}
609
610AnalysisState::AnalysisState(const BufferizationOptions &options, TypeID type)
611 : options(options), type(type) {
612 for (const BufferizationOptions::AnalysisStateInitFn &fn :
613 options.stateInitializers)
614 fn(*this);
615}
616
617bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
618 // Do not copy if the tensor has undefined contents.
619 if (hasUndefinedContents(opOperand: &opOperand))
620 return true;
621
622 // Do not copy if the buffer of the tensor is entirely overwritten (with
623 // values that do not depend on the old tensor).
624 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
625 return true;
626
627 // Do not copy if the tensor is never read.
628 AliasingValueList aliases = getAliasingValues(opOperand);
629 if (!bufferizesToMemoryRead(opOperand) &&
630 llvm::none_of(Range&: aliases,
631 P: [&](AliasingValue a) { return isValueRead(value: a.value); }))
632 return true;
633
634 // Default: Cannot omit the copy.
635 return false;
636}
637
638bool AnalysisState::isInPlace(OpOperand &opOperand) const {
639 // ToBufferOps are always in-place.
640 if (isa<ToBufferOp>(Val: opOperand.getOwner()))
641 return true;
642
643 // In the absence of analysis information, OpOperands that bufferize to a
644 // memory write are out-of-place, i.e., an alloc and copy is inserted.
645 return !bufferizesToMemoryWrite(opOperand);
646}
647
648bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
649 // In the absence of analysis information, we do not know if the values are
650 // equivalent. The conservative answer is "false".
651 return false;
652}
653
654bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
655 // In the absence of analysis information, we do not know if the values may be
656 // aliasing. The conservative answer is "true".
657 return true;
658}
659
660bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
661 // In the absence of analysis information, the conservative answer is "false".
662 return false;
663}
664
665// bufferization.to_buffer is not allowed to change the rank.
666static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
667#ifndef NDEBUG
668 auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
669 assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
670 rankedTensorType.getRank()) &&
671 "to_buffer would be invalid: mismatching ranks");
672#endif
673}
674
675FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
676 const BufferizationOptions &options,
677 const BufferizationState &state) {
678#ifndef NDEBUG
679 auto tensorType = llvm::dyn_cast<TensorLikeType>(value.getType());
680 assert(tensorType && "unexpected non-tensor type");
681#endif // NDEBUG
682
683 // Replace "%t = to_tensor %m" with %m.
684 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
685 return toTensorOp.getBuffer();
686
687 // Insert to_buffer op.
688 OpBuilder::InsertionGuard g(rewriter);
689 setInsertionPointAfter(b&: rewriter, value);
690 FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
691 if (failed(Result: bufferType))
692 return failure();
693 ensureToBufferOpIsValid(tensor: value, memrefType: *bufferType);
694 return rewriter
695 .create<bufferization::ToBufferOp>(location: value.getLoc(), args&: *bufferType, args&: value)
696 .getResult();
697}
698
699/// Return the buffer type for a given Value (tensor) after bufferization.
700FailureOr<BufferLikeType>
701bufferization::getBufferType(Value value, const BufferizationOptions &options,
702 const BufferizationState &state) {
703 SmallVector<Value> invocationStack;
704 return getBufferType(value, options, state, invocationStack);
705}
706
707/// Return the buffer type for a given Value (tensor) after bufferization.
708FailureOr<BufferLikeType>
709bufferization::getBufferType(Value value, const BufferizationOptions &options,
710 const BufferizationState &state,
711 SmallVector<Value> &invocationStack) {
712 assert(llvm::isa<TensorLikeType>(value.getType()) &&
713 "unexpected non-tensor type");
714 invocationStack.push_back(Elt: value);
715 auto popFromStack =
716 llvm::make_scope_exit(F: [&]() { invocationStack.pop_back(); });
717
718 // Try querying BufferizableOpInterface.
719 Operation *op = getOwnerOfValue(value);
720 auto bufferizableOp = options.dynCastBufferizableOp(op);
721 if (bufferizableOp)
722 return bufferizableOp.getBufferType(value, options, state, invocationStack);
723
724 // Op is not bufferizable.
725 return cast<TensorLikeType>(Val: value.getType()).getBufferType(options, emitError: [&]() {
726 return op->emitError();
727 });
728}
729
730bool bufferization::hasTensorSemantics(Operation *op) {
731 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(Val: op))
732 return bufferizableOp.hasTensorSemantics();
733 return detail::defaultHasTensorSemantics(op);
734}
735
736void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
737 Operation *op,
738 ValueRange values) {
739 assert(values.size() == op->getNumResults() &&
740 "expected one value per OpResult");
741 OpBuilder::InsertionGuard g(rewriter);
742
743 // Replace all OpResults with the given values.
744 SmallVector<Value> replacements;
745 for (OpResult opResult : op->getOpResults()) {
746 Value replacement = values[opResult.getResultNumber()];
747 if (llvm::isa<TensorLikeType>(Val: opResult.getType())) {
748 // The OpResult is a tensor. Such values are replaced with memrefs during
749 // bufferization.
750 assert(llvm::isa<BufferLikeType>(replacement.getType()) &&
751 "tensor op result should be replaced with a buffer value");
752 // The existing uses of the OpResult still expect a tensor. Insert a
753 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
754 // loose all of its users and eventually DCE away.
755 rewriter.setInsertionPointAfter(op);
756 replacement = rewriter.create<bufferization::ToTensorOp>(
757 location: replacement.getLoc(), args: opResult.getType(), args&: replacement);
758 }
759 replacements.push_back(Elt: replacement);
760 }
761
762 rewriter.replaceOp(op, newValues: replacements);
763}
764
765//===----------------------------------------------------------------------===//
766// Bufferization-specific scoped alloc insertion support.
767//===----------------------------------------------------------------------===//
768
769/// Create a memref allocation with the given type and dynamic extents.
770FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
771 MemRefType type,
772 ValueRange dynShape) const {
773 if (allocationFn)
774 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
775
776 // Default bufferallocation via AllocOp.
777 if (bufferAlignment != 0)
778 return b
779 .create<memref::AllocOp>(location: loc, args&: type, args&: dynShape,
780 args: b.getI64IntegerAttr(value: bufferAlignment))
781 .getResult();
782 return b.create<memref::AllocOp>(location: loc, args&: type, args&: dynShape).getResult();
783}
784
785/// Create a memory copy between two memref buffers.
786LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
787 Value from, Value to) const {
788 if (memCpyFn)
789 return (*memCpyFn)(b, loc, from, to);
790
791 b.create<memref::CopyOp>(location: loc, args&: from, args&: to);
792 return success();
793}
794
795//===----------------------------------------------------------------------===//
796// Bufferization-specific IRMapping support with debugging.
797//===----------------------------------------------------------------------===//
798
799BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
800 const BufferizationOptions &options,
801 MemRefLayoutAttrInterface layout,
802 Attribute memorySpace) {
803 // Case 1: Unranked memref type.
804 if (auto unrankedTensorType =
805 llvm::dyn_cast<UnrankedTensorType>(Val&: tensorType)) {
806 assert(!layout && "UnrankedTensorType cannot have a layout map");
807 return UnrankedMemRefType::get(elementType: unrankedTensorType.getElementType(),
808 memorySpace);
809 }
810
811 // Case 2: Ranked memref type with specified layout.
812 auto rankedTensorType = llvm::cast<RankedTensorType>(Val&: tensorType);
813 if (layout) {
814 return MemRefType::get(shape: rankedTensorType.getShape(),
815 elementType: rankedTensorType.getElementType(), layout,
816 memorySpace);
817 }
818
819 return options.unknownTypeConverterFn(tensorType, memorySpace, options);
820}
821
822BaseMemRefType
823bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
824 Attribute memorySpace) {
825 // Case 1: Unranked memref type.
826 if (auto unrankedTensorType =
827 llvm::dyn_cast<UnrankedTensorType>(Val&: tensorType)) {
828 return UnrankedMemRefType::get(elementType: unrankedTensorType.getElementType(),
829 memorySpace);
830 }
831
832 // Case 2: Ranked memref type.
833 auto rankedTensorType = llvm::cast<RankedTensorType>(Val&: tensorType);
834 int64_t dynamicOffset = ShapedType::kDynamic;
835 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
836 ShapedType::kDynamic);
837 auto stridedLayout = StridedLayoutAttr::get(context: tensorType.getContext(),
838 offset: dynamicOffset, strides: dynamicStrides);
839 return MemRefType::get(shape: rankedTensorType.getShape(),
840 elementType: rankedTensorType.getElementType(), layout: stridedLayout,
841 memorySpace);
842}
843
844/// Return a MemRef type with a static identity layout (i.e., no layout map). If
845/// the given tensor type is unranked, return an unranked MemRef type.
846BaseMemRefType
847bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
848 Attribute memorySpace) {
849 // Case 1: Unranked memref type.
850 if (auto unrankedTensorType =
851 llvm::dyn_cast<UnrankedTensorType>(Val&: tensorType)) {
852 return UnrankedMemRefType::get(elementType: unrankedTensorType.getElementType(),
853 memorySpace);
854 }
855
856 // Case 2: Ranked memref type.
857 auto rankedTensorType = llvm::cast<RankedTensorType>(Val&: tensorType);
858 MemRefLayoutAttrInterface layout = {};
859 return MemRefType::get(shape: rankedTensorType.getShape(),
860 elementType: rankedTensorType.getElementType(), layout,
861 memorySpace);
862}
863
864//===----------------------------------------------------------------------===//
865// Default implementations of interface methods
866//===----------------------------------------------------------------------===//
867
868bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
869 OpResult opResult, const AnalysisState &state) {
870 auto bufferizableOp = cast<BufferizableOpInterface>(Val: opResult.getDefiningOp());
871 AliasingOpOperandList opOperands =
872 bufferizableOp.getAliasingOpOperands(value: opResult, state);
873
874 // Case 1: OpResults that have no aliasing OpOperand usually bufferize to
875 // memory writes.
876 if (opOperands.getAliases().empty())
877 return true;
878
879 // Case 2: If an aliasing OpOperand bufferizes to a memory write, the OpResult
880 // may bufferize to a memory write.
881 if (llvm::any_of(Range&: opOperands, P: [&](AliasingOpOperand alias) {
882 return state.bufferizesToMemoryWrite(opOperand&: *alias.opOperand);
883 }))
884 return true;
885
886 // Case 3: Check if a nested aliasing OpOperand value bufferizes to a memory
887 // write. (Or: The reverse SSA use-def chain ends inside the reigon.) In that
888 // case, the OpResult bufferizes to a memory write. E.g.:
889 //
890 // %0 = "some_writing_op" : tensor<?xf32>
891 // %r = scf.if ... -> tensor<?xf32> {
892 // scf.yield %0 : tensor<?xf32>
893 // } else {
894 // %1 = "another_writing_op"(%0) : tensor<?xf32>
895 // scf.yield %1 : tensor<?xf32>
896 // }
897 // "some_reading_op"(%r)
898 //
899 // %r bufferizes to a memory write because an aliasing OpOperand value (%1)
900 // bufferizes to a memory write and the defining op is inside the scf.if.
901 //
902 // Note: This treatment of surrouding ops is useful for ops that have a
903 // region but no OpOperand such as scf.if or scf.execute_region. It simplifies
904 // the analysis considerably.
905 //
906 // "another_writing_op" in the above example should be able to bufferize
907 // inplace in the absence of another read of %0. However, if the scf.if op
908 // would not be considered a "write", the analysis would detect the
909 // following conflict:
910 //
911 // * read = some_reading_op
912 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.)
913 // * conflictingWrite = %1
914 //
915 auto isMemoryWriteInsideOp = [&](Value v) {
916 Operation *op = getOwnerOfValue(value: v);
917 if (!opResult.getDefiningOp()->isAncestor(other: op))
918 return false;
919 return state.bufferizesToMemoryWrite(value: v);
920 };
921 TraversalConfig config;
922 config.alwaysIncludeLeaves = false;
923 for (AliasingOpOperand alias : opOperands) {
924 if (!state
925 .findValueInReverseUseDefChain(opOperand: alias.opOperand,
926 condition: isMemoryWriteInsideOp, config)
927 .empty())
928 return true;
929 }
930 return false;
931}
932
933// Compute the AliasingOpOperandList for a given Value based on
934// getAliasingValues.
935AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
936 Value value, const AnalysisState &state) {
937 Operation *op = getOwnerOfValue(value);
938 SmallVector<AliasingOpOperand> result;
939 for (OpOperand &opOperand : op->getOpOperands()) {
940 if (!llvm::isa<TensorType>(Val: opOperand.get().getType()))
941 continue;
942 AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
943 for (const auto &it : aliasingValues)
944 if (it.value == value)
945 result.emplace_back(Args: &opOperand, Args: it.relation, Args: it.isDefinite);
946 }
947 return AliasingOpOperandList(std::move(result));
948}
949
950FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
951 Value value, const BufferizationOptions &options,
952 const BufferizationState &bufferizationState,
953 SmallVector<Value> &invocationStack) {
954 assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
955 auto tensorType = cast<TensorType>(Val: value.getType());
956
957 // No further analysis is possible for a block argument.
958 if (llvm::isa<BlockArgument>(Val: value)) {
959 return cast<BufferLikeType>(
960 Val: bufferization::getMemRefType(tensorType, options));
961 }
962
963 // Value is an OpResult.
964 Operation *op = getOwnerOfValue(value);
965 auto opResult = llvm::cast<OpResult>(Val&: value);
966 AnalysisState analysisState(options);
967 AliasingOpOperandList aliases = analysisState.getAliasingOpOperands(value: opResult);
968 if (aliases.getNumAliases() > 0 &&
969 aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
970 // If the OpResult has an equivalent OpOperand, both OpResult and
971 // OpOperand bufferize to the exact same buffer type.
972 Value equivalentOperand = aliases.getAliases().front().opOperand->get();
973 return getBufferType(value: equivalentOperand, options, state: bufferizationState,
974 invocationStack);
975 }
976
977 // If we do not know the memory space and there is no default memory space,
978 // report a failure.
979 auto memSpace =
980 options.defaultMemorySpaceFn(cast<TensorType>(Val: value.getType()));
981 if (!memSpace.has_value())
982 return op->emitError(message: "could not infer memory space");
983
984 return cast<BufferLikeType>(
985 Val: getMemRefType(tensorType, options, /*layout=*/{}, memorySpace: *memSpace));
986}
987
988bool bufferization::detail::defaultIsRepetitiveRegion(
989 BufferizableOpInterface bufferizableOp, unsigned index) {
990 assert(index < bufferizableOp->getNumRegions() && "invalid region index");
991 auto regionInterface =
992 dyn_cast<RegionBranchOpInterface>(Val: bufferizableOp.getOperation());
993 if (!regionInterface)
994 return false;
995 return regionInterface.isRepetitiveRegion(index);
996}
997
998AliasingOpOperandList
999bufferization::detail::unknownGetAliasingOpOperands(Value value) {
1000 // TODO: Take into account successor blocks.
1001 // No aliasing in case of non-entry blocks.
1002 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value))
1003 if (bbArg.getOwner() != &bbArg.getOwner()->getParent()->getBlocks().front())
1004 return {};
1005
1006 // Unknown op: Conservatively assume that each OpResult may alias with every
1007 // OpOperand. In addition, each block argument of an entry block may alias
1008 // with every OpOperand.
1009 AliasingOpOperandList r;
1010 for (OpOperand &operand : value.getDefiningOp()->getOpOperands())
1011 if (isa<TensorType>(Val: operand.get().getType()))
1012 r.addAlias(alias: {&operand, BufferRelation::Unknown, /*isDefinite=*/false});
1013 return r;
1014}
1015
1016AliasingValueList
1017bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
1018 // TODO: Take into account successor blocks.
1019 // Unknown op: Conservatively assume that each OpResult may alias with every
1020 // OpOperand. In addition, each block argument of an entry block may alias
1021 // with every OpOperand.
1022 AliasingValueList r;
1023 for (OpResult result : opOperand.getOwner()->getOpResults())
1024 if (llvm::isa<TensorType>(Val: result.getType()))
1025 r.addAlias(alias: {result, BufferRelation::Unknown, /*isDefinite=*/false});
1026 for (Region &region : opOperand.getOwner()->getRegions())
1027 if (!region.getBlocks().empty())
1028 for (BlockArgument bbArg : region.getBlocks().front().getArguments())
1029 if (isa<TensorType>(Val: bbArg.getType()))
1030 r.addAlias(alias: {bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
1031 return r;
1032}
1033
1034bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
1035 auto isaTensor = [](Type t) { return isa<TensorLikeType>(Val: t); };
1036 bool hasTensorBlockArgument = any_of(Range: op->getRegions(), P: [&](Region &r) {
1037 return any_of(Range&: r.getBlocks(), P: [&](Block &b) {
1038 return any_of(Range: b.getArguments(), P: [&](BlockArgument bbArg) {
1039 return isaTensor(bbArg.getType());
1040 });
1041 });
1042 });
1043 if (hasTensorBlockArgument)
1044 return true;
1045
1046 if (any_of(Range: op->getResultTypes(), P: isaTensor))
1047 return true;
1048 return any_of(Range: op->getOperandTypes(), P: isaTensor);
1049}
1050
1051FailureOr<BaseMemRefType>
1052bufferization::detail::asMemRefType(FailureOr<BufferLikeType> bufferType) {
1053 if (failed(Result: bufferType))
1054 return failure();
1055 return cast<BaseMemRefType>(Val&: *bufferType);
1056}
1057
1058bool bufferization::detail::typesMatchAfterBufferization(Operation &op,
1059 Value tensor,
1060 Value buffer) {
1061 return mlir::succeeded(
1062 Result: cast<TensorLikeType>(Val: tensor.getType())
1063 .verifyCompatibleBufferType(bufferType: cast<BufferLikeType>(Val: buffer.getType()),
1064 emitError: [&]() { return op.emitError(); }));
1065}
1066

source code of mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp