1//===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// One-Shot Analysis analyzes function bodies. By default, function boundaries
10// (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
11// OneShotModuleBufferization.cpp is an extension of One-Shot Analysis for
12// simple call graphs without loops.
13//
14// One-Shot Bufferize consists of three phases.
15//
16// 1. Analyze ops to decide which OpOperands can bufferize inplace, i.e.,
17// without inserting buffer copies. The analysis queries op bufferization
18// semantics via `BufferizableOpInterface`.
19// 2. Insert copies for OpOperands that were decided to bufferize out-of-place
20// in tensor land during `TensorCopyInsertion`.
21// 3. Bufferize ops by calling `BufferizableOpInterface::bufferize`.
22//
23// This file contains only the analysis. For convenience, this file also
24// contains a helper function `runOneShotBufferize` that analyzes an op (and its
25// nested ops) and then bufferizes it.
26//
27// Inplace bufferization decisions are passed from the analysis to the
28// `TensorCopyInsertion` phase via `AnalysisState`. They can be printed for
29// debugging purposes with `testAnalysisOnly`.
30//
31// Ops that do not implement `BufferizableOpInterface` can be analyzed but are
32// treated conservatively. E.g., the analysis has to assume that their tensor
33// OpOperands bufferize to memory writes. While such ops can be analyzed, they
34// are not bufferized and remain in the IR. to_tensor and to_buffer ops are
35// inserted at the bufferization boundary.
36//
37// This analysis caters to high-performance codegen where buffer reuse is deemed
38// critical: the analysis should fail if the bufferized form of the function
39// needs to return a buffer, unless `allowReturnAllocs` is enabled.
40
41#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
42
43#include <optional>
44#include <random>
45
46#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
47#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
48#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
49#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
50#include "mlir/Dialect/Func/IR/FuncOps.h"
51#include "mlir/Dialect/MemRef/IR/MemRef.h"
52#include "mlir/IR/AsmState.h"
53#include "mlir/IR/Dominance.h"
54#include "mlir/IR/Iterators.h"
55#include "mlir/IR/Operation.h"
56#include "mlir/IR/TypeUtilities.h"
57#include "mlir/Interfaces/ControlFlowInterfaces.h"
58#include "mlir/Interfaces/SubsetOpInterface.h"
59#include "llvm/ADT/DenseSet.h"
60#include "llvm/ADT/SetVector.h"
61
62MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
63
64// Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug
65// output.
66#define DEBUG_TYPE "one-shot-analysis"
67
68using namespace mlir;
69using namespace mlir::bufferization;
70
71static bool isaTensor(Type t) { return isa<TensorType>(Val: t); }
72
73//===----------------------------------------------------------------------===//
74// Bufferization-specific attribute manipulation.
75// These are for testing and debugging only. Bufferization information is stored
76// in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is
77// annotated with the results of the analysis, so that they can be checked in
78// tests.
79//===----------------------------------------------------------------------===//
80
81/// Attribute marker to specify op operands that bufferize in-place.
82constexpr StringLiteral kInPlaceOperandsAttrName = "__inplace_operands_attr__";
83
84constexpr StringLiteral kOpResultAliasSetAttrName =
85 "__opresult_alias_set_attr__";
86
87constexpr StringLiteral kBbArgAliasSetAttrName = "__bbarg_alias_set_attr__";
88
89/// Mark whether OpOperand will be bufferized inplace.
90static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
91 Operation *op = opOperand.getOwner();
92 SmallVector<StringRef> inPlaceVector;
93 if (auto attr = op->getAttr(name: kInPlaceOperandsAttrName)) {
94 inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>(
95 cast<ArrayAttr>(attr).getAsValueRange<StringAttr>()));
96 } else {
97 inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
98 for (OpOperand &opOperand : op->getOpOperands())
99 if (isa<TensorType>(Val: opOperand.get().getType()))
100 inPlaceVector[opOperand.getOperandNumber()] = "false";
101 }
102 inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
103 op->setAttr(kInPlaceOperandsAttrName,
104 OpBuilder(op).getStrArrayAttr(inPlaceVector));
105}
106
107//===----------------------------------------------------------------------===//
108// OneShotAnalysisState
109//===----------------------------------------------------------------------===//
110
111OneShotAnalysisState::OneShotAnalysisState(
112 Operation *op, const OneShotBufferizationOptions &options)
113 : AnalysisState(options, TypeID::get<OneShotAnalysisState>()) {
114 // Set up alias sets.
115 op->walk(callback: [&](Operation *op) {
116 for (Value v : op->getResults())
117 if (isa<TensorType>(Val: v.getType()))
118 createAliasInfoEntry(v);
119 for (Region &r : op->getRegions())
120 for (Block &b : r.getBlocks())
121 for (auto bbArg : b.getArguments())
122 if (isa<TensorType>(Val: bbArg.getType()))
123 createAliasInfoEntry(v: bbArg);
124 });
125
126 // Mark OpOperands in-place that must bufferize in-place.
127 op->walk([&](BufferizableOpInterface bufferizableOp) {
128 if (!options.isOpAllowed(op: bufferizableOp))
129 return WalkResult::skip();
130 for (OpOperand &opOperand : bufferizableOp->getOpOperands())
131 if (isa<TensorType>(opOperand.get().getType()))
132 if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
133 bufferizeInPlace(opOperand);
134 return WalkResult::advance();
135 });
136}
137
138void OneShotAnalysisState::applyOnEquivalenceClass(
139 Value v, function_ref<void(Value)> fun) const {
140 auto leaderIt = equivalentInfo.findLeader(V: v);
141 for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
142 ++mit) {
143 fun(*mit);
144 }
145}
146
147void OneShotAnalysisState::applyOnAliases(Value v,
148 function_ref<void(Value)> fun) const {
149 auto leaderIt = aliasInfo.findLeader(V: v);
150 for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
151 fun(*mit);
152 }
153}
154
155bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
156 Value v2) const {
157 return equivalentInfo.isEquivalent(V1: v1, V2: v2);
158}
159
160bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
161 Value v2) const {
162 return aliasInfo.isEquivalent(V1: v1, V2: v2);
163}
164
165void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) {
166 if (inplaceBufferized.contains(V: &operand))
167 return;
168 inplaceBufferized.insert(V: &operand);
169 for (AliasingValue alias : getAliasingValues(operand))
170 aliasInfo.unionSets(alias.value, operand.get());
171 ++statNumTensorInPlace;
172}
173
174void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) {
175 assert(!inplaceBufferized.contains(&operand) &&
176 "OpOperand was already decided to bufferize inplace");
177 ++statNumTensorOutOfPlace;
178}
179
180void OneShotAnalysisState::createAliasInfoEntry(Value v) {
181 aliasInfo.insert(Data: v);
182 equivalentInfo.insert(Data: v);
183}
184
185void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
186 op->walk(callback: [&](Operation *op) {
187 // Skip unknown ops.
188 auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
189 if (!bufferizableOp)
190 return WalkResult::skip();
191
192 // Check all tensor OpResults.
193 for (OpResult opResult : op->getOpResults()) {
194 if (!isa<TensorType>(Val: opResult.getType()))
195 continue;
196
197 // If there is no preceding definition, the tensor contents are
198 // undefined.
199 if (opResult.getUses().empty())
200 continue;
201 // It does not really matter which use to take to search about
202 // the value's definitions.
203 OpOperand *opOperand = &(*opResult.getUses().begin());
204 if (findDefinitionsCached(opOperand).empty())
205 for (OpOperand &use : opResult.getUses())
206 undefinedTensorUses.insert(V: &use);
207 }
208
209 return WalkResult::advance();
210 });
211}
212
213bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
214 return undefinedTensorUses.contains(V: opOperand);
215}
216
217bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
218 return inplaceBufferized.contains(V: &opOperand);
219}
220
221bool OneShotAnalysisState::isValueWritten(Value value) const {
222 bool isWritten = false;
223 applyOnAliases(v: value, fun: [&](Value val) {
224 for (OpOperand &use : val.getUses())
225 if (isInPlace(opOperand&: use) && bufferizesToMemoryWrite(use))
226 isWritten = true;
227 });
228 return isWritten;
229}
230
231bool OneShotAnalysisState::isWritable(Value value) const {
232 // TODO: Out-of-place bufferized value could be considered writable.
233 // Query BufferizableOpInterface to see if the BlockArgument is writable.
234 if (auto bufferizableOp =
235 getOptions().dynCastBufferizableOp(getOwnerOfValue(value)))
236 return bufferizableOp.isWritable(value, *this);
237
238 // Not a bufferizable op: The conservative answer is "not writable".
239 return false;
240}
241
242void OneShotAnalysisState::unionAliasSets(Value v1, Value v2) {
243 aliasInfo.unionSets(V1: v1, V2: v2);
244}
245
246void OneShotAnalysisState::unionEquivalenceClasses(Value v1, Value v2) {
247 equivalentInfo.unionSets(V1: v1, V2: v2);
248}
249
250OneShotAnalysisState::Extension::~Extension() = default;
251
252//===----------------------------------------------------------------------===//
253// Bufferization-specific alias analysis.
254//===----------------------------------------------------------------------===//
255
256/// Return true if opOperand has been decided to bufferize in-place.
257static bool isInplaceMemoryWrite(OpOperand &opOperand,
258 const OneShotAnalysisState &state) {
259 // OpOperands that do not bufferize to a memory write do not write in-place.
260 if (!state.bufferizesToMemoryWrite(opOperand))
261 return false;
262 // Check current bufferization decisions.
263 return state.isInPlace(opOperand);
264}
265
266/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
267/// properly dominates `b` and `b` is not inside `a`.
268static bool happensBefore(Operation *a, Operation *b,
269 const DominanceInfo &domInfo) {
270 do {
271 // TODO: Instead of isProperAncestor + properlyDominates, we should use
272 // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
273 if (a->isProperAncestor(other: b))
274 return false;
275 if (domInfo.properlyDominates(a, b))
276 return true;
277 } while ((a = a->getParentOp()));
278 return false;
279}
280
281/// Return `true` if op dominance can be used to rule out a read-after-write
282/// conflicts based on the ordering of ops. Returns `false` if op dominance
283/// cannot be used to due region-based loops.
284///
285/// Generalized op dominance can often be used to rule out potential conflicts
286/// due to "read happens before write". E.g., the following IR is not a RaW
287/// conflict because the read happens *before* the write.
288///
289/// Example 1:
290/// %0 = ... : tensor<?xf32> // DEF
291/// "reading_op"(%0) : tensor<?xf32> // READ
292/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
293///
294/// This is no longer true inside loops (or repetitive regions). In such cases,
295/// there may not be a meaningful `happensBefore` relationship because ops
296/// could be executed multiple times. E.g.:
297///
298/// Example 2:
299/// %0 = ... : tensor<?xf32> // DEF
300/// scf.for ... {
301/// "reading_op"(%0) : tensor<?xf32> // READ
302/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
303/// ...
304/// }
305///
306/// In the above example, reading_op happens before writing_op according to
307/// op dominance. However, both ops may happen multiple times; in
308/// particular, the second execution of reading_op happens after the first
309/// execution of writing_op. This is problematic because the tensor %0 they
310/// operate on (i.e., the "definition") is defined outside of the loop.
311///
312/// On a high-level, there is a potential RaW in a program if there exists a
313/// possible program execution such that there is a sequence of DEF, followed
314/// by WRITE, followed by READ. Each additional DEF resets the sequence.
315///
316/// E.g.:
317/// No conflict: DEF, WRITE, DEF, READ
318/// Potential conflict: DEF, READ, WRITE, READ, WRITE
319///
320/// Example 1 has no conflict: DEF, READ, WRITE
321/// Example 2 has a potential conflict: DEF, (READ, WRITE)*
322//
323/// Example 3:
324/// scf.for ... {
325/// %0 = ... : tensor<?xf32>
326/// "reading_op"(%0) : tensor<?xf32>
327/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
328/// ...
329/// }
330/// This has no conflict: (DEF, READ, WRITE)*
331///
332/// Example 4:
333/// %0 = ... : tensor<?xf32>
334/// scf.for ... {
335/// scf.for ... { "reading_op"(%0) }
336/// %1 = "writing_op"(%0)
337/// }
338/// This has a potential conflict: DEF, ((READ)*, WRITE)*
339///
340/// Example 5:
341/// %0 = ... : tensor<?xf32>
342/// scf.for ... { %1 = "writing_op"(%0) }
343/// scf.for ... { "reading_op"(%0) }
344/// This has a potential conflict: DEF, WRITE*, READ*
345///
346/// The following rules are used to rule out RaW conflicts via ordering of ops:
347///
348/// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of
349/// a repetitive region that enclosing both READ and WRITE, we cannot rule
350/// out RaW conflict due to the ordering of ops.
351/// 2. Otherwise: There are no loops that interfere with our analysis; for
352/// analysis purposes, we can assume that there are no loops/repetitive
353/// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
354/// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
355///
356static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite,
357 const SetVector<Value> &definitions,
358 AnalysisState &state) {
359 const BufferizationOptions &options = state.getOptions();
360 for (Value def : definitions) {
361 Region *rRead =
362 state.getEnclosingRepetitiveRegion(op: uRead->getOwner(), options);
363 Region *rDef = state.getEnclosingRepetitiveRegion(value: def, options);
364
365 // READ and DEF are in the same repetitive region. `happensBefore` can be
366 // used to rule out RaW conflicts due to op ordering.
367 if (rRead == rDef)
368 continue;
369
370 // Find the enclosing repetitive region of READ that is closest to DEF but
371 // not the repetitive region of DEF itself.
372 while (true) {
373 Region *nextRegion = getNextEnclosingRepetitiveRegion(region: rRead, options);
374 if (nextRegion == rDef)
375 break;
376 assert(nextRegion && "expected to find another repetitive region");
377 rRead = nextRegion;
378 }
379
380 // We cannot use op dominance if WRITE is inside the same repetitive region.
381 if (rRead->getParentOp()->isAncestor(other: uWrite->getOwner()))
382 return false;
383 }
384
385 return true;
386}
387
388/// Return `true` if op dominance can be used to rule out a read-after-write
389/// conflicts based on the ordering of ops. Returns `false` if op dominance
390/// cannot be used to due block-based loops within a region.
391///
392/// Refer to the `canUseOpDominanceDueToRegions` documentation for details on
393/// how op domiance is used during RaW conflict detection.
394///
395/// On a high-level, there is a potential RaW in a program if there exists a
396/// possible program execution such that there is a sequence of DEF, followed
397/// by WRITE, followed by READ. Each additional DEF resets the sequence.
398///
399/// Op dominance cannot be used if there is a path from block(READ) to
400/// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should
401/// not appear on that path.
402static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite,
403 const SetVector<Value> &definitions,
404 AnalysisState &state) {
405 // Fast path: If READ and WRITE are in different regions, their block cannot
406 // be reachable just via unstructured control flow. (Loops due to regions are
407 // covered by `canUseOpDominanceDueToRegions`.)
408 if (uRead->getOwner()->getParentRegion() !=
409 uWrite->getOwner()->getParentRegion())
410 return true;
411
412 Block *readBlock = uRead->getOwner()->getBlock();
413 Block *writeBlock = uWrite->getOwner()->getBlock();
414 for (Value def : definitions) {
415 Block *defBlock = def.getParentBlock();
416 if (readBlock->isReachable(other: writeBlock, except: {defBlock}) &&
417 writeBlock->isReachable(other: readBlock, except: {defBlock}))
418 return false;
419 }
420
421 return true;
422}
423
424static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
425 const SetVector<Value> &definitions,
426 AnalysisState &state) {
427 return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) &&
428 canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state);
429}
430
431/// Annotate IR with details about the detected RaW conflict.
432static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
433 Value definition) {
434 static uint64_t counter = 0;
435 Operation *readingOp = uRead->getOwner();
436 Operation *conflictingWritingOp = uConflictingWrite->getOwner();
437
438 OpBuilder b(conflictingWritingOp->getContext());
439 std::string id = "C_" + std::to_string(val: counter++);
440
441 std::string conflictingWriteAttr =
442 id +
443 "[CONFL-WRITE: " + std::to_string(val: uConflictingWrite->getOperandNumber()) +
444 "]";
445 conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
446
447 std::string readAttr =
448 id + "[READ: " + std::to_string(val: uRead->getOperandNumber()) + "]";
449 readingOp->setAttr(readAttr, b.getUnitAttr());
450
451 if (auto opResult = dyn_cast<OpResult>(Val&: definition)) {
452 std::string defAttr =
453 id + "[DEF: result " + std::to_string(val: opResult.getResultNumber()) + "]";
454 opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
455 } else {
456 auto bbArg = cast<BlockArgument>(Val&: definition);
457 std::string defAttr =
458 id + "[DEF: bbArg " + std::to_string(val: bbArg.getArgNumber()) + "]";
459 bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
460 }
461}
462
463/// Return 'true' if a tensor that is equivalent to `other` can be found in the
464/// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of
465/// place along that use-def chain, the two tensors may not materialize as
466/// equivalent buffers (but separate allocations).
467///
468/// Note: This function also requires that the two tensors have equivalent
469/// indexing. I.e., the tensor types do not change along the use-def chain,
470/// apart from static <-> dynamic dim casts.
471static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
472 OpOperand *start,
473 Value other) {
474 TraversalConfig config;
475 config.followEquivalentOnly = true;
476 config.alwaysIncludeLeaves = false;
477 config.followSameTypeOrCastsOnly = true;
478 return !state
479 .findValueInReverseUseDefChain(
480 start, [&](Value v) { return v == other; }, config)
481 .empty();
482}
483
484/// Return "true" if the given operand's value is originating from a subset
485/// that is equivalent to the subset that `subsetOp` inserts into.
486static bool matchesInsertDestination(const AnalysisState &state,
487 OpOperand *opOperand,
488 SubsetInsertionOpInterface subsetOp) {
489 auto matchingSubset = [&](Value val) {
490 if (auto opResult = dyn_cast<OpResult>(Val&: val))
491 if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
492 return state.areEquivalentBufferizedValues(v1, v2);
493 }))
494 return true;
495 return false;
496 };
497 // There may be multiple leaves at which the reverse SSA use-def chain lookup
498 // terminates. All of them must be equivalent subsets.
499 SetVector<Value> backwardSlice =
500 state.findValueInReverseUseDefChain(opOperand, matchingSubset);
501 return static_cast<bool>(llvm::all_of(Range&: backwardSlice, P: matchingSubset));
502}
503
504/// Return "true" if the given "read" and potentially conflicting "write" are
505/// not conflicting due to their subset relationship. The comments in this
506/// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
507/// pairs, but apply to any subset ops that implement the
508/// `SubsetInsertionOpInterface`.
509static bool areNonConflictingSubsets(OpOperand *uRead,
510 OpOperand *uConflictingWrite,
511 const AnalysisState &state) {
512 Operation *readingOp = uRead->getOwner();
513 Operation *conflictingWritingOp = uConflictingWrite->getOwner();
514
515 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
516 // uRead is an InsertSliceOp...
517 if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) {
518 // As an example, consider the following IR.
519 //
520 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
521 // %1 = linalg.fill %cst, %0 {inplace= [true] }
522 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
523 // {inplace= [true] }
524
525 if (uRead == &subsetOp.getDestinationOperand() &&
526 matchesInsertDestination(state, uConflictingWrite, subsetOp))
527 // Case 1: The main insight is that InsertSliceOp reads only part of
528 // the destination tensor. The overwritten area is not read. If
529 // uConflictingWrite writes into exactly the memory location that is
530 // being read by uRead, this is not a conflict.
531 //
532 // In the above example:
533 // uRead = OpOperand 1 (%t) of tensor.insert_slice
534 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
535 //
536 // The read of %t does not conflict with the write of the FillOp
537 // (same aliases!) because the area that the FillOp operates on is
538 // exactly the one that is *not* read via %t.
539 return true;
540
541 if (uRead == &subsetOp.getSourceOperand() &&
542 uConflictingWrite == &subsetOp.getDestinationOperand() &&
543 matchesInsertDestination(state, uRead, subsetOp))
544 // Case 2: The read of the source tensor and the write to the dest
545 // tensor via an InsertSliceOp is not a conflict if the read is
546 // reading exactly that part of an equivalent tensor that the
547 // InsertSliceOp is writing.
548 //
549 // In the above example:
550 // uRead = OpOperand 0 (%1) of tensor.insert_slice
551 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
552 return true;
553 }
554
555 // If uConflictingWrite is an InsertSliceOp...
556 if (auto subsetOp =
557 dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp))
558 // As an example, consider the following IR.
559 //
560 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
561 // %1 = linalg.fill %cst, %0 {inplace= [true] }
562 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
563 // {inplace= [true] }
564 // %3 = vector.transfer_read %1, %cst
565 //
566 // In the above example:
567 // uRead = OpOperand 0 (%1) of vector.transfer_read
568 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
569 // definition = %1
570 //
571 // This is not a conflict because the InsertSliceOp overwrites the
572 // memory segment of %1 with the exact same data. (Effectively, there
573 // is no memory write here.)
574 if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
575 state.areEquivalentBufferizedValues(
576 v1: uRead->get(), v2: subsetOp.getSourceOperand().get()) &&
577 matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp))
578 return true;
579
580 return false;
581}
582
583/// Given sets of uses and writes, return true if there is a RaW conflict under
584/// the assumption that all given reads/writes alias the same buffer and that
585/// all given writes bufferize inplace.
586///
587/// A conflict is: According to SSA use-def chains, a read R is supposed to read
588/// the result of a definition W1. But because of bufferization decisions, R
589/// actually reads another definition W2.
590static bool
591hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
592 const DenseSet<OpOperand *> &usesWrite,
593 const DominanceInfo &domInfo,
594 OneShotAnalysisState &state) {
595 const BufferizationOptions &options = state.getOptions();
596
597 // Before going through the main RaW analysis, find cases where a buffer must
598 // be privatized due to parallelism. If the result of a write is never read,
599 // privatization is not necessary (and large parts of the IR are likely dead).
600 if (options.checkParallelRegions && !usesRead.empty()) {
601 for (OpOperand *uConflictingWrite : usesWrite) {
602 // Find the allocation point or last write (definition) of the buffer.
603 // Note: In contrast to `findDefinitions`, this also returns results of
604 // ops that do not bufferize to memory write when no other definition
605 // could be found. E.g., "bufferization.alloc_tensor" would be included,
606 // even though that op just bufferizes to an allocation but does define
607 // the contents of the buffer.
608 SetVector<Value> definitionsOrLeaves =
609 state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) {
610 return state.bufferizesToMemoryWrite(v);
611 });
612 assert(!definitionsOrLeaves.empty() &&
613 "expected at least one definition or leaf");
614
615 // The writing op must bufferize out-of-place if the definition is in a
616 // different parallel region than this write.
617 for (Value def : definitionsOrLeaves) {
618 if (getParallelRegion(def.getParentRegion(), options) !=
619 getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(),
620 options)) {
621 LLVM_DEBUG(
622 llvm::dbgs()
623 << "\n- bufferizes out-of-place due to parallel region:\n");
624 LLVM_DEBUG(llvm::dbgs()
625 << " unConflictingWrite = operand "
626 << uConflictingWrite->getOperandNumber() << " of "
627 << *uConflictingWrite->getOwner() << "\n");
628 return true;
629 }
630 }
631 }
632 }
633
634 for (OpOperand *uRead : usesRead) {
635 Operation *readingOp = uRead->getOwner();
636 LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
637 LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()
638 << " of " << *readingOp << "\n");
639
640 // Find the definition of uRead by following the SSA use-def chain.
641 // E.g.:
642 //
643 // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
644 // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
645 // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
646 //
647 // In the above example, if uRead is the OpOperand of reading_op, the
648 // definition is %0. Note that operations that create an alias but do not
649 // bufferize to a memory write (such as ExtractSliceOp) are skipped.
650 const SetVector<Value> &definitions = state.findDefinitionsCached(opOperand: uRead);
651 if (definitions.empty()) {
652 // Fast path: No conflict if there are no definitions.
653 LLVM_DEBUG(llvm::dbgs()
654 << " no conflict: read value has no definitions\n");
655 continue;
656 }
657
658 // Look for conflicting memory writes. Potential conflicts are writes to an
659 // alias that have been decided to bufferize inplace.
660 for (OpOperand *uConflictingWrite : usesWrite) {
661 LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "
662 << uConflictingWrite->getOperandNumber() << " of "
663 << *uConflictingWrite->getOwner() << "\n");
664
665 // Check if op dominance can be used to rule out read-after-write
666 // conflicts.
667 bool useDominance =
668 canUseOpDominance(uRead, uConflictingWrite, definitions, state);
669 LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
670
671 // Throughout this loop, check for multiple requirements that have to be
672 // met for uConflictingWrite to be an actual conflict.
673 Operation *conflictingWritingOp = uConflictingWrite->getOwner();
674
675 // Inside of repetitive regions, ops may be executed multiple times and op
676 // dominance cannot be used to rule out conflicts.
677 if (useDominance) {
678 // No conflict if the readingOp dominates conflictingWritingOp, i.e.,
679 // the write is not visible when reading.
680 //
681 // Note: If ops are executed multiple times (e.g., because they are
682 // inside a loop), there may be no meaningful `happensBefore`
683 // relationship.
684 if (happensBefore(a: readingOp, b: conflictingWritingOp, domInfo)) {
685 LLVM_DEBUG(llvm::dbgs()
686 << " no conflict: read happens before write\n");
687 continue;
688 }
689
690 // No conflict if the reading use equals the use of the conflicting
691 // write. A use cannot conflict with itself.
692 //
693 // Note: Just being the same op is not enough. It has to be the same
694 // use.
695 // Note: If the op is executed multiple times (e.g., because it is
696 // inside a loop), it may be conflicting with itself.
697 if (uConflictingWrite == uRead) {
698 LLVM_DEBUG(llvm::dbgs()
699 << " no conflict: read and write are same use\n");
700 continue;
701 }
702
703 // Ops are not conflicting if they are in mutually exclusive regions.
704 //
705 // Note: If ops are executed multiple times (e.g., because they are
706 // inside a loop), mutually exclusive regions may be executed
707 // multiple times.
708 if (state.insideMutuallyExclusiveRegions(readingOp,
709 conflictingWritingOp)) {
710 LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in "
711 "mutually exclusive regions\n");
712 continue;
713 }
714
715 // Two equivalent operands of the same op are not conflicting if the op
716 // bufferizes to element-wise access. I.e., all loads at a position
717 // happen before all stores to the same position.
718 if (conflictingWritingOp == readingOp) {
719 if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
720 if (bufferizableOp.bufferizesToElementwiseAccess(
721 state, {uRead, uConflictingWrite})) {
722 if (hasEquivalentValueInReverseUseDefChain(
723 state, uRead, uConflictingWrite->get()) ||
724 hasEquivalentValueInReverseUseDefChain(
725 state, uConflictingWrite, uRead->get())) {
726 LLVM_DEBUG(
727 llvm::dbgs()
728 << " no conflict: op bufferizes to element-wise access\n");
729 continue;
730 }
731 }
732 }
733 }
734 }
735
736 // No conflict if the operands are non-conflicting subsets.
737 if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
738 LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
739 continue;
740 }
741
742 // No conflict if the op interface says so.
743 if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
744 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
745 LLVM_DEBUG(llvm::dbgs()
746 << " no conflict: op interace of reading op says 'no'\n");
747 continue;
748 }
749 }
750
751 if (conflictingWritingOp != readingOp) {
752 if (auto bufferizableOp =
753 options.dynCastBufferizableOp(conflictingWritingOp)) {
754 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
755 state)) {
756 LLVM_DEBUG(
757 llvm::dbgs()
758 << " no conflict: op interace of writing op says 'no'\n");
759 continue;
760 }
761 }
762 }
763
764 // Check all possible definitions.
765 for (Value definition : definitions) {
766 LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
767
768 // No conflict if the conflicting write happens before the definition.
769 if (Operation *defOp = definition.getDefiningOp()) {
770 if (happensBefore(a: conflictingWritingOp, b: defOp, domInfo)) {
771 // conflictingWritingOp happens before defOp. No conflict.
772 LLVM_DEBUG(llvm::dbgs()
773 << " no conflict: write happens before definition\n");
774 continue;
775 }
776 // No conflict if conflictingWritingOp is contained in defOp.
777 if (defOp->isProperAncestor(other: conflictingWritingOp)) {
778 LLVM_DEBUG(
779 llvm::dbgs()
780 << " no conflict: write is contained in definition\n");
781 continue;
782 }
783 } else {
784 auto bbArg = cast<BlockArgument>(Val&: definition);
785 Block *block = bbArg.getOwner();
786 if (!block->findAncestorOpInBlock(op&: *conflictingWritingOp)) {
787 LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
788 "and write happens outside of block\n");
789 // conflictingWritingOp happens outside of the block. No
790 // conflict.
791 continue;
792 }
793 }
794
795 // No conflict if the conflicting write and the definition are the same
796 // use.
797 AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite);
798 if (aliases.getNumAliases() == 1 &&
799 aliases.getAliases()[0].value == definition) {
800 LLVM_DEBUG(llvm::dbgs()
801 << " no conflict: definition and write are same\n");
802 continue;
803 }
804
805 // All requirements are met. Conflict found!
806
807 if (options.printConflicts)
808 annotateConflict(uRead, uConflictingWrite, definition);
809 LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");
810 return true;
811 }
812 }
813 }
814
815 return false;
816}
817
818// Helper function to iterate on aliases of `root` and capture the writes.
819static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
820 const OneShotAnalysisState &state) {
821 state.applyOnAliases(v: root, fun: [&](Value alias) {
822 for (auto &use : alias.getUses())
823 // Inplace write to a value that aliases root.
824 if (isInplaceMemoryWrite(opOperand&: use, state))
825 res.insert(V: &use);
826 });
827}
828
829// Helper function to iterate on aliases of `root` and capture the reads.
830static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
831 const OneShotAnalysisState &state) {
832 state.applyOnAliases(v: root, fun: [&](Value alias) {
833 for (auto &use : alias.getUses()) {
834 // Read of a value that aliases root.
835 if (state.bufferizesToMemoryRead(use)) {
836 res.insert(V: &use);
837 continue;
838 }
839
840 // Read of a dependent value in the SSA use-def chain. E.g.:
841 //
842 // %0 = ...
843 // %1 = tensor.extract_slice %0 {not_analyzed_yet}
844 // "read"(%1)
845 //
846 // In the above example, getAliasingReads(%0) includes the first OpOperand
847 // of the tensor.extract_slice op. The extract_slice itself does not read
848 // but its aliasing result is eventually fed into an op that does.
849 //
850 // Note: This is considered a "read" only if the use does not bufferize to
851 // a memory write. (We already ruled out memory reads. In case of a memory
852 // write, the buffer would be entirely overwritten; in the above example
853 // there would then be no flow of data from the extract_slice operand to
854 // its result's uses.)
855 if (!state.bufferizesToMemoryWrite(use)) {
856 AliasingValueList aliases = state.getAliasingValues(use);
857 if (llvm::any_of(Range&: aliases, P: [&](AliasingValue a) {
858 return state.isValueRead(a.value);
859 }))
860 res.insert(V: &use);
861 }
862 }
863 });
864}
865
866/// Return true if bufferizing `operand` inplace would create a conflict. A read
867/// R and a write W of the same alias set is a conflict if inplace bufferization
868/// of W changes the value read by R to a value different from the one that
869/// would be expected by tracing back R's origin through SSA use-def chains.
870/// A conflict can only be introduced by a new alias and/or an inplace
871/// bufferization decision.
872///
873/// Example:
874/// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
875/// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
876/// %e = tensor.extract_slice %1
877/// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
878/// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
879///
880/// In the above example, the two TransferWriteOps have already been decided to
881/// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
882/// conflict because:
883/// * According to SSA use-def chains, we expect to read the result of %1.
884/// * However, adding an alias {%0, %t} would mean that the second
885/// TransferWriteOp overwrites the result of the first one. Therefore, the
886/// TransferReadOp would no longer be reading the result of %1.
887///
888/// If `checkConsistencyOnly` is true, this function checks if there is a
889/// read-after-write conflict without bufferizing `operand` inplace. This would
890/// indicate a problem with the current inplace bufferization decisions.
891///
892/// Note: If `checkConsistencyOnly`, this function may be called with a null
893/// OpResult. In that case, only the consistency of bufferization decisions
894/// involving aliases of the given OpOperand are checked.
895static bool wouldCreateReadAfterWriteInterference(
896 OpOperand &operand, const DominanceInfo &domInfo,
897 OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
898 // Collect reads and writes of all aliases of OpOperand and OpResult.
899 DenseSet<OpOperand *> usesRead, usesWrite;
900 getAliasingReads(res&: usesRead, root: operand.get(), state);
901 getAliasingInplaceWrites(res&: usesWrite, root: operand.get(), state);
902 for (AliasingValue alias : state.getAliasingValues(operand)) {
903 getAliasingReads(usesRead, alias.value, state);
904 getAliasingInplaceWrites(usesWrite, alias.value, state);
905 }
906 if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
907 usesWrite.insert(V: &operand);
908
909 return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state);
910}
911
912/// Annotate IR with details about the detected non-writability conflict.
913static void annotateNonWritableTensor(Value value) {
914 static int64_t counter = 0;
915 OpBuilder b(value.getContext());
916 std::string id = "W_" + std::to_string(val: counter++);
917 if (auto opResult = dyn_cast<OpResult>(Val&: value)) {
918 std::string attr = id + "[NOT-WRITABLE: result " +
919 std::to_string(val: opResult.getResultNumber()) + "]";
920 opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr());
921 } else {
922 auto bbArg = cast<BlockArgument>(Val&: value);
923 std::string attr = id + "[NOT-WRITABLE: bbArg " +
924 std::to_string(val: bbArg.getArgNumber()) + "]";
925 bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr());
926 }
927}
928
929/// Return true if bufferizing `operand` inplace would create a write to a
930/// non-writable buffer.
931static bool
932wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
933 OneShotAnalysisState &state,
934 bool checkConsistencyOnly = false) {
935 bool foundWrite =
936 !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);
937
938 if (!foundWrite) {
939 // Collect writes of all aliases of OpOperand and OpResult.
940 DenseSet<OpOperand *> usesWrite;
941 getAliasingInplaceWrites(res&: usesWrite, root: operand.get(), state);
942 for (AliasingValue alias : state.getAliasingValues(operand))
943 getAliasingInplaceWrites(usesWrite, alias.value, state);
944 foundWrite = !usesWrite.empty();
945 }
946
947 if (!foundWrite)
948 return false;
949
950 // Look for a read-only tensor among all aliases.
951 bool foundReadOnly = false;
952 auto checkReadOnly = [&](Value v) {
953 if (!state.isWritable(value: v)) {
954 foundReadOnly = true;
955 if (state.getOptions().printConflicts)
956 annotateNonWritableTensor(value: v);
957 }
958 };
959 state.applyOnAliases(v: operand.get(), fun: checkReadOnly);
960 for (AliasingValue alias : state.getAliasingValues(operand))
961 state.applyOnAliases(alias.value, checkReadOnly);
962 if (foundReadOnly) {
963 LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
964 return true;
965 }
966
967 return false;
968}
969
970//===----------------------------------------------------------------------===//
971// Bufferization analyses.
972//===----------------------------------------------------------------------===//
973
974// Find the values that define the contents of the given operand's value.
975const llvm::SetVector<Value> &
976OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
977 Value value = opOperand->get();
978 if (!cachedDefinitions.count(value))
979 cachedDefinitions[value] = findDefinitions(opOperand);
980 return cachedDefinitions[value];
981}
982
983void OneShotAnalysisState::resetCache() {
984 AnalysisState::resetCache();
985 cachedDefinitions.clear();
986}
987
988/// Determine if `operand` can be bufferized in-place.
989static LogicalResult
990bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
991 const DominanceInfo &domInfo) {
992 LLVM_DEBUG(
993 llvm::dbgs() << "//===-------------------------------------------===//\n"
994 << "Analyzing operand #" << operand.getOperandNumber()
995 << " of " << *operand.getOwner() << "\n");
996
997 bool foundInterference =
998 wouldCreateWriteToNonWritableBuffer(operand, state) ||
999 wouldCreateReadAfterWriteInterference(operand, domInfo, state);
1000
1001 if (foundInterference)
1002 state.bufferizeOutOfPlace(operand);
1003 else
1004 state.bufferizeInPlace(operand);
1005
1006 LLVM_DEBUG(llvm::dbgs()
1007 << "//===-------------------------------------------===//\n");
1008 return success();
1009}
1010
1011LogicalResult
1012OneShotAnalysisState::analyzeSingleOp(Operation *op,
1013 const DominanceInfo &domInfo) {
1014 for (OpOperand &opOperand : op->getOpOperands())
1015 if (isa<TensorType>(Val: opOperand.get().getType()))
1016 if (failed(Result: bufferizableInPlaceAnalysisImpl(operand&: opOperand, state&: *this, domInfo)))
1017 return failure();
1018 return success();
1019}
1020
1021/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
1022static void equivalenceAnalysis(SmallVector<Operation *> &ops,
1023 OneShotAnalysisState &state) {
1024 for (Operation *op : ops) {
1025 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
1026 for (OpResult opResult : op->getOpResults()) {
1027 if (!isa<TensorType>(Val: opResult.getType()))
1028 continue;
1029 AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
1030 if (aliases.getNumAliases() == 0)
1031 // Nothing to do if there are no aliasing OpOperands.
1032 continue;
1033
1034 Value firstOperand = aliases.begin()->opOperand->get();
1035 bool allEquivalent = true;
1036 for (AliasingOpOperand alias : aliases) {
1037 bool isEquiv = alias.relation == BufferRelation::Equivalent;
1038 bool isInPlace = state.isInPlace(*alias.opOperand);
1039 Value operand = alias.opOperand->get();
1040 if (isEquiv && isInPlace && alias.isDefinite) {
1041 // Found a definite, equivalent alias. Merge equivalence sets.
1042 // There can only be one definite alias, so we can stop here.
1043 state.unionEquivalenceClasses(opResult, operand);
1044 allEquivalent = false;
1045 break;
1046 }
1047 if (!isEquiv || !isInPlace)
1048 allEquivalent = false;
1049 if (!state.areEquivalentBufferizedValues(operand, firstOperand))
1050 allEquivalent = false;
1051 }
1052
1053 // If all "maybe" aliases are equivalent and the OpResult is not a new
1054 // allocation, it is a definite, equivalent alias. E.g.:
1055 //
1056 // aliasingOpOperands(%r) = {(%t0, EQUIV, MAYBE), (%t1, EQUIV, MAYBE)}
1057 // aliasingValues(%t0) = {(%r, EQUIV, MAYBE)}
1058 // aliasingValues(%t1) = {(%r, EQUIV, MAYBE)}
1059 // %r = arith.select %c, %t0, %t1 : tensor<?xf32>
1060 //
1061 // If %t0 and %t1 are equivalent, it is safe to union the equivalence
1062 // classes of %r, %t0 and %t1.
1063 if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult))
1064 state.unionEquivalenceClasses(v1: opResult, v2: firstOperand);
1065 }
1066 }
1067 }
1068}
1069
1070/// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
1071/// in `op`.
1072static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) {
1073 // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
1074 SmallVector<Operation *> ops;
1075 op->walk<WalkOrder::PostOrder>(callback: [&](Operation *op) {
1076 // No tensors => no buffers.
1077 if (none_of(Range: op->getResultTypes(), P: isaTensor))
1078 return;
1079 ops.push_back(Elt: op);
1080 });
1081
1082 equivalenceAnalysis(ops, state);
1083}
1084
1085/// "Bottom-up from terminators" heuristic.
1086static SmallVector<Operation *>
1087bottomUpFromTerminatorsHeuristic(Operation *op,
1088 const OneShotAnalysisState &state) {
1089 SetVector<Operation *> traversedOps;
1090
1091 // Find region terminators.
1092 op->walk<WalkOrder::PostOrder>(callback: [&](RegionBranchTerminatorOpInterface term) {
1093 if (!traversedOps.insert(term))
1094 return;
1095 // Follow the reverse SSA use-def chain from each yielded value as long as
1096 // we stay within the same region.
1097 SmallVector<OpResult> worklist;
1098 for (Value v : term->getOperands()) {
1099 if (!isa<TensorType>(v.getType()))
1100 continue;
1101 auto opResult = dyn_cast<OpResult>(v);
1102 if (!opResult)
1103 continue;
1104 worklist.push_back(opResult);
1105 }
1106 while (!worklist.empty()) {
1107 OpResult opResult = worklist.pop_back_val();
1108 Operation *defOp = opResult.getDefiningOp();
1109 if (!traversedOps.insert(X: defOp))
1110 continue;
1111 if (!term->getParentRegion()->findAncestorOpInRegion(*defOp))
1112 continue;
1113 AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
1114 for (auto alias : aliases) {
1115 Value v = alias.opOperand->get();
1116 if (!isa<TensorType>(v.getType()))
1117 continue;
1118 auto opResult = dyn_cast<OpResult>(v);
1119 if (!opResult)
1120 continue;
1121 worklist.push_back(opResult);
1122 }
1123 }
1124 });
1125
1126 // Analyze traversed ops, then all remaining ops.
1127 SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end());
1128 op->walk<WalkOrder::PostOrder, ReverseIterator>(callback: [&](Operation *op) {
1129 if (!traversedOps.contains(key: op) && hasTensorSemantics(op))
1130 result.push_back(Elt: op);
1131 });
1132 return result;
1133}
1134
1135LogicalResult OneShotAnalysisState::analyzeOp(Operation *op,
1136 const DominanceInfo &domInfo) {
1137 OneShotBufferizationOptions::AnalysisHeuristic heuristic =
1138 getOptions().analysisHeuristic;
1139
1140 SmallVector<Operation *> orderedOps;
1141 if (heuristic ==
1142 OneShotBufferizationOptions::AnalysisHeuristic::BottomUpFromTerminators) {
1143 orderedOps = bottomUpFromTerminatorsHeuristic(op, state: *this);
1144 } else {
1145 op->walk(callback: [&](Operation *op) {
1146 // No tensors => no buffers.
1147 if (!hasTensorSemantics(op))
1148 return;
1149 orderedOps.push_back(Elt: op);
1150 });
1151 switch (heuristic) {
1152 case OneShotBufferizationOptions::AnalysisHeuristic::BottomUp: {
1153 // Default: Walk ops in reverse for better interference analysis.
1154 std::reverse(first: orderedOps.begin(), last: orderedOps.end());
1155 break;
1156 }
1157 case OneShotBufferizationOptions::AnalysisHeuristic::TopDown: {
1158 // Ops are already sorted top-down in `orderedOps`.
1159 break;
1160 }
1161 case OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer: {
1162 assert(getOptions().analysisFuzzerSeed &&
1163 "expected that fuzzer seed it set");
1164 // This is a fuzzer. For testing purposes only. Randomize the order in
1165 // which operations are analyzed. The bufferization quality is likely
1166 // worse, but we want to make sure that no assertions are triggered
1167 // anywhere.
1168 std::mt19937 g(getOptions().analysisFuzzerSeed);
1169 llvm::shuffle(first: orderedOps.begin(), last: orderedOps.end(), g);
1170 break;
1171 }
1172 default: {
1173 llvm_unreachable("unsupported heuristic");
1174 }
1175 }
1176 }
1177
1178 // Analyze ops in the computed order.
1179 for (Operation *op : orderedOps)
1180 if (failed(Result: analyzeSingleOp(op, domInfo)))
1181 return failure();
1182
1183 equivalenceAnalysis(op, state&: *this);
1184 return success();
1185}
1186
1187/// Perform various checks on the input IR to see if it contains IR constructs
1188/// that are unsupported by One-Shot Bufferize.
1189static LogicalResult
1190checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo,
1191 OneShotAnalysisState &state) {
1192 const BufferizationOptions &options = state.getOptions();
1193
1194 // Note: This walk cannot be combined with the one below because interface
1195 // methods of invalid/unsupported ops may be called during the second walk.
1196 // (On ops different from `op`.)
1197 WalkResult walkResult = op->walk([&](BufferizableOpInterface op) {
1198 // Skip ops that are not in the filter.
1199 if (!options.isOpAllowed(op: op.getOperation()))
1200 return WalkResult::advance();
1201
1202 // Check for unsupported unstructured control flow.
1203 if (!op.supportsUnstructuredControlFlow()) {
1204 for (Region &r : op->getRegions()) {
1205 if (r.getBlocks().size() > 1) {
1206 op->emitOpError("op or BufferizableOpInterface implementation does "
1207 "not support unstructured control flow, but at least "
1208 "one region has multiple blocks");
1209 return WalkResult::interrupt();
1210 }
1211 }
1212 }
1213
1214 return WalkResult::advance();
1215 });
1216 if (walkResult.wasInterrupted())
1217 return failure();
1218
1219 walkResult = op->walk([&](BufferizableOpInterface op) {
1220 // Skip ops that are not in the filter.
1221 if (!options.isOpAllowed(op: op.getOperation()))
1222 return WalkResult::advance();
1223
1224 // Input IR may not contain any ToTensorOps without the "restrict"
1225 // attribute. Such tensors may alias any other tensor, which is currently
1226 // not handled in the analysis.
1227 if (auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) {
1228 if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) {
1229 op->emitOpError("to_tensor ops without `restrict` are not supported by "
1230 "One-Shot Analysis");
1231 return WalkResult::interrupt();
1232 }
1233 }
1234
1235 for (OpOperand &opOperand : op->getOpOperands()) {
1236 if (isa<TensorType>(opOperand.get().getType())) {
1237 if (wouldCreateReadAfterWriteInterference(
1238 opOperand, domInfo, state,
1239 /*checkConsistencyOnly=*/true)) {
1240 // This error can happen if certain "mustBufferizeInPlace" interface
1241 // methods are implemented incorrectly, such that the IR already has
1242 // a RaW conflict before making any bufferization decisions. It can
1243 // also happen if the bufferization.materialize_in_destination is used
1244 // in such a way that a RaW conflict is not avoidable.
1245 op->emitOpError("not bufferizable under the given constraints: "
1246 "cannot avoid RaW conflict");
1247 return WalkResult::interrupt();
1248 }
1249
1250 if (state.isInPlace(opOperand) &&
1251 wouldCreateWriteToNonWritableBuffer(
1252 opOperand, state, /*checkConsistencyOnly=*/true)) {
1253 op->emitOpError("not bufferizable under the given constraints: would "
1254 "write to read-only buffer");
1255 return WalkResult::interrupt();
1256 }
1257 }
1258 }
1259
1260 return WalkResult::advance();
1261 });
1262
1263 return success(IsSuccess: !walkResult.wasInterrupted());
1264}
1265
1266/// Annotate the IR with the result of the analysis. For testing/debugging only.
1267static void
1268annotateOpsWithBufferizationMarkers(Operation *op,
1269 const OneShotAnalysisState &state) {
1270 // Add __inplace_operands_attr__.
1271 op->walk(callback: [&](Operation *op) {
1272 for (OpOperand &opOperand : op->getOpOperands())
1273 if (isa<TensorType>(Val: opOperand.get().getType()))
1274 setInPlaceOpOperand(opOperand, inPlace: state.isInPlace(opOperand));
1275 });
1276}
1277
1278static void annotateOpsWithAliasSets(Operation *op,
1279 const OneShotAnalysisState &state) {
1280 AsmState asmState(op);
1281 Builder b(op->getContext());
1282 // Helper function to build an array attribute of aliasing SSA value strings.
1283 auto buildAliasesArray = [&](Value v) {
1284 SmallVector<Attribute> aliases;
1285 state.applyOnAliases(v, fun: [&](Value alias) {
1286 std::string buffer;
1287 llvm::raw_string_ostream stream(buffer);
1288 alias.printAsOperand(os&: stream, state&: asmState);
1289 aliases.push_back(b.getStringAttr(buffer));
1290 });
1291 return b.getArrayAttr(aliases);
1292 };
1293
1294 op->walk(callback: [&](Operation *op) {
1295 // Build alias set array for every OpResult.
1296 SmallVector<Attribute> opResultAliasSets;
1297 for (OpResult opResult : op->getOpResults()) {
1298 if (llvm::isa<TensorType>(Val: opResult.getType())) {
1299 opResultAliasSets.push_back(Elt: buildAliasesArray(opResult));
1300 }
1301 }
1302 if (!opResultAliasSets.empty())
1303 op->setAttr(kOpResultAliasSetAttrName, b.getArrayAttr(opResultAliasSets));
1304
1305 // Build alias set array for every BlockArgument.
1306 SmallVector<Attribute> regionAliasSets;
1307 bool hasTensorBbArg = false;
1308 for (Region &r : op->getRegions()) {
1309 SmallVector<Attribute> blockAliasSets;
1310 for (Block &block : r.getBlocks()) {
1311 SmallVector<Attribute> bbArgAliasSets;
1312 for (BlockArgument bbArg : block.getArguments()) {
1313 if (llvm::isa<TensorType>(Val: bbArg.getType())) {
1314 bbArgAliasSets.push_back(Elt: buildAliasesArray(bbArg));
1315 hasTensorBbArg = true;
1316 }
1317 }
1318 blockAliasSets.push_back(b.getArrayAttr(bbArgAliasSets));
1319 }
1320 regionAliasSets.push_back(b.getArrayAttr(blockAliasSets));
1321 }
1322 if (hasTensorBbArg)
1323 op->setAttr(kBbArgAliasSetAttrName, b.getArrayAttr(regionAliasSets));
1324 });
1325}
1326
1327LogicalResult bufferization::analyzeOp(Operation *op,
1328 OneShotAnalysisState &state,
1329 BufferizationStatistics *statistics) {
1330 DominanceInfo domInfo(op);
1331 const OneShotBufferizationOptions &options = state.getOptions();
1332
1333 if (failed(Result: checkPreBufferizationAssumptions(op, domInfo, state)))
1334 return failure();
1335
1336 // If the analysis fails, just return.
1337 if (failed(Result: state.analyzeOp(op, domInfo)))
1338 return failure();
1339
1340 if (statistics) {
1341 statistics->numTensorInPlace = state.getStatNumTensorInPlace();
1342 statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace();
1343 }
1344
1345 bool failedAnalysis = false;
1346
1347 // Gather some extra analysis data.
1348 state.gatherUndefinedTensorUses(op);
1349
1350 // Analysis verification: After setting up alias/equivalence sets, each op
1351 // can check for expected invariants/limitations and fail the analysis if
1352 // necessary.
1353 op->walk(callback: [&](Operation *op) {
1354 if (BufferizableOpInterface bufferizableOp =
1355 options.dynCastBufferizableOp(op))
1356 failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
1357 });
1358
1359 // Annotate operations if we only want to report the analysis.
1360 if (options.testAnalysisOnly)
1361 annotateOpsWithBufferizationMarkers(op, state);
1362 if (options.dumpAliasSets)
1363 annotateOpsWithAliasSets(op, state);
1364
1365 return success(IsSuccess: !failedAnalysis);
1366}
1367
1368LogicalResult bufferization::runOneShotBufferize(
1369 Operation *op, const OneShotBufferizationOptions &options,
1370 BufferizationState &state, BufferizationStatistics *statistics) {
1371 // copy-before-write deactivates the analysis. It cannot be used together with
1372 // test-analysis-only.
1373 assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
1374 "invalid combination of bufferization flags");
1375
1376 if (options.copyBeforeWrite) {
1377 // Copy buffer before each write. No analysis is needed.
1378 } else {
1379 // Run One-Shot Analysis and insert buffer copies (on the tensor level)
1380 // only where needed. This is the default and much more efficient than
1381 // copy-before-write.
1382 if (failed(Result: insertTensorCopies(op, options, bufferizationState: state, statistics)))
1383 return failure();
1384
1385 // If test-analysis-only is set, the IR was annotated with RaW conflict
1386 // markers (attributes) during One-Shot Analysis.
1387 if (options.testAnalysisOnly)
1388 return success();
1389 }
1390
1391 // Bufferize the op and its nested ops. If options.copyBeforeWrite is set,
1392 // a new buffer copy is allocated every time a buffer is written to.
1393 return bufferizeOp(op, options, state, statistics);
1394}
1395

source code of mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp