1//===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/Dialect/Tensor/IR/Tensor.h"
14#include "mlir/IR/AsmState.h"
15#include "mlir/IR/BuiltinOps.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/IR/TypeUtilities.h"
19#include "mlir/IR/Value.h"
20#include "mlir/Interfaces/ControlFlowInterfaces.h"
21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/Support/Debug.h"
23
24//===----------------------------------------------------------------------===//
25// BufferizableOpInterface
26//===----------------------------------------------------------------------===//
27
28namespace mlir {
29namespace bufferization {
30
31#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
32
33} // namespace bufferization
34} // namespace mlir
35
36MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
37
38#define DEBUG_TYPE "bufferizable-op-interface"
39#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
40#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
41
42using namespace mlir;
43using namespace bufferization;
44
45static bool isRepetitiveRegion(Region *region,
46 const BufferizationOptions &options) {
47 Operation *op = region->getParentOp();
48 if (auto bufferizableOp = options.dynCastBufferizableOp(op))
49 if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
50 return true;
51 return false;
52}
53
54Region *AnalysisState::getEnclosingRepetitiveRegion(
55 Operation *op, const BufferizationOptions &options) {
56 if (!op->getBlock())
57 return nullptr;
58 if (auto iter = enclosingRepetitiveRegionCache.find_as(op);
59 iter != enclosingRepetitiveRegionCache.end())
60 return iter->second;
61 return enclosingRepetitiveRegionCache[op] =
62 getEnclosingRepetitiveRegion(op->getBlock(), options);
63}
64
65Region *AnalysisState::getEnclosingRepetitiveRegion(
66 Value value, const BufferizationOptions &options) {
67 if (auto iter = enclosingRepetitiveRegionCache.find_as(value);
68 iter != enclosingRepetitiveRegionCache.end())
69 return iter->second;
70
71 Region *region = value.getParentRegion();
72 // Collect all visited regions since we only know the repetitive region we
73 // want to map it to later on
74 SmallVector<Region *> visitedRegions;
75 while (region) {
76 visitedRegions.push_back(Elt: region);
77 if (isRepetitiveRegion(region, options))
78 break;
79 region = region->getParentRegion();
80 }
81 enclosingRepetitiveRegionCache[value] = region;
82 for (Region *r : visitedRegions)
83 enclosingRepetitiveRegionCache[r] = region;
84 return region;
85}
86
87Region *AnalysisState::getEnclosingRepetitiveRegion(
88 Block *block, const BufferizationOptions &options) {
89 if (auto iter = enclosingRepetitiveRegionCache.find_as(block);
90 iter != enclosingRepetitiveRegionCache.end())
91 return iter->second;
92
93 Region *region = block->getParent();
94 Operation *op = nullptr;
95 // Collect all visited regions since we only know the repetitive region we
96 // want to map it to later on
97 SmallVector<Region *> visitedRegions;
98 do {
99 op = region->getParentOp();
100 if (isRepetitiveRegion(region, options))
101 break;
102 } while ((region = op->getParentRegion()));
103
104 enclosingRepetitiveRegionCache[block] = region;
105 for (Region *r : visitedRegions)
106 enclosingRepetitiveRegionCache[r] = region;
107 return region;
108}
109
110void AnalysisState::resetCache() { enclosingRepetitiveRegionCache.clear(); }
111
112Region *bufferization::getNextEnclosingRepetitiveRegion(
113 Region *region, const BufferizationOptions &options) {
114 assert(isRepetitiveRegion(region, options) && "expected repetitive region");
115 while ((region = region->getParentRegion())) {
116 if (isRepetitiveRegion(region, options))
117 break;
118 }
119 return region;
120}
121
122Region *bufferization::getParallelRegion(Region *region,
123 const BufferizationOptions &options) {
124 while (region) {
125 auto bufferizableOp = options.dynCastBufferizableOp(region->getParentOp());
126 if (bufferizableOp &&
127 bufferizableOp.isParallelRegion(region->getRegionNumber())) {
128 assert(isRepetitiveRegion(region, options) &&
129 "expected that all parallel regions are also repetitive regions");
130 return region;
131 }
132 region = region->getParentRegion();
133 }
134 return nullptr;
135}
136
137Operation *bufferization::getOwnerOfValue(Value value) {
138 if (auto opResult = llvm::dyn_cast<OpResult>(Val&: value))
139 return opResult.getDefiningOp();
140 return llvm::cast<BlockArgument>(Val&: value).getOwner()->getParentOp();
141}
142
143/// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
144/// shaped value is copied. Otherwise, a tensor with undefined contents is
145/// allocated.
146FailureOr<Value> bufferization::allocateTensorForShapedValue(
147 OpBuilder &b, Location loc, Value shapedValue,
148 const BufferizationOptions &options, bool copy) {
149 Value tensor;
150 if (llvm::isa<RankedTensorType>(Val: shapedValue.getType())) {
151 tensor = shapedValue;
152 } else if (llvm::isa<MemRefType>(Val: shapedValue.getType())) {
153 tensor = b.create<ToTensorOp>(loc, shapedValue);
154 } else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) ||
155 llvm::isa<UnrankedMemRefType>(shapedValue.getType())) {
156 return getOwnerOfValue(value: shapedValue)
157 ->emitError(message: "copying of unranked tensors is not implemented");
158 } else {
159 llvm_unreachable("expected RankedTensorType or MemRefType");
160 }
161 RankedTensorType tensorType = llvm::cast<RankedTensorType>(tensor.getType());
162 SmallVector<Value> dynamicSizes;
163 if (!copy) {
164 // Compute the dynamic part of the shape.
165 // First try to query the shape via ReifyRankedShapedTypeOpInterface.
166 bool reifiedShapes = false;
167 if (llvm::isa<RankedTensorType>(Val: shapedValue.getType()) &&
168 llvm::isa<OpResult>(Val: shapedValue)) {
169 ReifiedRankedShapedTypeDims resultDims;
170 if (succeeded(
171 result: reifyResultShapes(b, op: shapedValue.getDefiningOp(), reifiedReturnShapes&: resultDims))) {
172 reifiedShapes = true;
173 auto &shape =
174 resultDims[llvm::cast<OpResult>(Val&: shapedValue).getResultNumber()];
175 for (const auto &dim : enumerate(tensorType.getShape()))
176 if (ShapedType::isDynamic(dim.value()))
177 dynamicSizes.push_back(shape[dim.index()].get<Value>());
178 }
179 }
180
181 // If the shape could not be reified, create DimOps.
182 if (!reifiedShapes)
183 populateDynamicDimSizes(b, loc, shapedValue: tensor, dynamicDims&: dynamicSizes);
184 }
185
186 // Create AllocTensorOp.
187 auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
188 copy ? tensor : Value());
189
190 // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
191 if (copy)
192 return allocTensorOp.getResult();
193 FailureOr<BaseMemRefType> copyBufferType = getBufferType(value: tensor, options);
194 if (failed(result: copyBufferType))
195 return failure();
196 std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
197 if (!memorySpace)
198 memorySpace = options.defaultMemorySpaceFn(tensorType);
199 if (memorySpace.has_value())
200 allocTensorOp.setMemorySpaceAttr(memorySpace.value());
201 return allocTensorOp.getResult();
202}
203
204LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
205 RewriterBase &rewriter, const AnalysisState &state) {
206 OpBuilder::InsertionGuard g(rewriter);
207 Operation *op = getOperation();
208 SmallVector<OpOperand *> outOfPlaceOpOperands;
209 DenseSet<OpOperand *> copiedOpOperands;
210 SmallVector<Value> outOfPlaceValues;
211 DenseSet<Value> copiedOpValues;
212
213 // Find all out-of-place OpOperands.
214 for (OpOperand &opOperand : op->getOpOperands()) {
215 Type operandType = opOperand.get().getType();
216 if (!llvm::isa<TensorType>(operandType))
217 continue;
218 if (state.isInPlace(opOperand))
219 continue;
220 if (llvm::isa<UnrankedTensorType>(operandType))
221 return op->emitError("copying of unranked tensors is not implemented");
222
223 AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
224 if (aliasingValues.getNumAliases() == 1 &&
225 isa<OpResult>(aliasingValues.getAliases()[0].value) &&
226 !state.bufferizesToMemoryWrite(opOperand) &&
227 state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
228 .getNumAliases() == 1 &&
229 !isa<UnrankedTensorType>(
230 aliasingValues.getAliases()[0].value.getType())) {
231 // The op itself does not write but may create exactly one alias. Instead
232 // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
233 // be smaller than the OpOperand (e.g., in the case of an extract_slice,
234 // where the result is usually a smaller part of the source). Do not apply
235 // this optimization if the OpResult is an unranked tensor (because those
236 // cannot be copied at the moment).
237 Value value = aliasingValues.getAliases()[0].value;
238 outOfPlaceValues.push_back(value);
239 if (!state.canOmitTensorCopy(opOperand))
240 copiedOpValues.insert(value);
241 } else {
242 // In all other cases, make a copy of the OpOperand.
243 outOfPlaceOpOperands.push_back(&opOperand);
244 if (!state.canOmitTensorCopy(opOperand))
245 copiedOpOperands.insert(&opOperand);
246 }
247 }
248
249 // Insert copies of OpOperands.
250 rewriter.setInsertionPoint(op);
251 for (OpOperand *opOperand : outOfPlaceOpOperands) {
252 FailureOr<Value> copy = allocateTensorForShapedValue(
253 rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
254 copiedOpOperands.contains(opOperand));
255 if (failed(copy))
256 return failure();
257 rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
258 }
259
260 // Insert copies of Values.
261 rewriter.setInsertionPointAfter(op);
262 for (Value value : outOfPlaceValues) {
263 FailureOr<Value> copy = allocateTensorForShapedValue(
264 rewriter, op->getLoc(), value, state.getOptions(),
265 copiedOpValues.count(value));
266 if (failed(copy))
267 return failure();
268 SmallVector<OpOperand *> uses = llvm::to_vector(
269 llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; }));
270 for (OpOperand *use : uses) {
271 // Do not update the alloc_tensor op that we just created.
272 if (use->getOwner() == copy->getDefiningOp())
273 continue;
274 // tensor.dim ops may have been created to be used as alloc_tensor op
275 // dynamic extents. Do not update these either.
276 if (isa<tensor::DimOp>(use->getOwner()))
277 continue;
278 rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); });
279 }
280 }
281
282 return success();
283}
284
285//===----------------------------------------------------------------------===//
286// OpFilter
287//===----------------------------------------------------------------------===//
288
289bool OpFilter::isOpAllowed(Operation *op) const {
290 // All other ops: Allow/disallow according to filter.
291 bool isAllowed = !hasAllowRule();
292 for (const Entry &entry : entries) {
293 bool filterResult = entry.fn(op);
294 switch (entry.type) {
295 case Entry::ALLOW:
296 isAllowed |= filterResult;
297 break;
298 case Entry::DENY:
299 if (filterResult)
300 // DENY filter matches. This op is no allowed. (Even if other ALLOW
301 // filters may match.)
302 return false;
303 };
304 }
305 return isAllowed;
306}
307
308//===----------------------------------------------------------------------===//
309// BufferizationOptions
310//===----------------------------------------------------------------------===//
311
312namespace {
313
314/// Default function arg type converter: Use a fully dynamic layout map.
315BaseMemRefType
316defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
317 func::FuncOp funcOp,
318 const BufferizationOptions &options) {
319 return getMemRefTypeWithFullyDynamicLayout(tensorType: type, memorySpace);
320}
321/// Default unknown type converter: Use a fully dynamic layout map.
322BaseMemRefType
323defaultUnknownTypeConverter(Value value, Attribute memorySpace,
324 const BufferizationOptions &options) {
325 return getMemRefTypeWithFullyDynamicLayout(
326 tensorType: llvm::cast<TensorType>(Val: value.getType()), memorySpace);
327}
328
329} // namespace
330
331// Default constructor for BufferizationOptions.
332BufferizationOptions::BufferizationOptions()
333 : functionArgTypeConverterFn(defaultFunctionArgTypeConverter),
334 unknownTypeConverterFn(defaultUnknownTypeConverter) {}
335
336bool BufferizationOptions::isOpAllowed(Operation *op) const {
337 // Special case: If function boundary bufferization is deactivated, do not
338 // allow ops that belong to the `func` dialect.
339 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
340 if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
341 return false;
342
343 return opFilter.isOpAllowed(op);
344}
345
346BufferizableOpInterface
347BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
348 if (!isOpAllowed(op))
349 return nullptr;
350 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
351 if (!bufferizableOp)
352 return nullptr;
353 return bufferizableOp;
354}
355
356BufferizableOpInterface
357BufferizationOptions::dynCastBufferizableOp(Value value) const {
358 return dynCastBufferizableOp(getOwnerOfValue(value));
359}
360
361void BufferizationOptions::setFunctionBoundaryTypeConversion(
362 LayoutMapOption layoutMapOption) {
363 functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
364 func::FuncOp funcOp,
365 const BufferizationOptions &options) {
366 if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
367 return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
368 memorySpace);
369 return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
370 memorySpace);
371 };
372 inferFunctionResultLayout =
373 layoutMapOption == LayoutMapOption::InferLayoutMap;
374}
375
376//===----------------------------------------------------------------------===//
377// Helper functions for BufferizableOpInterface
378//===----------------------------------------------------------------------===//
379
380static void setInsertionPointAfter(OpBuilder &b, Value value) {
381 if (auto bbArg = llvm::dyn_cast<BlockArgument>(Val&: value)) {
382 b.setInsertionPointToStart(bbArg.getOwner());
383 } else {
384 b.setInsertionPointAfter(value.getDefiningOp());
385 }
386}
387
388/// Determine which OpOperand* will alias with `value` if the op is bufferized
389/// in place. Return all tensor OpOperand* if the op is not bufferizable.
390AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const {
391 if (Operation *op = getOwnerOfValue(value))
392 if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
393 return bufferizableOp.getAliasingOpOperands(value, *this);
394
395 // The op is not bufferizable.
396 return detail::unknownGetAliasingOpOperands(value);
397}
398
399/// Determine which Values will alias with `opOperand` if the op is bufferized
400/// in place. Return all tensor Values if the op is not bufferizable.
401AliasingValueList AnalysisState::getAliasingValues(OpOperand &opOperand) const {
402 if (auto bufferizableOp =
403 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
404 return bufferizableOp.getAliasingValues(opOperand, *this);
405
406 // The op is not bufferizable.
407 return detail::unknownGetAliasingValues(opOperand);
408}
409
410/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
411/// op is not bufferizable.
412bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
413 if (auto bufferizableOp =
414 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
415 return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
416
417 // Unknown op that returns a tensor. The inplace analysis does not support it.
418 // Conservatively return true.
419 return true;
420}
421
422/// Return true if `opOperand` bufferizes to a memory write. Return
423/// `true` if the op is not bufferizable.
424bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
425 if (auto bufferizableOp =
426 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
427 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
428
429 // Unknown op that returns a tensor. The inplace analysis does not support it.
430 // Conservatively return true.
431 return true;
432}
433
434/// Return true if `opOperand` does neither read nor write but bufferizes to an
435/// alias. Return false if the op is not bufferizable.
436bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
437 if (auto bufferizableOp =
438 getOptions().dynCastBufferizableOp(opOperand.getOwner()))
439 return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
440
441 // Unknown op that returns a tensor. The inplace analysis does not support it.
442 // Conservatively return false.
443 return false;
444}
445
446bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
447 auto opResult = llvm::dyn_cast<OpResult>(Val&: value);
448 if (!opResult)
449 return true;
450 auto bufferizableOp = getOptions().dynCastBufferizableOp(value);
451 if (!bufferizableOp)
452 return true;
453 return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *this);
454}
455
456/// Return true if the given value is read by an op that bufferizes to a memory
457/// read. Also takes into account ops that create an alias but do not read by
458/// themselves (e.g., ExtractSliceOp).
459bool AnalysisState::isValueRead(Value value) const {
460 assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
461 SmallVector<OpOperand *> workingSet;
462 DenseSet<OpOperand *> visited;
463 for (OpOperand &use : value.getUses())
464 workingSet.push_back(Elt: &use);
465
466 while (!workingSet.empty()) {
467 OpOperand *uMaybeReading = workingSet.pop_back_val();
468 if (visited.contains(V: uMaybeReading))
469 continue;
470 visited.insert(V: uMaybeReading);
471
472 // Skip over all ops that neither read nor write (but create an alias).
473 if (bufferizesToAliasOnly(*uMaybeReading))
474 for (AliasingValue alias : getAliasingValues(*uMaybeReading))
475 for (OpOperand &use : alias.value.getUses())
476 workingSet.push_back(&use);
477 if (bufferizesToMemoryRead(opOperand&: *uMaybeReading))
478 return true;
479 }
480
481 return false;
482}
483
484// Starting from `value`, follow the use-def chain in reverse, always selecting
485// the aliasing OpOperands. Find and return Values for which `condition`
486// evaluates to true. OpOperands of such matching Values are not traversed any
487// further.
488llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
489 Value value, llvm::function_ref<bool(Value)> condition,
490 TraversalConfig config) const {
491 llvm::DenseSet<Value> visited;
492 llvm::SetVector<Value> result, workingSet;
493 workingSet.insert(X: value);
494
495 while (!workingSet.empty()) {
496 Value value = workingSet.pop_back_val();
497
498 if (!config.revisitAlreadyVisitedValues && visited.contains(V: value)) {
499 // Stop traversal if value was already visited.
500 if (config.alwaysIncludeLeaves)
501 result.insert(X: value);
502 continue;
503 }
504 visited.insert(V: value);
505
506 if (condition(value)) {
507 result.insert(X: value);
508 continue;
509 }
510
511 if (!config.followUnknownOps && !options.dynCastBufferizableOp(value)) {
512 // Stop iterating if `followUnknownOps` is unset and the op is either
513 // not bufferizable or excluded in the OpFilter.
514 if (config.alwaysIncludeLeaves)
515 result.insert(X: value);
516 continue;
517 }
518
519 AliasingOpOperandList aliases = getAliasingOpOperands(value);
520 if (aliases.getNumAliases() == 0) {
521 // The traversal ends naturally if there are no more OpOperands that
522 // could be followed.
523 if (config.alwaysIncludeLeaves)
524 result.insert(X: value);
525 continue;
526 }
527
528 for (AliasingOpOperand a : aliases) {
529 if (config.followEquivalentOnly &&
530 a.relation != BufferRelation::Equivalent) {
531 // Stop iterating if `followEquivalentOnly` is set but the alias is not
532 // equivalent.
533 if (config.alwaysIncludeLeaves)
534 result.insert(value);
535 continue;
536 }
537
538 if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) {
539 // Stop iterating if `followInPlaceOnly` is set but the alias is
540 // out-of-place.
541 if (config.alwaysIncludeLeaves)
542 result.insert(value);
543 continue;
544 }
545
546 if (config.followSameTypeOrCastsOnly &&
547 a.opOperand->get().getType() != value.getType() &&
548 !value.getDefiningOp<CastOpInterface>()) {
549 // Stop iterating if `followSameTypeOrCastsOnly` is set but the alias is
550 // has a different type and the op is not a cast.
551 if (config.alwaysIncludeLeaves)
552 result.insert(value);
553 continue;
554 }
555
556 workingSet.insert(a.opOperand->get());
557 }
558 }
559
560 return result;
561}
562
563// Find the values that define the contents of the given value.
564llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
565 TraversalConfig config;
566 config.alwaysIncludeLeaves = false;
567 return findValueInReverseUseDefChain(
568 value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
569}
570
571AnalysisState::AnalysisState(const BufferizationOptions &options)
572 : AnalysisState(options, TypeID::get<AnalysisState>()) {}
573
574AnalysisState::AnalysisState(const BufferizationOptions &options, TypeID type)
575 : options(options), type(type) {
576 for (const BufferizationOptions::AnalysisStateInitFn &fn :
577 options.stateInitializers)
578 fn(*this);
579}
580
581bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
582 // Do not copy if the tensor has undefined contents.
583 if (hasUndefinedContents(opOperand: &opOperand))
584 return true;
585
586 // Do not copy if the buffer of the tensor is entirely overwritten (with
587 // values that do not depend on the old tensor).
588 if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
589 return true;
590
591 // Do not copy if the tensor is never read.
592 AliasingValueList aliases = getAliasingValues(opOperand);
593 if (!bufferizesToMemoryRead(opOperand) &&
594 llvm::none_of(Range&: aliases,
595 P: [&](AliasingValue a) { return isValueRead(value: a.value); }))
596 return true;
597
598 // Default: Cannot omit the copy.
599 return false;
600}
601
602bool AnalysisState::isInPlace(OpOperand &opOperand) const {
603 // ToMemrefOps are always in-place.
604 if (isa<ToMemrefOp>(opOperand.getOwner()))
605 return true;
606
607 // In the absence of analysis information, OpOperands that bufferize to a
608 // memory write are out-of-place, i.e., an alloc and copy is inserted.
609 return !bufferizesToMemoryWrite(opOperand);
610}
611
612bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
613 // In the absence of analysis information, we do not know if the values are
614 // equivalent. The conservative answer is "false".
615 return false;
616}
617
618bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
619 // In the absence of analysis information, we do not know if the values may be
620 // aliasing. The conservative answer is "true".
621 return true;
622}
623
624bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
625 // In the absence of analysis information, the conservative answer is "false".
626 return false;
627}
628
629// bufferization.to_memref is not allowed to change the rank.
630static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
631#ifndef NDEBUG
632 auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
633 assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
634 rankedTensorType.getRank()) &&
635 "to_memref would be invalid: mismatching ranks");
636#endif
637}
638
639FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
640 const BufferizationOptions &options) {
641#ifndef NDEBUG
642 auto tensorType = llvm::dyn_cast<TensorType>(Val: value.getType());
643 assert(tensorType && "unexpected non-tensor type");
644#endif // NDEBUG
645
646 // Replace "%t = to_tensor %m" with %m.
647 if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
648 return toTensorOp.getMemref();
649
650 // Insert to_memref op.
651 OpBuilder::InsertionGuard g(rewriter);
652 setInsertionPointAfter(b&: rewriter, value);
653 FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
654 if (failed(result: memrefType))
655 return failure();
656 ensureToMemrefOpIsValid(tensor: value, memrefType: *memrefType);
657 return rewriter
658 .create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
659 .getResult();
660}
661
662/// Return the buffer type for a given Value (tensor) after bufferization.
663FailureOr<BaseMemRefType>
664bufferization::getBufferType(Value value, const BufferizationOptions &options) {
665 SmallVector<Value> invocationStack;
666 return getBufferType(value, options, invocationStack);
667}
668
669/// Return the buffer type for a given Value (tensor) after bufferization.
670FailureOr<BaseMemRefType>
671bufferization::getBufferType(Value value, const BufferizationOptions &options,
672 SmallVector<Value> &invocationStack) {
673 assert(llvm::isa<TensorType>(value.getType()) &&
674 "unexpected non-tensor type");
675 invocationStack.push_back(Elt: value);
676 auto popFromStack =
677 llvm::make_scope_exit(F: [&]() { invocationStack.pop_back(); });
678
679 // Try querying BufferizableOpInterface.
680 Operation *op = getOwnerOfValue(value);
681 auto bufferizableOp = options.dynCastBufferizableOp(op);
682 if (bufferizableOp)
683 return bufferizableOp.getBufferType(value, options, invocationStack);
684
685 // Op is not bufferizable.
686 auto memSpace =
687 options.defaultMemorySpaceFn(cast<TensorType>(Val: value.getType()));
688 if (!memSpace.has_value())
689 return op->emitError(message: "could not infer memory space");
690
691 return getMemRefType(value, options, /*layout=*/{}, *memSpace);
692}
693
694bool bufferization::hasTensorSemantics(Operation *op) {
695 if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
696 return bufferizableOp.hasTensorSemantics();
697 return detail::defaultHasTensorSemantics(op);
698}
699
700void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
701 Operation *op,
702 ValueRange values) {
703 assert(values.size() == op->getNumResults() &&
704 "expected one value per OpResult");
705 OpBuilder::InsertionGuard g(rewriter);
706
707 // Replace all OpResults with the given values.
708 SmallVector<Value> replacements;
709 for (OpResult opResult : op->getOpResults()) {
710 Value replacement = values[opResult.getResultNumber()];
711 if (llvm::isa<TensorType>(Val: opResult.getType())) {
712 // The OpResult is a tensor. Such values are replaced with memrefs during
713 // bufferization.
714 assert((llvm::isa<MemRefType>(replacement.getType()) ||
715 llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
716 "tensor op result should be replaced with a memref value");
717 // The existing uses of the OpResult still expect a tensor. Insert a
718 // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
719 // loose all of its users and eventually DCE away.
720 rewriter.setInsertionPointAfter(op);
721 replacement = rewriter.create<bufferization::ToTensorOp>(
722 replacement.getLoc(), replacement);
723 }
724 replacements.push_back(Elt: replacement);
725 }
726
727 rewriter.replaceOp(op, newValues: replacements);
728}
729
730//===----------------------------------------------------------------------===//
731// Bufferization-specific scoped alloc insertion support.
732//===----------------------------------------------------------------------===//
733
734/// Create a memref allocation with the given type and dynamic extents.
735FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
736 MemRefType type,
737 ValueRange dynShape) const {
738 if (allocationFn)
739 return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
740
741 // Default bufferallocation via AllocOp.
742 if (bufferAlignment != 0)
743 return b
744 .create<memref::AllocOp>(loc, type, dynShape,
745 b.getI64IntegerAttr(bufferAlignment))
746 .getResult();
747 return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
748}
749
750/// Create a memory copy between two memref buffers.
751LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
752 Value from, Value to) const {
753 if (memCpyFn)
754 return (*memCpyFn)(b, loc, from, to);
755
756 b.create<memref::CopyOp>(loc, from, to);
757 return success();
758}
759
760//===----------------------------------------------------------------------===//
761// Bufferization-specific IRMapping support with debugging.
762//===----------------------------------------------------------------------===//
763
764BaseMemRefType bufferization::getMemRefType(Value value,
765 const BufferizationOptions &options,
766 MemRefLayoutAttrInterface layout,
767 Attribute memorySpace) {
768 auto tensorType = llvm::cast<TensorType>(Val: value.getType());
769
770 // Case 1: Unranked memref type.
771 if (auto unrankedTensorType =
772 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
773 assert(!layout && "UnrankedTensorType cannot have a layout map");
774 return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
775 memorySpace);
776 }
777
778 // Case 2: Ranked memref type with specified layout.
779 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
780 if (layout) {
781 return MemRefType::get(rankedTensorType.getShape(),
782 rankedTensorType.getElementType(), layout,
783 memorySpace);
784 }
785
786 return options.unknownTypeConverterFn(value, memorySpace, options);
787}
788
789BaseMemRefType
790bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
791 Attribute memorySpace) {
792 // Case 1: Unranked memref type.
793 if (auto unrankedTensorType =
794 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
795 return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
796 memorySpace);
797 }
798
799 // Case 2: Ranked memref type.
800 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
801 int64_t dynamicOffset = ShapedType::kDynamic;
802 SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
803 ShapedType::kDynamic);
804 auto stridedLayout = StridedLayoutAttr::get(tensorType.getContext(),
805 dynamicOffset, dynamicStrides);
806 return MemRefType::get(rankedTensorType.getShape(),
807 rankedTensorType.getElementType(), stridedLayout,
808 memorySpace);
809}
810
811/// Return a MemRef type with a static identity layout (i.e., no layout map). If
812/// the given tensor type is unranked, return an unranked MemRef type.
813BaseMemRefType
814bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
815 Attribute memorySpace) {
816 // Case 1: Unranked memref type.
817 if (auto unrankedTensorType =
818 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
819 return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
820 memorySpace);
821 }
822
823 // Case 2: Ranked memref type.
824 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
825 MemRefLayoutAttrInterface layout = {};
826 return MemRefType::get(rankedTensorType.getShape(),
827 rankedTensorType.getElementType(), layout,
828 memorySpace);
829}
830
831//===----------------------------------------------------------------------===//
832// Default implementations of interface methods
833//===----------------------------------------------------------------------===//
834
835bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
836 OpResult opResult, const AnalysisState &state) {
837 auto bufferizableOp = cast<BufferizableOpInterface>(opResult.getDefiningOp());
838 AliasingOpOperandList opOperands =
839 bufferizableOp.getAliasingOpOperands(opResult, state);
840
841 // Case 1: OpResults that have no aliasing OpOperand usually bufferize to
842 // memory writes.
843 if (opOperands.getAliases().empty())
844 return true;
845
846 // Case 2: If an aliasing OpOperand bufferizes to a memory write, the OpResult
847 // may bufferize to a memory write.
848 if (llvm::any_of(Range&: opOperands, P: [&](AliasingOpOperand alias) {
849 return state.bufferizesToMemoryWrite(opOperand&: *alias.opOperand);
850 }))
851 return true;
852
853 // Case 3: Check if a nested aliasing OpOperand value bufferizes to a memory
854 // write. (Or: The reverse SSA use-def chain ends inside the reigon.) In that
855 // case, the OpResult bufferizes to a memory write. E.g.:
856 //
857 // %0 = "some_writing_op" : tensor<?xf32>
858 // %r = scf.if ... -> tensor<?xf32> {
859 // scf.yield %0 : tensor<?xf32>
860 // } else {
861 // %1 = "another_writing_op"(%0) : tensor<?xf32>
862 // scf.yield %1 : tensor<?xf32>
863 // }
864 // "some_reading_op"(%r)
865 //
866 // %r bufferizes to a memory write because an aliasing OpOperand value (%1)
867 // bufferizes to a memory write and the defining op is inside the scf.if.
868 //
869 // Note: This treatment of surrouding ops is useful for ops that have a
870 // region but no OpOperand such as scf.if or scf.execute_region. It simplifies
871 // the analysis considerably.
872 //
873 // "another_writing_op" in the above example should be able to bufferize
874 // inplace in the absence of another read of %0. However, if the scf.if op
875 // would not be considered a "write", the analysis would detect the
876 // following conflict:
877 //
878 // * read = some_reading_op
879 // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.)
880 // * conflictingWrite = %1
881 //
882 auto isMemoryWriteInsideOp = [&](Value v) {
883 Operation *op = getOwnerOfValue(value: v);
884 if (!opResult.getDefiningOp()->isAncestor(other: op))
885 return false;
886 return state.bufferizesToMemoryWrite(value: v);
887 };
888 TraversalConfig config;
889 config.alwaysIncludeLeaves = false;
890 for (AliasingOpOperand alias : opOperands) {
891 if (!state
892 .findValueInReverseUseDefChain(alias.opOperand->get(),
893 isMemoryWriteInsideOp, config)
894 .empty())
895 return true;
896 }
897 return false;
898}
899
900// Compute the AliasingOpOperandList for a given Value based on
901// getAliasingValues.
902AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
903 Value value, const AnalysisState &state) {
904 Operation *op = getOwnerOfValue(value);
905 SmallVector<AliasingOpOperand> result;
906 for (OpOperand &opOperand : op->getOpOperands()) {
907 if (!llvm::isa<TensorType>(Val: opOperand.get().getType()))
908 continue;
909 AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
910 for (const auto &it : aliasingValues)
911 if (it.value == value)
912 result.emplace_back(&opOperand, it.relation, it.isDefinite);
913 }
914 return AliasingOpOperandList(std::move(result));
915}
916
917FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
918 Value value, const BufferizationOptions &options,
919 SmallVector<Value> &invocationStack) {
920 assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
921
922 // No further analysis is possible for a block argument.
923 if (llvm::isa<BlockArgument>(Val: value))
924 return bufferization::getMemRefType(value, options);
925
926 // Value is an OpResult.
927 Operation *op = getOwnerOfValue(value);
928 auto opResult = llvm::cast<OpResult>(Val&: value);
929 AnalysisState state(options);
930 AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
931 if (aliases.getNumAliases() > 0 &&
932 aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
933 // If the OpResult has an equivalent OpOperand, both OpResult and
934 // OpOperand bufferize to the exact same buffer type.
935 Value equivalentOperand = aliases.getAliases().front().opOperand->get();
936 return getBufferType(value: equivalentOperand, options, invocationStack);
937 }
938
939 // If we do not know the memory space and there is no default memory space,
940 // report a failure.
941 auto memSpace =
942 options.defaultMemorySpaceFn(cast<TensorType>(Val: value.getType()));
943 if (!memSpace.has_value())
944 return op->emitError(message: "could not infer memory space");
945
946 return getMemRefType(value, options, /*layout=*/{}, *memSpace);
947}
948
949bool bufferization::detail::defaultIsRepetitiveRegion(
950 BufferizableOpInterface bufferizableOp, unsigned index) {
951 assert(index < bufferizableOp->getNumRegions() && "invalid region index");
952 auto regionInterface =
953 dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation());
954 if (!regionInterface)
955 return false;
956 return regionInterface.isRepetitiveRegion(index);
957}
958
959AliasingOpOperandList
960bufferization::detail::unknownGetAliasingOpOperands(Value value) {
961 // TODO: Take into account successor blocks.
962 // No aliasing in case of non-entry blocks.
963 if (auto bbArg = dyn_cast<BlockArgument>(Val&: value))
964 if (bbArg.getOwner() != &bbArg.getOwner()->getParent()->getBlocks().front())
965 return {};
966
967 // Unknown op: Conservatively assume that each OpResult may alias with every
968 // OpOperand. In addition, each block argument of an entry block may alias
969 // with every OpOperand.
970 AliasingOpOperandList r;
971 for (OpOperand &operand : value.getDefiningOp()->getOpOperands())
972 if (isa<TensorType>(Val: operand.get().getType()))
973 r.addAlias(alias: {&operand, BufferRelation::Unknown, /*isDefinite=*/false});
974 return r;
975}
976
977AliasingValueList
978bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
979 // TODO: Take into account successor blocks.
980 // Unknown op: Conservatively assume that each OpResult may alias with every
981 // OpOperand. In addition, each block argument of an entry block may alias
982 // with every OpOperand.
983 AliasingValueList r;
984 for (OpResult result : opOperand.getOwner()->getOpResults())
985 if (llvm::isa<TensorType>(Val: result.getType()))
986 r.addAlias(alias: {result, BufferRelation::Unknown, /*isDefinite=*/false});
987 for (Region &region : opOperand.getOwner()->getRegions())
988 if (!region.getBlocks().empty())
989 for (BlockArgument bbArg : region.getBlocks().front().getArguments())
990 if (isa<TensorType>(Val: bbArg.getType()))
991 r.addAlias(alias: {bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
992 return r;
993}
994
995bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
996 auto isaTensor = [](Type t) { return isa<TensorType>(Val: t); };
997 bool hasTensorBlockArgument = any_of(Range: op->getRegions(), P: [&](Region &r) {
998 return any_of(Range&: r.getBlocks(), P: [&](Block &b) {
999 return any_of(Range: b.getArguments(), P: [&](BlockArgument bbArg) {
1000 return isaTensor(bbArg.getType());
1001 });
1002 });
1003 });
1004 if (hasTensorBlockArgument)
1005 return true;
1006
1007 if (any_of(Range: op->getResultTypes(), P: isaTensor))
1008 return true;
1009 return any_of(Range: op->getOperandTypes(), P: isaTensor);
1010}
1011

source code of mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp