1//===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/Dialect/Tensor/IR/Tensor.h"
14#include "mlir/IR/AsmState.h"
15#include "mlir/IR/BuiltinOps.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/IR/TypeUtilities.h"
19#include "mlir/IR/Value.h"
20#include "mlir/Interfaces/ControlFlowInterfaces.h"
21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/Support/Debug.h"
23
24//===----------------------------------------------------------------------===//
25// BufferizableOpInterface
26//===----------------------------------------------------------------------===//
27
28namespace mlir {
29namespace bufferization {
30
31#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
32
33} // namespace bufferization
34} // namespace mlir
35
36MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
37
38#define DEBUG_TYPE "bufferizable-op-interface"
39#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
40#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
41
42using namespace mlir;
43using namespace bufferization;
44
45static bool isRepetitiveRegion(Region *region,
46 const BufferizationOptions &options) {
47 Operation *op = region->getParentOp();
48 if (auto bufferizableOp = options.dynCastBufferizableOp(op))
49 if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
50 return true;
51 return false;
52}
53
54Region *AnalysisState::getEnclosingRepetitiveRegion(
55 Operation *op, const BufferizationOptions &options) {
56 if (!op->getBlock())
57 return nullptr;
58 if (auto iter = enclosingRepetitiveRegionCache.find_as(op);
59 iter != enclosingRepetitiveRegionCache.end())
60 return iter->second;
61 return enclosingRepetitiveRegionCache[op] =
62 getEnclosingRepetitiveRegion(op->getBlock(), options);
63}
64
65Region *AnalysisState::getEnclosingRepetitiveRegion(
66 Value value, const BufferizationOptions &options) {
67 if (auto iter = enclosingRepetitiveRegionCache.find_as(value);
68 iter != enclosingRepetitiveRegionCache.end())
69 return iter->second;
70
71 Region *region = value.getParentRegion();
72 // Collect all visited regions since we only know the repetitive region we
73 // want to map it to later on
74 SmallVector<Region *> visitedRegions;
75 while (region) {
76 visitedRegions.push_back(Elt: region);
77 if (isRepetitiveRegion(region, options))
78 break;
79 region = region->getParentRegion();
80 }
81 enclosingRepetitiveRegionCache[value] = region;
82 for (Region *r : visitedRegions)
83 enclosingRepetitiveRegionCache[r] = region;
84 return region;
85}
86
87Region *AnalysisState::getEnclosingRepetitiveRegion(
88 Block *block, const BufferizationOptions &options) {
89 if (auto iter = enclosingRepetitiveRegionCache.find_as(block);
90 iter != enclosingRepetitiveRegionCache.end())
91 return iter->second;
92
93 Region *region = block->getParent();
94 Operation *op = nullptr;
95 // Collect all visited regions since we only know the repetitive region we
96 // want to map it to later on
97 SmallVector<Region *> visitedRegions;
98 do {
99 op = region->getParentOp();
100 if (isRepetitiveRegion(region, options))
101 break;
102 } while ((region = op->getParentRegion()));
103
104 enclosingRepetitiveRegionCache[block] = region;
105 for (Region *r : visitedRegions)
106 enclosingRepetitiveRegionCache[r] = region;
107 return region;
108}
109
110bool AnalysisState::insideMutuallyExclusiveRegions(Operation *op0,
111 Operation *op1) {
112 auto key = std::make_pair(x&: op0, y&: op1);
113 if (auto iter = insideMutuallyExclusiveRegionsCache.find(key);
114 iter != insideMutuallyExclusiveRegionsCache.end())
115 return iter->second;
116 bool result = ::mlir::insideMutuallyExclusiveRegions(a: op0, b: op1);
117 // Populate results for both orderings of the ops.
118 insideMutuallyExclusiveRegionsCache[key] = result;
119 insideMutuallyExclusiveRegionsCache[std::make_pair(op1, op0)] = result;
120 return result;
121}
122
123void AnalysisState::resetCache() {
124 enclosingRepetitiveRegionCache.clear();
125 insideMutuallyExclusiveRegionsCache.clear();
126}
127
128SymbolTableCollection &BufferizationState::getSymbolTables() {
129 return symbolTables;
130}
131
132Region *bufferization::getNextEnclosingRepetitiveRegion(
133 Region *region, const BufferizationOptions &options) {
134 assert(isRepetitiveRegion(region, options) && "expected repetitive region");
135 while ((region = region->getParentRegion())) {
136 if (isRepetitiveRegion(region, options))
137 break;
138 }
139 return region;
140}
141
142Region *bufferization::getParallelRegion(Region *region,
143 const BufferizationOptions &options) {
144 while (region) {
145 auto bufferizableOp = options.dynCastBufferizableOp(region->getParentOp());
146 if (bufferizableOp &&
147 bufferizableOp.isParallelRegion(region->getRegionNumber())) {
148 assert(isRepetitiveRegion(region, options) &&
149 "expected that all parallel regions are also repetitive regions");
150 return region;
151 }
152 region = region->getParentRegion();
153 }
154 return nullptr;
155}
156
157Operation *bufferization::getOwnerOfValue(Value value) {
158 if (auto opResult = llvm::dyn_cast<OpResult>(Val&: value))
159 return opResult.getDefiningOp();
160 return llvm::cast<BlockArgument>(Val&: value).getOwner()->getParentOp();
161}
162
163/// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
164/// shaped value is copied. Otherwise, a tensor with undefined contents is
165/// allocated.
166FailureOr<Value> bufferization::allocateTensorForShapedValue(
167 OpBuilder &b, Location loc, Value shapedValue,
168 const BufferizationOptions &options, const BufferizationState &state,
169 bool copy) {
170 Value tensor;
171 if (llvm::isa<RankedTensorType>(Val: shapedValue.getType())) {
172 tensor = shapedValue;
173 } else if (llvm::isa<MemRefType>(Val: shapedValue.getType())) {
174 tensor = b.create<ToTensorOp>(loc, shapedValue);
175 } else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) ||
176 llvm::isa<UnrankedMemRefType>(shapedValue.getType())) {
177 return getOwnerOfValue(value: shapedValue)
178 ->emitError(message: "copying of unranked tensors is not implemented");
179 } else {
180 llvm_unreachable("expected RankedTensorType or MemRefType");
181 }
182 RankedTensorType tensorType = llvm::cast<RankedTensorType>(tensor.getType());
183 SmallVector<Value> dynamicSizes;
184 if (!copy) {
185 // Compute the dynamic part of the shape.
186 // First try to query the shape via ReifyRankedShapedTypeOpInterface.
187 bool reifiedShapes = false;
188 if (llvm::isa<RankedTensorType>(Val: shapedValue.getType()) &&
189 llvm::isa<OpResult>(Val: shapedValue)) {
190 ReifiedRankedShapedTypeDims resultDims;
191 if (succeeded(
192 Result: reifyResultShapes(b, op: shapedValue.getDefiningOp(), reifiedReturnShapes&: resultDims))) {
193 reifiedShapes = true;
194 auto &shape =
195 resultDims[llvm::cast<OpResult>(Val&: shapedValue).getResultNumber()];
196 for (const auto &dim : enumerate(tensorType.getShape()))
197 if (ShapedType::isDynamic(dim.value()))
198 dynamicSizes.push_back(cast<Value>(shape[dim.index()]));
199 }
200 }
201
202 // If the shape could not be reified, create DimOps.
203 if (!reifiedShapes)
204 populateDynamicDimSizes(b, loc, shapedValue: tensor, dynamicDims&: dynamicSizes);
205 }
206
207 // Create AllocTensorOp.
208 auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
209 copy ? tensor : Value());
210
211 // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
212 if (copy)
213 return allocTensorOp.getResult();
214 FailureOr<BaseMemRefType> copyBufferType =
215 getBufferType(value: tensor, options, state);
216 if (failed(Result: copyBufferType))
217 return failure();
218 std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
219 if (!memorySpace)
220 memorySpace = options.defaultMemorySpaceFn(tensorType);
221 if (memorySpace.has_value())
222 allocTensorOp.setMemorySpaceAttr(memorySpace.value());
223 return allocTensorOp.getResult();
224}
225
226LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
227 RewriterBase &rewriter, const AnalysisState &analysisState,
228 const BufferizationState &bufferizationState) {
229 OpBuilder::InsertionGuard g(rewriter);
230 Operation *op = getOperation();
231 SmallVector<OpOperand *> outOfPlaceOpOperands;
232 DenseSet<OpOperand *> copiedOpOperands;
233 SmallVector<Value> outOfPlaceValues;
234 DenseSet<Value> copiedOpValues;
235
236 // Find all out-of-place OpOperands.
237 for (OpOperand &opOperand : op->getOpOperands()) {
238 Type operandType = opOperand.get().getType();
239 if (!llvm::isa<TensorType>(operandType))
240 continue;
241 if (analysisState.isInPlace(opOperand))
242 continue;
243 if (llvm::isa<UnrankedTensorType>(operandType))
244 return op->emitError("copying of unranked tensors is not implemented");
245
246 AliasingValueList aliasingValues =
247 analysisState.getAliasingValues(opOperand);
248 if (aliasingValues.getNumAliases() == 1 &&
249 isa<OpResult>(aliasingValues.getAliases()[0].value) &&
250 !analysisState.bufferizesToMemoryWrite(opOperand) &&
251 analysisState
252 .getAliasingOpOperands(aliasingValues.getAliases()[0].value)
253 .getNumAliases() == 1 &&
254 !isa<UnrankedTensorType>(
255 aliasingValues.getAliases()[0].value.getType())) {
256 // The op itself does not write but may create exactly one alias. Instead
257 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
258 // be smaller than the OpOperand (e.g., in the case of an extract_slice,
259 // where the result is usually a smaller part of the source). Do not apply
260 // this optimization if the OpResult is an unranked tensor (because those
261 // cannot be copied at the moment).
262 Value value = aliasingValues.getAliases()[0].value;
263 outOfPlaceValues.push_back(value);
264 if (!analysisState.canOmitTensorCopy(opOperand))
265 copiedOpValues.insert(value);
266 } else {
267 // In all other cases, make a copy of the OpOperand.
268 outOfPlaceOpOperands.push_back(&opOperand);
269 if (!analysisState.canOmitTensorCopy(opOperand))
270 copiedOpOperands.insert(&opOperand);
271 }
272 }
273
274 // Insert copies of OpOperands.
275 rewriter.setInsertionPoint(op);
276 for (OpOperand *opOperand : outOfPlaceOpOperands) {
277 FailureOr<Value> copy = allocateTensorForShapedValue(
278 rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
279 bufferizationState, copiedOpOperands.contains(opOperand));
280 if (failed(copy))
281 return failure();
282 rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
283 }
284
285 // Insert copies of Values.
286 rewriter.setInsertionPointAfter(op);
287 for (Value value : outOfPlaceValues) {
288 FailureOr<Value> copy = allocateTensorForShapedValue(
289 rewriter, op->getLoc(), value, analysisState.getOptions(),
290 bufferizationState, copiedOpValues.count(value));
291 if (failed(copy))
292 return failure();
293 SmallVector<OpOperand *> uses = llvm::to_vector(
294 llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; }));
295 for (OpOperand *use : uses) {
296 // Do not update the alloc_tensor op that we just created.
297 if (use->getOwner() == copy->getDefiningOp())
298 continue;
299 // tensor.dim ops may have been created to be used as alloc_tensor op
300 // dynamic extents. Do not update these either.
301 if (isa<tensor::DimOp>(use->getOwner()))
302 continue;
303 rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); });
304 }
305 }
306
307 return success();
308}
309
310//===----------------------------------------------------------------------===//
311// OpFilter
312//===----------------------------------------------------------------------===//
313
314bool OpFilter::isOpAllowed(Operation *op) const {
315 // All other ops: Allow/disallow according to filter.
316 bool isAllowed = !hasAllowRule();
317 for (const Entry &entry : entries) {
318 bool filterResult = entry.fn(op);
319 switch (entry.type) {
320 case Entry::ALLOW:
321 isAllowed |= filterResult;
322 break;
323 case Entry::DENY:
324 if (filterResult)
325 // DENY filter matches. This op is no allowed. (Even if other ALLOW
326 // filters may match.)
327 return false;
328 };
329 }
330 return isAllowed;
331}
332
333//===----------------------------------------------------------------------===//
334// BufferizationOptions
335//===----------------------------------------------------------------------===//
336
337namespace {
338
339/// Default function arg type converter: Use a fully dynamic layout map.
340BaseMemRefType
341defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
342 func::FuncOp funcOp,
343 const BufferizationOptions &options) {
344 return getMemRefTypeWithFullyDynamicLayout(tensorType: type, memorySpace);
345}
346/// Default unknown type converter: Use a fully dynamic layout map.
347BaseMemRefType
348defaultUnknownTypeConverter(Value value, Attribute memorySpace,
349 const BufferizationOptions &options) {
350 return getMemRefTypeWithFullyDynamicLayout(
351 tensorType: llvm::cast<TensorType>(Val: value.getType()), memorySpace);
352}
353
354} // namespace
355
356// Default constructor for BufferizationOptions.
357BufferizationOptions::BufferizationOptions()
358 : functionArgTypeConverterFn(defaultFunctionArgTypeConverter),
359 unknownTypeConverterFn(defaultUnknownTypeConverter) {}
360
361bool BufferizationOptions::isOpAllowed(Operation *op) const {
362 // Special case: If function boundary bufferization is deactivated, do not
363 // allow ops that belong to the `func` dialect.
364 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
365 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
366 return false;
367
368 return opFilter.isOpAllowed(op);
369}
370
371BufferizableOpInterface
372BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
373 if (!isOpAllowed(op))
374 return nullptr;
375 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
376 if (!bufferizableOp)
377 return nullptr;
378 return bufferizableOp;
379}
380
381BufferizableOpInterface
382BufferizationOptions::dynCastBufferizableOp(Value value) const {
383 return dynCastBufferizableOp(getOwnerOfValue(value));
384}
385
386void BufferizationOptions::setFunctionBoundaryTypeConversion(
387 LayoutMapOption layoutMapOption) {
388 functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
389 func::FuncOp funcOp,
390 const BufferizationOptions &options) {
391 if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
392 return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
393 memorySpace);
394 return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
395 memorySpace);
396 };
397 inferFunctionResultLayout =
398 layoutMapOption == LayoutMapOption::InferLayoutMap;
399}
400
401//===----------------------------------------------------------------------===//
402// Helper functions for BufferizableOpInterface
403//===----------------------------------------------------------------------===//
404
405static void setInsertionPointAfter(OpBuilder &b, Value value) {
406 if (auto bbArg = llvm::dyn_cast<BlockArgument>(Val&: value)) {
407 b.setInsertionPointToStart(bbArg.getOwner());
408 } else {
409 b.setInsertionPointAfter(value.getDefiningOp());
410 }
411}
412
413/// Determine which OpOperand* will alias with `value` if the op is bufferized
414/// in place. Return all tensor OpOperand* if the op is not bufferizable.
415AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const {
416 if (Operation *op = getOwnerOfValue(value))
417 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
418 return bufferizableOp.getAliasingOpOperands(value, *this);
419
420 // The op is not bufferizable.
421 return detail::unknownGetAliasingOpOperands(value);
422}
423
424/// Determine which Values will alias with `opOperand` if the op is bufferized
425/// in place. Return all tensor Values if the op is not bufferizable.
426AliasingValueList AnalysisState::getAliasingValues(OpOperand &opOperand) const {
427 if (auto bufferizableOp =
428 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
429 return bufferizableOp.getAliasingValues(opOperand, *this);
430
431 // The op is not bufferizable.
432 return detail::unknownGetAliasingValues(opOperand);
433}
434
435/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
436/// op is not bufferizable.
437bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
438 if (auto bufferizableOp =
439 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
440 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
441
442 // Unknown op that returns a tensor. The inplace analysis does not support it.
443 // Conservatively return true.
444 return true;
445}
446
447/// Return true if `opOperand` bufferizes to a memory write. Return
448/// `true` if the op is not bufferizable.
449bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
450 if (auto bufferizableOp =
451 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
452 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
453
454 // Unknown op that returns a tensor. The inplace analysis does not support it.
455 // Conservatively return true.
456 return true;
457}
458
459/// Return true if `opOperand` does neither read nor write but bufferizes to an
460/// alias. Return false if the op is not bufferizable.
461bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
462 if (auto bufferizableOp =
463 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
464 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
465
466 // Unknown op that returns a tensor. The inplace analysis does not support it.
467 // Conservatively return false.
468 return false;
469}
470
471bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
472 auto opResult = llvm::dyn_cast<OpResult>(Val&: value);
473 if (!opResult)
474 return true;
475 auto bufferizableOp = getOptions().dynCastBufferizableOp(value);
476 if (!bufferizableOp)
477 return true;
478 return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *this);
479}
480
481/// Return true if the given value is read by an op that bufferizes to a memory
482/// read. Also takes into account ops that create an alias but do not read by
483/// themselves (e.g., ExtractSliceOp).
484bool AnalysisState::isValueRead(Value value) const {
485 assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
486 SmallVector<OpOperand *> workingSet;
487 DenseSet<OpOperand *> visited;
488 for (OpOperand &use : value.getUses())
489 workingSet.push_back(Elt: &use);
490
491 while (!workingSet.empty()) {
492 OpOperand *uMaybeReading = workingSet.pop_back_val();
493 if (!visited.insert(V: uMaybeReading).second)
494 continue;
495
496 // Skip over all ops that neither read nor write (but create an alias).
497 if (bufferizesToAliasOnly(*uMaybeReading))
498 for (AliasingValue alias : getAliasingValues(*uMaybeReading))
499 for (OpOperand &use : alias.value.getUses())
500 workingSet.push_back(&use);
501 if (bufferizesToMemoryRead(opOperand&: *uMaybeReading))
502 return true;
503 }
504
505 return false;
506}
507
508// Starting from `opOperand`, follow the use-def chain in reverse, always
509// selecting the aliasing OpOperands. Find and return Values for which
510// `condition` evaluates to true. Uses of such matching Values are not
511// traversed any further, the visited aliasing opOperands will be preserved
512// through `visitedOpOperands`.
513llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
514 OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
515 TraversalConfig config,
516 llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
517 llvm::DenseSet<Value> visited;
518 llvm::SetVector<Value> result, workingSet;
519 workingSet.insert(X: opOperand->get());
520
521 if (visitedOpOperands)
522 visitedOpOperands->insert(V: opOperand);
523
524 while (!workingSet.empty()) {
525 Value value = workingSet.pop_back_val();
526
527 if (!config.revisitAlreadyVisitedValues && visited.contains(V: value)) {
528 // Stop traversal if value was already visited.
529 if (config.alwaysIncludeLeaves)
530 result.insert(X: value);
531 continue;
532 }
533 visited.insert(V: value);
534
535 if (condition(value)) {
536 result.insert(X: value);
537 continue;
538 }
539
540 if (!config.followUnknownOps && !options.dynCastBufferizableOp(value)) {
541 // Stop iterating if `followUnknownOps` is unset and the op is either
542 // not bufferizable or excluded in the OpFilter.
543 if (config.alwaysIncludeLeaves)
544 result.insert(X: value);
545 continue;
546 }
547
548 AliasingOpOperandList aliases = getAliasingOpOperands(value);
549 if (aliases.getNumAliases() == 0) {
550 // The traversal ends naturally if there are no more OpOperands that
551 // could be followed.
552 if (config.alwaysIncludeLeaves)
553 result.insert(X: value);
554 continue;
555 }
556
557 for (AliasingOpOperand a : aliases) {
558 if (config.followEquivalentOnly &&
559 a.relation != BufferRelation::Equivalent) {
560 // Stop iterating if `followEquivalentOnly` is set but the alias is not
561 // equivalent.
562 if (config.alwaysIncludeLeaves)
563 result.insert(value);
564 continue;
565 }
566
567 if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) {
568 // Stop iterating if `followInPlaceOnly` is set but the alias is
569 // out-of-place.
570 if (config.alwaysIncludeLeaves)
571 result.insert(value);
572 continue;
573 }
574
575 if (config.followSameTypeOrCastsOnly &&
576 a.opOperand->get().getType() != value.getType() &&
577 !value.getDefiningOp<CastOpInterface>()) {
578 // Stop iterating if `followSameTypeOrCastsOnly` is set but the alias is
579 // has a different type and the op is not a cast.
580 if (config.alwaysIncludeLeaves)
581 result.insert(value);
582 continue;
583 }
584
585 workingSet.insert(a.opOperand->get());
586 if (visitedOpOperands)
587 visitedOpOperands->insert(a.opOperand);
588 }
589 }
590
591 return result;
592}
593
594// Find the values that define the contents of the given operand's value.
595llvm::SetVector<Value>
596AnalysisState::findDefinitions(OpOperand *opOperand) const {
597 TraversalConfig config;
598 config.alwaysIncludeLeaves = false;
599 return findValueInReverseUseDefChain(
600 opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
601 config);
602}
603
604AnalysisState::AnalysisState(const BufferizationOptions &options)
605 : AnalysisState(options, TypeID::get<AnalysisState>()) {}
606
607AnalysisState::AnalysisState(const BufferizationOptions &options, TypeID type)
608 : options(options), type(type) {
609 for (const BufferizationOptions::AnalysisStateInitFn &fn :
610 options.stateInitializers)
611 fn(*this);
612}
613
614bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
615 // Do not copy if the tensor has undefined contents.
616 if (hasUndefinedContents(opOperand: &opOperand))
617 return true;
618
619 // Do not copy if the buffer of the tensor is entirely overwritten (with
620 // values that do not depend on the old tensor).
621 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
622 return true;
623
624 // Do not copy if the tensor is never read.
625 AliasingValueList aliases = getAliasingValues(opOperand);
626 if (!bufferizesToMemoryRead(opOperand) &&
627 llvm::none_of(Range&: aliases,
628 P: [&](AliasingValue a) { return isValueRead(value: a.value); }))
629 return true;
630
631 // Default: Cannot omit the copy.
632 return false;
633}
634
635bool AnalysisState::isInPlace(OpOperand &opOperand) const {
636 // ToBufferOps are always in-place.
637 if (isa<ToBufferOp>(opOperand.getOwner()))
638 return true;
639
640 // In the absence of analysis information, OpOperands that bufferize to a
641 // memory write are out-of-place, i.e., an alloc and copy is inserted.
642 return !bufferizesToMemoryWrite(opOperand);
643}
644
645bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
646 // In the absence of analysis information, we do not know if the values are
647 // equivalent. The conservative answer is "false".
648 return false;
649}
650
651bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
652 // In the absence of analysis information, we do not know if the values may be
653 // aliasing. The conservative answer is "true".
654 return true;
655}
656
657bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
658 // In the absence of analysis information, the conservative answer is "false".
659 return false;
660}
661
662// bufferization.to_buffer is not allowed to change the rank.
663static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
664#ifndef NDEBUG
665 auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
666 assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
667 rankedTensorType.getRank()) &&
668 "to_buffer would be invalid: mismatching ranks");
669#endif
670}
671
672FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
673 const BufferizationOptions &options,
674 const BufferizationState &state) {
675#ifndef NDEBUG
676 auto tensorType = llvm::dyn_cast<TensorType>(Val: value.getType());
677 assert(tensorType && "unexpected non-tensor type");
678#endif // NDEBUG
679
680 // Replace "%t = to_tensor %m" with %m.
681 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
682 return toTensorOp.getMemref();
683
684 // Insert to_buffer op.
685 OpBuilder::InsertionGuard g(rewriter);
686 setInsertionPointAfter(b&: rewriter, value);
687 FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
688 if (failed(Result: memrefType))
689 return failure();
690 ensureToBufferOpIsValid(tensor: value, memrefType: *memrefType);
691 return rewriter
692 .create<bufferization::ToBufferOp>(value.getLoc(), *memrefType, value)
693 .getResult();
694}
695
696/// Return the buffer type for a given Value (tensor) after bufferization.
697FailureOr<BaseMemRefType>
698bufferization::getBufferType(Value value, const BufferizationOptions &options,
699 const BufferizationState &state) {
700 SmallVector<Value> invocationStack;
701 return getBufferType(value, options, state, invocationStack);
702}
703
704/// Return the buffer type for a given Value (tensor) after bufferization.
705FailureOr<BaseMemRefType>
706bufferization::getBufferType(Value value, const BufferizationOptions &options,
707 const BufferizationState &state,
708 SmallVector<Value> &invocationStack) {
709 assert(llvm::isa<TensorType>(value.getType()) &&
710 "unexpected non-tensor type");
711 invocationStack.push_back(Elt: value);
712 auto popFromStack =
713 llvm::make_scope_exit(F: [&]() { invocationStack.pop_back(); });
714
715 // Try querying BufferizableOpInterface.
716 Operation *op = getOwnerOfValue(value);
717 auto bufferizableOp = options.dynCastBufferizableOp(op);
718 if (bufferizableOp)
719 return bufferizableOp.getBufferType(value, options, state, invocationStack);
720
721 // Op is not bufferizable.
722 auto memSpace =
723 options.defaultMemorySpaceFn(cast<TensorType>(Val: value.getType()));
724 if (!memSpace.has_value())
725 return op->emitError(message: "could not infer memory space");
726
727 return getMemRefType(value, options, /*layout=*/{}, *memSpace);
728}
729
730bool bufferization::hasTensorSemantics(Operation *op) {
731 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
732 return bufferizableOp.hasTensorSemantics();
733 return detail::defaultHasTensorSemantics(op);
734}
735
736void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
737 Operation *op,
738 ValueRange values) {
739 assert(values.size() == op->getNumResults() &&
740 "expected one value per OpResult");
741 OpBuilder::InsertionGuard g(rewriter);
742
743 // Replace all OpResults with the given values.
744 SmallVector<Value> replacements;
745 for (OpResult opResult : op->getOpResults()) {
746 Value replacement = values[opResult.getResultNumber()];
747 if (llvm::isa<TensorType>(Val: opResult.getType())) {
748 // The OpResult is a tensor. Such values are replaced with memrefs during
749 // bufferization.
750 assert((llvm::isa<MemRefType>(replacement.getType()) ||
751 llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
752 "tensor op result should be replaced with a memref value");
753 // The existing uses of the OpResult still expect a tensor. Insert a
754 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
755 // loose all of its users and eventually DCE away.
756 rewriter.setInsertionPointAfter(op);
757 replacement = rewriter.create<bufferization::ToTensorOp>(
758 replacement.getLoc(), opResult.getType(), replacement);
759 }
760 replacements.push_back(Elt: replacement);
761 }
762
763 rewriter.replaceOp(op, newValues: replacements);
764}
765
766//===----------------------------------------------------------------------===//
767// Bufferization-specific scoped alloc insertion support.
768//===----------------------------------------------------------------------===//
769
770/// Create a memref allocation with the given type and dynamic extents.
771FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
772 MemRefType type,
773 ValueRange dynShape) const {
774 if (allocationFn)
775 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
776
777 // Default bufferallocation via AllocOp.
778 if (bufferAlignment != 0)
779 return b
780 .create<memref::AllocOp>(loc, type, dynShape,
781 b.getI64IntegerAttr(bufferAlignment))
782 .getResult();
783 return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
784}
785
786/// Create a memory copy between two memref buffers.
787LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
788 Value from, Value to) const {
789 if (memCpyFn)
790 return (*memCpyFn)(b, loc, from, to);
791
792 b.create<memref::CopyOp>(loc, from, to);
793 return success();
794}
795
796//===----------------------------------------------------------------------===//
797// Bufferization-specific IRMapping support with debugging.
798//===----------------------------------------------------------------------===//
799
800BaseMemRefType bufferization::getMemRefType(Value value,
801 const BufferizationOptions &options,
802 MemRefLayoutAttrInterface layout,
803 Attribute memorySpace) {
804 auto tensorType = llvm::cast<TensorType>(Val: value.getType());
805
806 // Case 1: Unranked memref type.
807 if (auto unrankedTensorType =
808 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
809 assert(!layout && "UnrankedTensorType cannot have a layout map");
810 return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
811 memorySpace);
812 }
813
814 // Case 2: Ranked memref type with specified layout.
815 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
816 if (layout) {
817 return MemRefType::get(rankedTensorType.getShape(),
818 rankedTensorType.getElementType(), layout,
819 memorySpace);
820 }
821
822 return options.unknownTypeConverterFn(value, memorySpace, options);
823}
824
825BaseMemRefType
826bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
827 Attribute memorySpace) {
828 // Case 1: Unranked memref type.
829 if (auto unrankedTensorType =
830 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
831 return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
832 memorySpace);
833 }
834
835 // Case 2: Ranked memref type.
836 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
837 int64_t dynamicOffset = ShapedType::kDynamic;
838 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
839 ShapedType::kDynamic);
840 auto stridedLayout = StridedLayoutAttr::get(tensorType.getContext(),
841 dynamicOffset, dynamicStrides);
842 return MemRefType::get(rankedTensorType.getShape(),
843 rankedTensorType.getElementType(), stridedLayout,
844 memorySpace);
845}
846
847/// Return a MemRef type with a static identity layout (i.e., no layout map). If
848/// the given tensor type is unranked, return an unranked MemRef type.
849BaseMemRefType
850bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
851 Attribute memorySpace) {
852 // Case 1: Unranked memref type.
853 if (auto unrankedTensorType =
854 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
855 return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
856 memorySpace);
857 }
858
859 // Case 2: Ranked memref type.
860 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
861 MemRefLayoutAttrInterface layout = {};
862 return MemRefType::get(rankedTensorType.getShape(),
863 rankedTensorType.getElementType(), layout,
864 memorySpace);
865}
866
867//===----------------------------------------------------------------------===//
868// Default implementations of interface methods
869//===----------------------------------------------------------------------===//
870
871bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
872 OpResult opResult, const AnalysisState &state) {
873 auto bufferizableOp = cast<BufferizableOpInterface>(opResult.getDefiningOp());
874 AliasingOpOperandList opOperands =
875 bufferizableOp.getAliasingOpOperands(opResult, state);
876
877 // Case 1: OpResults that have no aliasing OpOperand usually bufferize to
878 // memory writes.
879 if (opOperands.getAliases().empty())
880 return true;
881
882 // Case 2: If an aliasing OpOperand bufferizes to a memory write, the OpResult
883 // may bufferize to a memory write.
884 if (llvm::any_of(Range&: opOperands, P: [&](AliasingOpOperand alias) {
885 return state.bufferizesToMemoryWrite(opOperand&: *alias.opOperand);
886 }))
887 return true;
888
889 // Case 3: Check if a nested aliasing OpOperand value bufferizes to a memory
890 // write. (Or: The reverse SSA use-def chain ends inside the reigon.) In that
891 // case, the OpResult bufferizes to a memory write. E.g.:
892 //
893 // %0 = "some_writing_op" : tensor<?xf32>
894 // %r = scf.if ... -> tensor<?xf32> {
895 // scf.yield %0 : tensor<?xf32>
896 // } else {
897 // %1 = "another_writing_op"(%0) : tensor<?xf32>
898 // scf.yield %1 : tensor<?xf32>
899 // }
900 // "some_reading_op"(%r)
901 //
902 // %r bufferizes to a memory write because an aliasing OpOperand value (%1)
903 // bufferizes to a memory write and the defining op is inside the scf.if.
904 //
905 // Note: This treatment of surrouding ops is useful for ops that have a
906 // region but no OpOperand such as scf.if or scf.execute_region. It simplifies
907 // the analysis considerably.
908 //
909 // "another_writing_op" in the above example should be able to bufferize
910 // inplace in the absence of another read of %0. However, if the scf.if op
911 // would not be considered a "write", the analysis would detect the
912 // following conflict:
913 //
914 // * read = some_reading_op
915 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.)
916 // * conflictingWrite = %1
917 //
918 auto isMemoryWriteInsideOp = [&](Value v) {
919 Operation *op = getOwnerOfValue(value: v);
920 if (!opResult.getDefiningOp()->isAncestor(other: op))
921 return false;
922 return state.bufferizesToMemoryWrite(value: v);
923 };
924 TraversalConfig config;
925 config.alwaysIncludeLeaves = false;
926 for (AliasingOpOperand alias : opOperands) {
927 if (!state
928 .findValueInReverseUseDefChain(alias.opOperand,
929 isMemoryWriteInsideOp, config)
930 .empty())
931 return true;
932 }
933 return false;
934}
935
936// Compute the AliasingOpOperandList for a given Value based on
937// getAliasingValues.
938AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
939 Value value, const AnalysisState &state) {
940 Operation *op = getOwnerOfValue(value);
941 SmallVector<AliasingOpOperand> result;
942 for (OpOperand &opOperand : op->getOpOperands()) {
943 if (!llvm::isa<TensorType>(Val: opOperand.get().getType()))
944 continue;
945 AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
946 for (const auto &it : aliasingValues)
947 if (it.value == value)
948 result.emplace_back(&opOperand, it.relation, it.isDefinite);
949 }
950 return AliasingOpOperandList(std::move(result));
951}
952
953FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
954 Value value, const BufferizationOptions &options,
955 const BufferizationState &bufferizationState,
956 SmallVector<Value> &invocationStack) {
957 assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
958
959 // No further analysis is possible for a block argument.
960 if (llvm::isa<BlockArgument>(Val: value))
961 return bufferization::getMemRefType(value, options);
962
963 // Value is an OpResult.
964 Operation *op = getOwnerOfValue(value);
965 auto opResult = llvm::cast<OpResult>(Val&: value);
966 AnalysisState analysisState(options);
967 AliasingOpOperandList aliases = analysisState.getAliasingOpOperands(opResult);
968 if (aliases.getNumAliases() > 0 &&
969 aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
970 // If the OpResult has an equivalent OpOperand, both OpResult and
971 // OpOperand bufferize to the exact same buffer type.
972 Value equivalentOperand = aliases.getAliases().front().opOperand->get();
973 return getBufferType(value: equivalentOperand, options, state: bufferizationState,
974 invocationStack);
975 }
976
977 // If we do not know the memory space and there is no default memory space,
978 // report a failure.
979 auto memSpace =
980 options.defaultMemorySpaceFn(cast<TensorType>(Val: value.getType()));
981 if (!memSpace.has_value())
982 return op->emitError(message: "could not infer memory space");
983
984 return getMemRefType(value, options, /*layout=*/{}, *memSpace);
985}
986
987bool bufferization::detail::defaultIsRepetitiveRegion(
988 BufferizableOpInterface bufferizableOp, unsigned index) {
989 assert(index < bufferizableOp->getNumRegions() && "invalid region index");
990 auto regionInterface =
991 dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation());
992 if (!regionInterface)
993 return false;
994 return regionInterface.isRepetitiveRegion(index);
995}
996
997AliasingOpOperandList
998bufferization::detail::unknownGetAliasingOpOperands(Value value) {
999 // TODO: Take into account successor blocks.
1000 // No aliasing in case of non-entry blocks.
1001 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value))
1002 if (bbArg.getOwner() != &bbArg.getOwner()->getParent()->getBlocks().front())
1003 return {};
1004
1005 // Unknown op: Conservatively assume that each OpResult may alias with every
1006 // OpOperand. In addition, each block argument of an entry block may alias
1007 // with every OpOperand.
1008 AliasingOpOperandList r;
1009 for (OpOperand &operand : value.getDefiningOp()->getOpOperands())
1010 if (isa<TensorType>(Val: operand.get().getType()))
1011 r.addAlias(alias: {&operand, BufferRelation::Unknown, /*isDefinite=*/false});
1012 return r;
1013}
1014
1015AliasingValueList
1016bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
1017 // TODO: Take into account successor blocks.
1018 // Unknown op: Conservatively assume that each OpResult may alias with every
1019 // OpOperand. In addition, each block argument of an entry block may alias
1020 // with every OpOperand.
1021 AliasingValueList r;
1022 for (OpResult result : opOperand.getOwner()->getOpResults())
1023 if (llvm::isa<TensorType>(Val: result.getType()))
1024 r.addAlias(alias: {result, BufferRelation::Unknown, /*isDefinite=*/false});
1025 for (Region &region : opOperand.getOwner()->getRegions())
1026 if (!region.getBlocks().empty())
1027 for (BlockArgument bbArg : region.getBlocks().front().getArguments())
1028 if (isa<TensorType>(Val: bbArg.getType()))
1029 r.addAlias(alias: {bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
1030 return r;
1031}
1032
1033bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
1034 auto isaTensor = [](Type t) { return isa<TensorType>(Val: t); };
1035 bool hasTensorBlockArgument = any_of(Range: op->getRegions(), P: [&](Region &r) {
1036 return any_of(Range&: r.getBlocks(), P: [&](Block &b) {
1037 return any_of(Range: b.getArguments(), P: [&](BlockArgument bbArg) {
1038 return isaTensor(bbArg.getType());
1039 });
1040 });
1041 });
1042 if (hasTensorBlockArgument)
1043 return true;
1044
1045 if (any_of(Range: op->getResultTypes(), P: isaTensor))
1046 return true;
1047 return any_of(Range: op->getOperandTypes(), P: isaTensor);
1048}
1049

source code of mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp