1//===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// One-Shot Analysis analyzes function bodies. By default, function boundaries
10// (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
11// OneShotModuleBufferization.cpp is an extension of One-Shot Analysis for
12// simple call graphs without loops.
13//
14// One-Shot Bufferize consists of three phases.
15//
16// 1. Analyze ops to decide which OpOperands can bufferize inplace, i.e.,
17// without inserting buffer copies. The analysis queries op bufferization
18// semantics via `BufferizableOpInterface`.
19// 2. Insert copies for OpOperands that were decided to bufferize out-of-place
20// in tensor land during `TensorCopyInsertion`.
21// 3. Bufferize ops by calling `BufferizableOpInterface::bufferize`.
22//
23// This file contains only the analysis. For convenience, this file also
24// contains a helper function `runOneShotBufferize` that analyzes an op (and its
25// nested ops) and then bufferizes it.
26//
27// Inplace bufferization decisions are passed from the analysis to the
28// `TensorCopyInsertion` phase via `AnalysisState`. They can be printed for
29// debugging purposes with `testAnalysisOnly`.
30//
31// Ops that do not implement `BufferizableOpInterface` can be analyzed but are
32// treated conservatively. E.g., the analysis has to assume that their tensor
33// OpOperands bufferize to memory writes. While such ops can be analyzed, they
34// are not bufferized and remain in the IR. to_tensor and to_buffer ops are
35// inserted at the bufferization boundary.
36//
37// This analysis caters to high-performance codegen where buffer reuse is deemed
38// critical: the analysis should fail if the bufferized form of the function
39// needs to return a buffer, unless `allowReturnAllocs` is enabled.
40
41#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
42
43#include <random>
44
45#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
46#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
47#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
48#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
49#include "mlir/Dialect/MemRef/IR/MemRef.h"
50#include "mlir/IR/AsmState.h"
51#include "mlir/IR/Dominance.h"
52#include "mlir/IR/Iterators.h"
53#include "mlir/IR/Operation.h"
54#include "mlir/IR/TypeUtilities.h"
55#include "mlir/Interfaces/ControlFlowInterfaces.h"
56#include "mlir/Interfaces/SubsetOpInterface.h"
57#include "llvm/ADT/DenseSet.h"
58#include "llvm/ADT/SetVector.h"
59
60MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
61
62// Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug
63// output.
64#define DEBUG_TYPE "one-shot-analysis"
65
66using namespace mlir;
67using namespace mlir::bufferization;
68
69static bool isaTensor(Type t) { return isa<TensorType>(Val: t); }
70
71//===----------------------------------------------------------------------===//
72// Bufferization-specific attribute manipulation.
73// These are for testing and debugging only. Bufferization information is stored
74// in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is
75// annotated with the results of the analysis, so that they can be checked in
76// tests.
77//===----------------------------------------------------------------------===//
78
79/// Attribute marker to specify op operands that bufferize in-place.
80constexpr StringLiteral kInPlaceOperandsAttrName = "__inplace_operands_attr__";
81
82constexpr StringLiteral kOpResultAliasSetAttrName =
83 "__opresult_alias_set_attr__";
84
85constexpr StringLiteral kBbArgAliasSetAttrName = "__bbarg_alias_set_attr__";
86
87/// Mark whether OpOperand will be bufferized inplace.
88static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
89 Operation *op = opOperand.getOwner();
90 SmallVector<StringRef> inPlaceVector;
91 if (auto attr = op->getAttr(name: kInPlaceOperandsAttrName)) {
92 inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>(
93 Range: cast<ArrayAttr>(Val&: attr).getAsValueRange<StringAttr>()));
94 } else {
95 inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
96 for (OpOperand &opOperand : op->getOpOperands())
97 if (isa<TensorType>(Val: opOperand.get().getType()))
98 inPlaceVector[opOperand.getOperandNumber()] = "false";
99 }
100 inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
101 op->setAttr(name: kInPlaceOperandsAttrName,
102 value: OpBuilder(op).getStrArrayAttr(values: inPlaceVector));
103}
104
105//===----------------------------------------------------------------------===//
106// OneShotAnalysisState
107//===----------------------------------------------------------------------===//
108
109OneShotAnalysisState::OneShotAnalysisState(
110 Operation *op, const OneShotBufferizationOptions &options)
111 : AnalysisState(options, TypeID::get<OneShotAnalysisState>()) {
112 // Set up alias sets.
113 op->walk(callback: [&](Operation *op) {
114 for (Value v : op->getResults())
115 if (isa<TensorType>(Val: v.getType()))
116 createAliasInfoEntry(v);
117 for (Region &r : op->getRegions())
118 for (Block &b : r.getBlocks())
119 for (auto bbArg : b.getArguments())
120 if (isa<TensorType>(Val: bbArg.getType()))
121 createAliasInfoEntry(v: bbArg);
122 });
123
124 // Mark OpOperands in-place that must bufferize in-place.
125 op->walk(callback: [&](BufferizableOpInterface bufferizableOp) {
126 if (!options.isOpAllowed(op: bufferizableOp))
127 return WalkResult::skip();
128 for (OpOperand &opOperand : bufferizableOp->getOpOperands())
129 if (isa<TensorType>(Val: opOperand.get().getType()))
130 if (bufferizableOp.mustBufferizeInPlace(opOperand, state: *this))
131 bufferizeInPlace(operand&: opOperand);
132 return WalkResult::advance();
133 });
134}
135
136void OneShotAnalysisState::applyOnEquivalenceClass(
137 Value v, function_ref<void(Value)> fun) const {
138 auto leaderIt = equivalentInfo.findLeader(V: v);
139 for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
140 ++mit) {
141 fun(*mit);
142 }
143}
144
145void OneShotAnalysisState::applyOnAliases(Value v,
146 function_ref<void(Value)> fun) const {
147 auto leaderIt = aliasInfo.findLeader(V: v);
148 for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
149 fun(*mit);
150 }
151}
152
153bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
154 Value v2) const {
155 return equivalentInfo.isEquivalent(V1: v1, V2: v2);
156}
157
158bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
159 Value v2) const {
160 return aliasInfo.isEquivalent(V1: v1, V2: v2);
161}
162
163void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) {
164 if (inplaceBufferized.contains(V: &operand))
165 return;
166 inplaceBufferized.insert(V: &operand);
167 for (AliasingValue alias : getAliasingValues(opOperand&: operand))
168 aliasInfo.unionSets(V1: alias.value, V2: operand.get());
169 ++statNumTensorInPlace;
170}
171
172void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) {
173 assert(!inplaceBufferized.contains(&operand) &&
174 "OpOperand was already decided to bufferize inplace");
175 ++statNumTensorOutOfPlace;
176}
177
178void OneShotAnalysisState::createAliasInfoEntry(Value v) {
179 aliasInfo.insert(Data: v);
180 equivalentInfo.insert(Data: v);
181}
182
183void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
184 op->walk(callback: [&](Operation *op) {
185 // Skip unknown ops.
186 auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
187 if (!bufferizableOp)
188 return WalkResult::skip();
189
190 // Check all tensor OpResults.
191 for (OpResult opResult : op->getOpResults()) {
192 if (!isa<TensorType>(Val: opResult.getType()))
193 continue;
194
195 // If there is no preceding definition, the tensor contents are
196 // undefined.
197 if (opResult.getUses().empty())
198 continue;
199 // It does not really matter which use to take to search about
200 // the value's definitions.
201 OpOperand *opOperand = &(*opResult.getUses().begin());
202 if (findDefinitionsCached(opOperand).empty())
203 for (OpOperand &use : opResult.getUses())
204 undefinedTensorUses.insert(V: &use);
205 }
206
207 return WalkResult::advance();
208 });
209}
210
211bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
212 return undefinedTensorUses.contains(V: opOperand);
213}
214
215bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
216 return inplaceBufferized.contains(V: &opOperand);
217}
218
219bool OneShotAnalysisState::isValueWritten(Value value) const {
220 bool isWritten = false;
221 applyOnAliases(v: value, fun: [&](Value val) {
222 for (OpOperand &use : val.getUses())
223 if (isInPlace(opOperand&: use) && bufferizesToMemoryWrite(opOperand&: use))
224 isWritten = true;
225 });
226 return isWritten;
227}
228
229bool OneShotAnalysisState::isWritable(Value value) const {
230 // TODO: Out-of-place bufferized value could be considered writable.
231 // Query BufferizableOpInterface to see if the BlockArgument is writable.
232 if (auto bufferizableOp =
233 getOptions().dynCastBufferizableOp(op: getOwnerOfValue(value)))
234 return bufferizableOp.isWritable(value, state: *this);
235
236 // Not a bufferizable op: The conservative answer is "not writable".
237 return false;
238}
239
240void OneShotAnalysisState::unionAliasSets(Value v1, Value v2) {
241 aliasInfo.unionSets(V1: v1, V2: v2);
242}
243
244void OneShotAnalysisState::unionEquivalenceClasses(Value v1, Value v2) {
245 equivalentInfo.unionSets(V1: v1, V2: v2);
246}
247
248OneShotAnalysisState::Extension::~Extension() = default;
249
250//===----------------------------------------------------------------------===//
251// Bufferization-specific alias analysis.
252//===----------------------------------------------------------------------===//
253
254/// Return true if opOperand has been decided to bufferize in-place.
255static bool isInplaceMemoryWrite(OpOperand &opOperand,
256 const OneShotAnalysisState &state) {
257 // OpOperands that do not bufferize to a memory write do not write in-place.
258 if (!state.bufferizesToMemoryWrite(opOperand))
259 return false;
260 // Check current bufferization decisions.
261 return state.isInPlace(opOperand);
262}
263
264/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
265/// properly dominates `b` and `b` is not inside `a`.
266static bool happensBefore(Operation *a, Operation *b,
267 const DominanceInfo &domInfo) {
268 do {
269 // TODO: Instead of isProperAncestor + properlyDominates, we should use
270 // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
271 if (a->isProperAncestor(other: b))
272 return false;
273 if (domInfo.properlyDominates(a, b))
274 return true;
275 } while ((a = a->getParentOp()));
276 return false;
277}
278
279/// Return `true` if op dominance can be used to rule out a read-after-write
280/// conflicts based on the ordering of ops. Returns `false` if op dominance
281/// cannot be used to due region-based loops.
282///
283/// Generalized op dominance can often be used to rule out potential conflicts
284/// due to "read happens before write". E.g., the following IR is not a RaW
285/// conflict because the read happens *before* the write.
286///
287/// Example 1:
288/// %0 = ... : tensor<?xf32> // DEF
289/// "reading_op"(%0) : tensor<?xf32> // READ
290/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
291///
292/// This is no longer true inside loops (or repetitive regions). In such cases,
293/// there may not be a meaningful `happensBefore` relationship because ops
294/// could be executed multiple times. E.g.:
295///
296/// Example 2:
297/// %0 = ... : tensor<?xf32> // DEF
298/// scf.for ... {
299/// "reading_op"(%0) : tensor<?xf32> // READ
300/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
301/// ...
302/// }
303///
304/// In the above example, reading_op happens before writing_op according to
305/// op dominance. However, both ops may happen multiple times; in
306/// particular, the second execution of reading_op happens after the first
307/// execution of writing_op. This is problematic because the tensor %0 they
308/// operate on (i.e., the "definition") is defined outside of the loop.
309///
310/// On a high-level, there is a potential RaW in a program if there exists a
311/// possible program execution such that there is a sequence of DEF, followed
312/// by WRITE, followed by READ. Each additional DEF resets the sequence.
313///
314/// E.g.:
315/// No conflict: DEF, WRITE, DEF, READ
316/// Potential conflict: DEF, READ, WRITE, READ, WRITE
317///
318/// Example 1 has no conflict: DEF, READ, WRITE
319/// Example 2 has a potential conflict: DEF, (READ, WRITE)*
320//
321/// Example 3:
322/// scf.for ... {
323/// %0 = ... : tensor<?xf32>
324/// "reading_op"(%0) : tensor<?xf32>
325/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
326/// ...
327/// }
328/// This has no conflict: (DEF, READ, WRITE)*
329///
330/// Example 4:
331/// %0 = ... : tensor<?xf32>
332/// scf.for ... {
333/// scf.for ... { "reading_op"(%0) }
334/// %1 = "writing_op"(%0)
335/// }
336/// This has a potential conflict: DEF, ((READ)*, WRITE)*
337///
338/// Example 5:
339/// %0 = ... : tensor<?xf32>
340/// scf.for ... { %1 = "writing_op"(%0) }
341/// scf.for ... { "reading_op"(%0) }
342/// This has a potential conflict: DEF, WRITE*, READ*
343///
344/// The following rules are used to rule out RaW conflicts via ordering of ops:
345///
346/// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of
347/// a repetitive region that enclosing both READ and WRITE, we cannot rule
348/// out RaW conflict due to the ordering of ops.
349/// 2. Otherwise: There are no loops that interfere with our analysis; for
350/// analysis purposes, we can assume that there are no loops/repetitive
351/// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
352/// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
353///
354static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite,
355 const SetVector<Value> &definitions,
356 AnalysisState &state) {
357 const BufferizationOptions &options = state.getOptions();
358 for (Value def : definitions) {
359 Region *rRead =
360 state.getEnclosingRepetitiveRegion(op: uRead->getOwner(), options);
361 Region *rDef = state.getEnclosingRepetitiveRegion(value: def, options);
362
363 // READ and DEF are in the same repetitive region. `happensBefore` can be
364 // used to rule out RaW conflicts due to op ordering.
365 if (rRead == rDef)
366 continue;
367
368 // Find the enclosing repetitive region of READ that is closest to DEF but
369 // not the repetitive region of DEF itself.
370 while (true) {
371 Region *nextRegion = getNextEnclosingRepetitiveRegion(region: rRead, options);
372 if (nextRegion == rDef)
373 break;
374 assert(nextRegion && "expected to find another repetitive region");
375 rRead = nextRegion;
376 }
377
378 // We cannot use op dominance if WRITE is inside the same repetitive region.
379 if (rRead->getParentOp()->isAncestor(other: uWrite->getOwner()))
380 return false;
381 }
382
383 return true;
384}
385
386/// Return `true` if op dominance can be used to rule out a read-after-write
387/// conflicts based on the ordering of ops. Returns `false` if op dominance
388/// cannot be used to due block-based loops within a region.
389///
390/// Refer to the `canUseOpDominanceDueToRegions` documentation for details on
391/// how op domiance is used during RaW conflict detection.
392///
393/// On a high-level, there is a potential RaW in a program if there exists a
394/// possible program execution such that there is a sequence of DEF, followed
395/// by WRITE, followed by READ. Each additional DEF resets the sequence.
396///
397/// Op dominance cannot be used if there is a path from block(READ) to
398/// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should
399/// not appear on that path.
400static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite,
401 const SetVector<Value> &definitions,
402 AnalysisState &state) {
403 // Fast path: If READ and WRITE are in different regions, their block cannot
404 // be reachable just via unstructured control flow. (Loops due to regions are
405 // covered by `canUseOpDominanceDueToRegions`.)
406 if (uRead->getOwner()->getParentRegion() !=
407 uWrite->getOwner()->getParentRegion())
408 return true;
409
410 Block *readBlock = uRead->getOwner()->getBlock();
411 Block *writeBlock = uWrite->getOwner()->getBlock();
412 for (Value def : definitions) {
413 Block *defBlock = def.getParentBlock();
414 if (readBlock->isReachable(other: writeBlock, except: {defBlock}) &&
415 writeBlock->isReachable(other: readBlock, except: {defBlock}))
416 return false;
417 }
418
419 return true;
420}
421
422static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
423 const SetVector<Value> &definitions,
424 AnalysisState &state) {
425 return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) &&
426 canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state);
427}
428
429/// Annotate IR with details about the detected RaW conflict.
430static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
431 Value definition) {
432 static uint64_t counter = 0;
433 Operation *readingOp = uRead->getOwner();
434 Operation *conflictingWritingOp = uConflictingWrite->getOwner();
435
436 OpBuilder b(conflictingWritingOp->getContext());
437 std::string id = "C_" + std::to_string(val: counter++);
438
439 std::string conflictingWriteAttr =
440 id +
441 "[CONFL-WRITE: " + std::to_string(val: uConflictingWrite->getOperandNumber()) +
442 "]";
443 conflictingWritingOp->setAttr(name: conflictingWriteAttr, value: b.getUnitAttr());
444
445 std::string readAttr =
446 id + "[READ: " + std::to_string(val: uRead->getOperandNumber()) + "]";
447 readingOp->setAttr(name: readAttr, value: b.getUnitAttr());
448
449 if (auto opResult = dyn_cast<OpResult>(Val&: definition)) {
450 std::string defAttr =
451 id + "[DEF: result " + std::to_string(val: opResult.getResultNumber()) + "]";
452 opResult.getDefiningOp()->setAttr(name: defAttr, value: b.getUnitAttr());
453 } else {
454 auto bbArg = cast<BlockArgument>(Val&: definition);
455 std::string defAttr =
456 id + "[DEF: bbArg " + std::to_string(val: bbArg.getArgNumber()) + "]";
457 bbArg.getOwner()->getParentOp()->setAttr(name: defAttr, value: b.getUnitAttr());
458 }
459}
460
461/// Return 'true' if a tensor that is equivalent to `other` can be found in the
462/// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of
463/// place along that use-def chain, the two tensors may not materialize as
464/// equivalent buffers (but separate allocations).
465///
466/// Note: This function also requires that the two tensors have equivalent
467/// indexing. I.e., the tensor types do not change along the use-def chain,
468/// apart from static <-> dynamic dim casts.
469static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
470 OpOperand *start,
471 Value other) {
472 TraversalConfig config;
473 config.followEquivalentOnly = true;
474 config.alwaysIncludeLeaves = false;
475 config.followSameTypeOrCastsOnly = true;
476 return !state
477 .findValueInReverseUseDefChain(
478 opOperand: start, condition: [&](Value v) { return v == other; }, config)
479 .empty();
480}
481
482/// Return "true" if the given operand's value is originating from a subset
483/// that is equivalent to the subset that `subsetOp` inserts into.
484static bool matchesInsertDestination(const AnalysisState &state,
485 OpOperand *opOperand,
486 SubsetInsertionOpInterface subsetOp) {
487 auto matchingSubset = [&](Value val) {
488 if (auto opResult = dyn_cast<OpResult>(Val&: val))
489 if (subsetOp.isEquivalentSubset(candidate: opResult, equivalenceFn: [&](Value v1, Value v2) {
490 return state.areEquivalentBufferizedValues(v1, v2);
491 }))
492 return true;
493 return false;
494 };
495 // There may be multiple leaves at which the reverse SSA use-def chain lookup
496 // terminates. All of them must be equivalent subsets.
497 SetVector<Value> backwardSlice =
498 state.findValueInReverseUseDefChain(opOperand, condition: matchingSubset);
499 return static_cast<bool>(llvm::all_of(Range&: backwardSlice, P: matchingSubset));
500}
501
502/// Return "true" if the given "read" and potentially conflicting "write" are
503/// not conflicting due to their subset relationship. The comments in this
504/// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
505/// pairs, but apply to any subset ops that implement the
506/// `SubsetInsertionOpInterface`.
507static bool areNonConflictingSubsets(OpOperand *uRead,
508 OpOperand *uConflictingWrite,
509 const AnalysisState &state) {
510 Operation *readingOp = uRead->getOwner();
511 Operation *conflictingWritingOp = uConflictingWrite->getOwner();
512
513 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
514 // uRead is an InsertSliceOp...
515 if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(Val: readingOp)) {
516 // As an example, consider the following IR.
517 //
518 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
519 // %1 = linalg.fill %cst, %0 {inplace= [true] }
520 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
521 // {inplace= [true] }
522
523 if (uRead == &subsetOp.getDestinationOperand() &&
524 matchesInsertDestination(state, opOperand: uConflictingWrite, subsetOp))
525 // Case 1: The main insight is that InsertSliceOp reads only part of
526 // the destination tensor. The overwritten area is not read. If
527 // uConflictingWrite writes into exactly the memory location that is
528 // being read by uRead, this is not a conflict.
529 //
530 // In the above example:
531 // uRead = OpOperand 1 (%t) of tensor.insert_slice
532 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
533 //
534 // The read of %t does not conflict with the write of the FillOp
535 // (same aliases!) because the area that the FillOp operates on is
536 // exactly the one that is *not* read via %t.
537 return true;
538
539 if (uRead == &subsetOp.getSourceOperand() &&
540 uConflictingWrite == &subsetOp.getDestinationOperand() &&
541 matchesInsertDestination(state, opOperand: uRead, subsetOp))
542 // Case 2: The read of the source tensor and the write to the dest
543 // tensor via an InsertSliceOp is not a conflict if the read is
544 // reading exactly that part of an equivalent tensor that the
545 // InsertSliceOp is writing.
546 //
547 // In the above example:
548 // uRead = OpOperand 0 (%1) of tensor.insert_slice
549 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
550 return true;
551 }
552
553 // If uConflictingWrite is an InsertSliceOp...
554 if (auto subsetOp =
555 dyn_cast<SubsetInsertionOpInterface>(Val: conflictingWritingOp))
556 // As an example, consider the following IR.
557 //
558 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
559 // %1 = linalg.fill %cst, %0 {inplace= [true] }
560 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
561 // {inplace= [true] }
562 // %3 = vector.transfer_read %1, %cst
563 //
564 // In the above example:
565 // uRead = OpOperand 0 (%1) of vector.transfer_read
566 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
567 // definition = %1
568 //
569 // This is not a conflict because the InsertSliceOp overwrites the
570 // memory segment of %1 with the exact same data. (Effectively, there
571 // is no memory write here.)
572 if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
573 state.areEquivalentBufferizedValues(
574 v1: uRead->get(), v2: subsetOp.getSourceOperand().get()) &&
575 matchesInsertDestination(state, opOperand: &subsetOp.getSourceOperand(), subsetOp))
576 return true;
577
578 return false;
579}
580
581/// Given sets of uses and writes, return true if there is a RaW conflict under
582/// the assumption that all given reads/writes alias the same buffer and that
583/// all given writes bufferize inplace.
584///
585/// A conflict is: According to SSA use-def chains, a read R is supposed to read
586/// the result of a definition W1. But because of bufferization decisions, R
587/// actually reads another definition W2.
588static bool
589hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
590 const DenseSet<OpOperand *> &usesWrite,
591 const DominanceInfo &domInfo,
592 OneShotAnalysisState &state) {
593 const BufferizationOptions &options = state.getOptions();
594
595 // Before going through the main RaW analysis, find cases where a buffer must
596 // be privatized due to parallelism. If the result of a write is never read,
597 // privatization is not necessary (and large parts of the IR are likely dead).
598 if (options.checkParallelRegions && !usesRead.empty()) {
599 for (OpOperand *uConflictingWrite : usesWrite) {
600 // Find the allocation point or last write (definition) of the buffer.
601 // Note: In contrast to `findDefinitions`, this also returns results of
602 // ops that do not bufferize to memory write when no other definition
603 // could be found. E.g., "bufferization.alloc_tensor" would be included,
604 // even though that op just bufferizes to an allocation but does define
605 // the contents of the buffer.
606 SetVector<Value> definitionsOrLeaves =
607 state.findValueInReverseUseDefChain(opOperand: uConflictingWrite, condition: [&](Value v) {
608 return state.bufferizesToMemoryWrite(value: v);
609 });
610 assert(!definitionsOrLeaves.empty() &&
611 "expected at least one definition or leaf");
612
613 // The writing op must bufferize out-of-place if the definition is in a
614 // different parallel region than this write.
615 for (Value def : definitionsOrLeaves) {
616 if (getParallelRegion(region: def.getParentRegion(), options) !=
617 getParallelRegion(region: uConflictingWrite->getOwner()->getParentRegion(),
618 options)) {
619 LLVM_DEBUG(
620 llvm::dbgs()
621 << "\n- bufferizes out-of-place due to parallel region:\n");
622 LLVM_DEBUG(llvm::dbgs()
623 << " unConflictingWrite = operand "
624 << uConflictingWrite->getOperandNumber() << " of "
625 << *uConflictingWrite->getOwner() << "\n");
626 return true;
627 }
628 }
629 }
630 }
631
632 for (OpOperand *uRead : usesRead) {
633 Operation *readingOp = uRead->getOwner();
634 LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
635 LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()
636 << " of " << *readingOp << "\n");
637
638 // Find the definition of uRead by following the SSA use-def chain.
639 // E.g.:
640 //
641 // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
642 // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
643 // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
644 //
645 // In the above example, if uRead is the OpOperand of reading_op, the
646 // definition is %0. Note that operations that create an alias but do not
647 // bufferize to a memory write (such as ExtractSliceOp) are skipped.
648 const SetVector<Value> &definitions = state.findDefinitionsCached(opOperand: uRead);
649 if (definitions.empty()) {
650 // Fast path: No conflict if there are no definitions.
651 LLVM_DEBUG(llvm::dbgs()
652 << " no conflict: read value has no definitions\n");
653 continue;
654 }
655
656 // Look for conflicting memory writes. Potential conflicts are writes to an
657 // alias that have been decided to bufferize inplace.
658 for (OpOperand *uConflictingWrite : usesWrite) {
659 LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "
660 << uConflictingWrite->getOperandNumber() << " of "
661 << *uConflictingWrite->getOwner() << "\n");
662
663 // Check if op dominance can be used to rule out read-after-write
664 // conflicts.
665 bool useDominance =
666 canUseOpDominance(uRead, uWrite: uConflictingWrite, definitions, state);
667 LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
668
669 // Throughout this loop, check for multiple requirements that have to be
670 // met for uConflictingWrite to be an actual conflict.
671 Operation *conflictingWritingOp = uConflictingWrite->getOwner();
672
673 // Inside of repetitive regions, ops may be executed multiple times and op
674 // dominance cannot be used to rule out conflicts.
675 if (useDominance) {
676 // No conflict if the readingOp dominates conflictingWritingOp, i.e.,
677 // the write is not visible when reading.
678 //
679 // Note: If ops are executed multiple times (e.g., because they are
680 // inside a loop), there may be no meaningful `happensBefore`
681 // relationship.
682 if (happensBefore(a: readingOp, b: conflictingWritingOp, domInfo)) {
683 LLVM_DEBUG(llvm::dbgs()
684 << " no conflict: read happens before write\n");
685 continue;
686 }
687
688 // No conflict if the reading use equals the use of the conflicting
689 // write. A use cannot conflict with itself.
690 //
691 // Note: Just being the same op is not enough. It has to be the same
692 // use.
693 // Note: If the op is executed multiple times (e.g., because it is
694 // inside a loop), it may be conflicting with itself.
695 if (uConflictingWrite == uRead) {
696 LLVM_DEBUG(llvm::dbgs()
697 << " no conflict: read and write are same use\n");
698 continue;
699 }
700
701 // Ops are not conflicting if they are in mutually exclusive regions.
702 //
703 // Note: If ops are executed multiple times (e.g., because they are
704 // inside a loop), mutually exclusive regions may be executed
705 // multiple times.
706 if (state.insideMutuallyExclusiveRegions(op0: readingOp,
707 op1: conflictingWritingOp)) {
708 LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in "
709 "mutually exclusive regions\n");
710 continue;
711 }
712
713 // Two equivalent operands of the same op are not conflicting if the op
714 // bufferizes to element-wise access. I.e., all loads at a position
715 // happen before all stores to the same position.
716 if (conflictingWritingOp == readingOp) {
717 if (auto bufferizableOp = options.dynCastBufferizableOp(op: readingOp)) {
718 if (bufferizableOp.bufferizesToElementwiseAccess(
719 state, opOperands: {uRead, uConflictingWrite})) {
720 if (hasEquivalentValueInReverseUseDefChain(
721 state, start: uRead, other: uConflictingWrite->get()) ||
722 hasEquivalentValueInReverseUseDefChain(
723 state, start: uConflictingWrite, other: uRead->get())) {
724 LLVM_DEBUG(
725 llvm::dbgs()
726 << " no conflict: op bufferizes to element-wise access\n");
727 continue;
728 }
729 }
730 }
731 }
732 }
733
734 // No conflict if the operands are non-conflicting subsets.
735 if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
736 LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
737 continue;
738 }
739
740 // No conflict if the op interface says so.
741 if (auto bufferizableOp = options.dynCastBufferizableOp(op: readingOp)) {
742 if (bufferizableOp.isNotConflicting(uRead, uWrite: uConflictingWrite, state)) {
743 LLVM_DEBUG(llvm::dbgs()
744 << " no conflict: op interace of reading op says 'no'\n");
745 continue;
746 }
747 }
748
749 if (conflictingWritingOp != readingOp) {
750 if (auto bufferizableOp =
751 options.dynCastBufferizableOp(op: conflictingWritingOp)) {
752 if (bufferizableOp.isNotConflicting(uRead, uWrite: uConflictingWrite,
753 state)) {
754 LLVM_DEBUG(
755 llvm::dbgs()
756 << " no conflict: op interace of writing op says 'no'\n");
757 continue;
758 }
759 }
760 }
761
762 // Check all possible definitions.
763 for (Value definition : definitions) {
764 LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
765
766 // No conflict if the conflicting write happens before the definition.
767 if (Operation *defOp = definition.getDefiningOp()) {
768 if (happensBefore(a: conflictingWritingOp, b: defOp, domInfo)) {
769 // conflictingWritingOp happens before defOp. No conflict.
770 LLVM_DEBUG(llvm::dbgs()
771 << " no conflict: write happens before definition\n");
772 continue;
773 }
774 // No conflict if conflictingWritingOp is contained in defOp.
775 if (defOp->isProperAncestor(other: conflictingWritingOp)) {
776 LLVM_DEBUG(
777 llvm::dbgs()
778 << " no conflict: write is contained in definition\n");
779 continue;
780 }
781 } else {
782 auto bbArg = cast<BlockArgument>(Val&: definition);
783 Block *block = bbArg.getOwner();
784 if (!block->findAncestorOpInBlock(op&: *conflictingWritingOp)) {
785 LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
786 "and write happens outside of block\n");
787 // conflictingWritingOp happens outside of the block. No
788 // conflict.
789 continue;
790 }
791 }
792
793 // No conflict if the conflicting write and the definition are the same
794 // use.
795 AliasingValueList aliases = state.getAliasingValues(opOperand&: *uConflictingWrite);
796 if (aliases.getNumAliases() == 1 &&
797 aliases.getAliases()[0].value == definition) {
798 LLVM_DEBUG(llvm::dbgs()
799 << " no conflict: definition and write are same\n");
800 continue;
801 }
802
803 // All requirements are met. Conflict found!
804
805 if (options.printConflicts)
806 annotateConflict(uRead, uConflictingWrite, definition);
807 LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");
808 return true;
809 }
810 }
811 }
812
813 return false;
814}
815
816// Helper function to iterate on aliases of `root` and capture the writes.
817static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
818 const OneShotAnalysisState &state) {
819 state.applyOnAliases(v: root, fun: [&](Value alias) {
820 for (auto &use : alias.getUses())
821 // Inplace write to a value that aliases root.
822 if (isInplaceMemoryWrite(opOperand&: use, state))
823 res.insert(V: &use);
824 });
825}
826
827// Helper function to iterate on aliases of `root` and capture the reads.
828static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
829 const OneShotAnalysisState &state) {
830 state.applyOnAliases(v: root, fun: [&](Value alias) {
831 for (auto &use : alias.getUses()) {
832 // Read of a value that aliases root.
833 if (state.bufferizesToMemoryRead(opOperand&: use)) {
834 res.insert(V: &use);
835 continue;
836 }
837
838 // Read of a dependent value in the SSA use-def chain. E.g.:
839 //
840 // %0 = ...
841 // %1 = tensor.extract_slice %0 {not_analyzed_yet}
842 // "read"(%1)
843 //
844 // In the above example, getAliasingReads(%0) includes the first OpOperand
845 // of the tensor.extract_slice op. The extract_slice itself does not read
846 // but its aliasing result is eventually fed into an op that does.
847 //
848 // Note: This is considered a "read" only if the use does not bufferize to
849 // a memory write. (We already ruled out memory reads. In case of a memory
850 // write, the buffer would be entirely overwritten; in the above example
851 // there would then be no flow of data from the extract_slice operand to
852 // its result's uses.)
853 if (!state.bufferizesToMemoryWrite(opOperand&: use)) {
854 AliasingValueList aliases = state.getAliasingValues(opOperand&: use);
855 if (llvm::any_of(Range&: aliases, P: [&](AliasingValue a) {
856 return state.isValueRead(value: a.value);
857 }))
858 res.insert(V: &use);
859 }
860 }
861 });
862}
863
864/// Return true if bufferizing `operand` inplace would create a conflict. A read
865/// R and a write W of the same alias set is a conflict if inplace bufferization
866/// of W changes the value read by R to a value different from the one that
867/// would be expected by tracing back R's origin through SSA use-def chains.
868/// A conflict can only be introduced by a new alias and/or an inplace
869/// bufferization decision.
870///
871/// Example:
872/// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
873/// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
874/// %e = tensor.extract_slice %1
875/// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
876/// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
877///
878/// In the above example, the two TransferWriteOps have already been decided to
879/// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
880/// conflict because:
881/// * According to SSA use-def chains, we expect to read the result of %1.
882/// * However, adding an alias {%0, %t} would mean that the second
883/// TransferWriteOp overwrites the result of the first one. Therefore, the
884/// TransferReadOp would no longer be reading the result of %1.
885///
886/// If `checkConsistencyOnly` is true, this function checks if there is a
887/// read-after-write conflict without bufferizing `operand` inplace. This would
888/// indicate a problem with the current inplace bufferization decisions.
889///
890/// Note: If `checkConsistencyOnly`, this function may be called with a null
891/// OpResult. In that case, only the consistency of bufferization decisions
892/// involving aliases of the given OpOperand are checked.
893static bool wouldCreateReadAfterWriteInterference(
894 OpOperand &operand, const DominanceInfo &domInfo,
895 OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
896 // Collect reads and writes of all aliases of OpOperand and OpResult.
897 DenseSet<OpOperand *> usesRead, usesWrite;
898 getAliasingReads(res&: usesRead, root: operand.get(), state);
899 getAliasingInplaceWrites(res&: usesWrite, root: operand.get(), state);
900 for (AliasingValue alias : state.getAliasingValues(opOperand&: operand)) {
901 getAliasingReads(res&: usesRead, root: alias.value, state);
902 getAliasingInplaceWrites(res&: usesWrite, root: alias.value, state);
903 }
904 if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(opOperand&: operand))
905 usesWrite.insert(V: &operand);
906
907 return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state);
908}
909
910/// Annotate IR with details about the detected non-writability conflict.
911static void annotateNonWritableTensor(Value value) {
912 static int64_t counter = 0;
913 OpBuilder b(value.getContext());
914 std::string id = "W_" + std::to_string(val: counter++);
915 if (auto opResult = dyn_cast<OpResult>(Val&: value)) {
916 std::string attr = id + "[NOT-WRITABLE: result " +
917 std::to_string(val: opResult.getResultNumber()) + "]";
918 opResult.getDefiningOp()->setAttr(name: attr, value: b.getUnitAttr());
919 } else {
920 auto bbArg = cast<BlockArgument>(Val&: value);
921 std::string attr = id + "[NOT-WRITABLE: bbArg " +
922 std::to_string(val: bbArg.getArgNumber()) + "]";
923 bbArg.getOwner()->getParentOp()->setAttr(name: attr, value: b.getUnitAttr());
924 }
925}
926
927/// Return true if bufferizing `operand` inplace would create a write to a
928/// non-writable buffer.
929static bool
930wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
931 OneShotAnalysisState &state,
932 bool checkConsistencyOnly = false) {
933 bool foundWrite =
934 !checkConsistencyOnly && state.bufferizesToMemoryWrite(opOperand&: operand);
935
936 if (!foundWrite) {
937 // Collect writes of all aliases of OpOperand and OpResult.
938 DenseSet<OpOperand *> usesWrite;
939 getAliasingInplaceWrites(res&: usesWrite, root: operand.get(), state);
940 for (AliasingValue alias : state.getAliasingValues(opOperand&: operand))
941 getAliasingInplaceWrites(res&: usesWrite, root: alias.value, state);
942 foundWrite = !usesWrite.empty();
943 }
944
945 if (!foundWrite)
946 return false;
947
948 // Look for a read-only tensor among all aliases.
949 bool foundReadOnly = false;
950 auto checkReadOnly = [&](Value v) {
951 if (!state.isWritable(value: v)) {
952 foundReadOnly = true;
953 if (state.getOptions().printConflicts)
954 annotateNonWritableTensor(value: v);
955 }
956 };
957 state.applyOnAliases(v: operand.get(), fun: checkReadOnly);
958 for (AliasingValue alias : state.getAliasingValues(opOperand&: operand))
959 state.applyOnAliases(v: alias.value, fun: checkReadOnly);
960 if (foundReadOnly) {
961 LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
962 return true;
963 }
964
965 return false;
966}
967
968//===----------------------------------------------------------------------===//
969// Bufferization analyses.
970//===----------------------------------------------------------------------===//
971
972// Find the values that define the contents of the given operand's value.
973const llvm::SetVector<Value> &
974OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
975 Value value = opOperand->get();
976 if (!cachedDefinitions.count(Val: value))
977 cachedDefinitions[value] = findDefinitions(opOperand);
978 return cachedDefinitions[value];
979}
980
981void OneShotAnalysisState::resetCache() {
982 AnalysisState::resetCache();
983 cachedDefinitions.clear();
984}
985
986/// Determine if `operand` can be bufferized in-place.
987static LogicalResult
988bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
989 const DominanceInfo &domInfo) {
990 LLVM_DEBUG(
991 llvm::dbgs() << "//===-------------------------------------------===//\n"
992 << "Analyzing operand #" << operand.getOperandNumber()
993 << " of " << *operand.getOwner() << "\n");
994
995 bool foundInterference =
996 wouldCreateWriteToNonWritableBuffer(operand, state) ||
997 wouldCreateReadAfterWriteInterference(operand, domInfo, state);
998
999 if (foundInterference)
1000 state.bufferizeOutOfPlace(operand);
1001 else
1002 state.bufferizeInPlace(operand);
1003
1004 LLVM_DEBUG(llvm::dbgs()
1005 << "//===-------------------------------------------===//\n");
1006 return success();
1007}
1008
1009LogicalResult
1010OneShotAnalysisState::analyzeSingleOp(Operation *op,
1011 const DominanceInfo &domInfo) {
1012 for (OpOperand &opOperand : op->getOpOperands())
1013 if (isa<TensorType>(Val: opOperand.get().getType()))
1014 if (failed(Result: bufferizableInPlaceAnalysisImpl(operand&: opOperand, state&: *this, domInfo)))
1015 return failure();
1016 return success();
1017}
1018
1019/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
1020static void equivalenceAnalysis(SmallVector<Operation *> &ops,
1021 OneShotAnalysisState &state) {
1022 for (Operation *op : ops) {
1023 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
1024 for (OpResult opResult : op->getOpResults()) {
1025 if (!isa<TensorType>(Val: opResult.getType()))
1026 continue;
1027 AliasingOpOperandList aliases = state.getAliasingOpOperands(value: opResult);
1028 if (aliases.getNumAliases() == 0)
1029 // Nothing to do if there are no aliasing OpOperands.
1030 continue;
1031
1032 Value firstOperand = aliases.begin()->opOperand->get();
1033 bool allEquivalent = true;
1034 for (AliasingOpOperand alias : aliases) {
1035 bool isEquiv = alias.relation == BufferRelation::Equivalent;
1036 bool isInPlace = state.isInPlace(opOperand&: *alias.opOperand);
1037 Value operand = alias.opOperand->get();
1038 if (isEquiv && isInPlace && alias.isDefinite) {
1039 // Found a definite, equivalent alias. Merge equivalence sets.
1040 // There can only be one definite alias, so we can stop here.
1041 state.unionEquivalenceClasses(v1: opResult, v2: operand);
1042 allEquivalent = false;
1043 break;
1044 }
1045 if (!isEquiv || !isInPlace)
1046 allEquivalent = false;
1047 if (!state.areEquivalentBufferizedValues(v1: operand, v2: firstOperand))
1048 allEquivalent = false;
1049 }
1050
1051 // If all "maybe" aliases are equivalent and the OpResult is not a new
1052 // allocation, it is a definite, equivalent alias. E.g.:
1053 //
1054 // aliasingOpOperands(%r) = {(%t0, EQUIV, MAYBE), (%t1, EQUIV, MAYBE)}
1055 // aliasingValues(%t0) = {(%r, EQUIV, MAYBE)}
1056 // aliasingValues(%t1) = {(%r, EQUIV, MAYBE)}
1057 // %r = arith.select %c, %t0, %t1 : tensor<?xf32>
1058 //
1059 // If %t0 and %t1 are equivalent, it is safe to union the equivalence
1060 // classes of %r, %t0 and %t1.
1061 if (allEquivalent && !bufferizableOp.bufferizesToAllocation(value: opResult))
1062 state.unionEquivalenceClasses(v1: opResult, v2: firstOperand);
1063 }
1064 }
1065 }
1066}
1067
1068/// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
1069/// in `op`.
1070static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) {
1071 // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
1072 SmallVector<Operation *> ops;
1073 op->walk<WalkOrder::PostOrder>(callback: [&](Operation *op) {
1074 // No tensors => no buffers.
1075 if (none_of(Range: op->getResultTypes(), P: isaTensor))
1076 return;
1077 ops.push_back(Elt: op);
1078 });
1079
1080 equivalenceAnalysis(ops, state);
1081}
1082
1083/// "Bottom-up from terminators" heuristic.
1084static SmallVector<Operation *>
1085bottomUpFromTerminatorsHeuristic(Operation *op,
1086 const OneShotAnalysisState &state) {
1087 SetVector<Operation *> traversedOps;
1088
1089 // Find region terminators.
1090 op->walk<WalkOrder::PostOrder>(callback: [&](RegionBranchTerminatorOpInterface term) {
1091 if (!traversedOps.insert(X: term))
1092 return;
1093 // Follow the reverse SSA use-def chain from each yielded value as long as
1094 // we stay within the same region.
1095 SmallVector<OpResult> worklist;
1096 for (Value v : term->getOperands()) {
1097 if (!isa<TensorType>(Val: v.getType()))
1098 continue;
1099 auto opResult = dyn_cast<OpResult>(Val&: v);
1100 if (!opResult)
1101 continue;
1102 worklist.push_back(Elt: opResult);
1103 }
1104 while (!worklist.empty()) {
1105 OpResult opResult = worklist.pop_back_val();
1106 Operation *defOp = opResult.getDefiningOp();
1107 if (!traversedOps.insert(X: defOp))
1108 continue;
1109 if (!term->getParentRegion()->findAncestorOpInRegion(op&: *defOp))
1110 continue;
1111 AliasingOpOperandList aliases = state.getAliasingOpOperands(value: opResult);
1112 for (auto alias : aliases) {
1113 Value v = alias.opOperand->get();
1114 if (!isa<TensorType>(Val: v.getType()))
1115 continue;
1116 auto opResult = dyn_cast<OpResult>(Val&: v);
1117 if (!opResult)
1118 continue;
1119 worklist.push_back(Elt: opResult);
1120 }
1121 }
1122 });
1123
1124 // Analyze traversed ops, then all remaining ops.
1125 SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end());
1126 op->walk<WalkOrder::PostOrder, ReverseIterator>(callback: [&](Operation *op) {
1127 if (!traversedOps.contains(key: op) && hasTensorSemantics(op))
1128 result.push_back(Elt: op);
1129 });
1130 return result;
1131}
1132
1133LogicalResult OneShotAnalysisState::analyzeOp(Operation *op,
1134 const DominanceInfo &domInfo) {
1135 OneShotBufferizationOptions::AnalysisHeuristic heuristic =
1136 getOptions().analysisHeuristic;
1137
1138 SmallVector<Operation *> orderedOps;
1139 if (heuristic ==
1140 OneShotBufferizationOptions::AnalysisHeuristic::BottomUpFromTerminators) {
1141 orderedOps = bottomUpFromTerminatorsHeuristic(op, state: *this);
1142 } else {
1143 op->walk(callback: [&](Operation *op) {
1144 // No tensors => no buffers.
1145 if (!hasTensorSemantics(op))
1146 return;
1147 orderedOps.push_back(Elt: op);
1148 });
1149 switch (heuristic) {
1150 case OneShotBufferizationOptions::AnalysisHeuristic::BottomUp: {
1151 // Default: Walk ops in reverse for better interference analysis.
1152 std::reverse(first: orderedOps.begin(), last: orderedOps.end());
1153 break;
1154 }
1155 case OneShotBufferizationOptions::AnalysisHeuristic::TopDown: {
1156 // Ops are already sorted top-down in `orderedOps`.
1157 break;
1158 }
1159 case OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer: {
1160 assert(getOptions().analysisFuzzerSeed &&
1161 "expected that fuzzer seed it set");
1162 // This is a fuzzer. For testing purposes only. Randomize the order in
1163 // which operations are analyzed. The bufferization quality is likely
1164 // worse, but we want to make sure that no assertions are triggered
1165 // anywhere.
1166 std::mt19937 g(getOptions().analysisFuzzerSeed);
1167 llvm::shuffle(first: orderedOps.begin(), last: orderedOps.end(), g);
1168 break;
1169 }
1170 default: {
1171 llvm_unreachable("unsupported heuristic");
1172 }
1173 }
1174 }
1175
1176 // Analyze ops in the computed order.
1177 for (Operation *op : orderedOps)
1178 if (failed(Result: analyzeSingleOp(op, domInfo)))
1179 return failure();
1180
1181 equivalenceAnalysis(op, state&: *this);
1182 return success();
1183}
1184
1185/// Perform various checks on the input IR to see if it contains IR constructs
1186/// that are unsupported by One-Shot Bufferize.
1187static LogicalResult
1188checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo,
1189 OneShotAnalysisState &state) {
1190 const BufferizationOptions &options = state.getOptions();
1191
1192 // Note: This walk cannot be combined with the one below because interface
1193 // methods of invalid/unsupported ops may be called during the second walk.
1194 // (On ops different from `op`.)
1195 WalkResult walkResult = op->walk(callback: [&](BufferizableOpInterface op) {
1196 // Skip ops that are not in the filter.
1197 if (!options.isOpAllowed(op: op.getOperation()))
1198 return WalkResult::advance();
1199
1200 // Check for unsupported unstructured control flow.
1201 if (!op.supportsUnstructuredControlFlow()) {
1202 for (Region &r : op->getRegions()) {
1203 if (r.getBlocks().size() > 1) {
1204 op->emitOpError(message: "op or BufferizableOpInterface implementation does "
1205 "not support unstructured control flow, but at least "
1206 "one region has multiple blocks");
1207 return WalkResult::interrupt();
1208 }
1209 }
1210 }
1211
1212 return WalkResult::advance();
1213 });
1214 if (walkResult.wasInterrupted())
1215 return failure();
1216
1217 walkResult = op->walk(callback: [&](BufferizableOpInterface op) {
1218 // Skip ops that are not in the filter.
1219 if (!options.isOpAllowed(op: op.getOperation()))
1220 return WalkResult::advance();
1221
1222 // Input IR may not contain any ToTensorOps without the "restrict"
1223 // attribute. Such tensors may alias any other tensor, which is currently
1224 // not handled in the analysis.
1225 if (auto toTensorOp = dyn_cast<ToTensorOp>(Val: op.getOperation())) {
1226 if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) {
1227 op->emitOpError(message: "to_tensor ops without `restrict` are not supported by "
1228 "One-Shot Analysis");
1229 return WalkResult::interrupt();
1230 }
1231 }
1232
1233 for (OpOperand &opOperand : op->getOpOperands()) {
1234 if (isa<TensorType>(Val: opOperand.get().getType())) {
1235 if (wouldCreateReadAfterWriteInterference(
1236 operand&: opOperand, domInfo, state,
1237 /*checkConsistencyOnly=*/true)) {
1238 // This error can happen if certain "mustBufferizeInPlace" interface
1239 // methods are implemented incorrectly, such that the IR already has
1240 // a RaW conflict before making any bufferization decisions. It can
1241 // also happen if the bufferization.materialize_in_destination is used
1242 // in such a way that a RaW conflict is not avoidable.
1243 op->emitOpError(message: "not bufferizable under the given constraints: "
1244 "cannot avoid RaW conflict");
1245 return WalkResult::interrupt();
1246 }
1247
1248 if (state.isInPlace(opOperand) &&
1249 wouldCreateWriteToNonWritableBuffer(
1250 operand&: opOperand, state, /*checkConsistencyOnly=*/true)) {
1251 op->emitOpError(message: "not bufferizable under the given constraints: would "
1252 "write to read-only buffer");
1253 return WalkResult::interrupt();
1254 }
1255 }
1256 }
1257
1258 return WalkResult::advance();
1259 });
1260
1261 return success(IsSuccess: !walkResult.wasInterrupted());
1262}
1263
1264/// Annotate the IR with the result of the analysis. For testing/debugging only.
1265static void
1266annotateOpsWithBufferizationMarkers(Operation *op,
1267 const OneShotAnalysisState &state) {
1268 // Add __inplace_operands_attr__.
1269 op->walk(callback: [&](Operation *op) {
1270 for (OpOperand &opOperand : op->getOpOperands())
1271 if (isa<TensorType>(Val: opOperand.get().getType()))
1272 setInPlaceOpOperand(opOperand, inPlace: state.isInPlace(opOperand));
1273 });
1274}
1275
1276static void annotateOpsWithAliasSets(Operation *op,
1277 const OneShotAnalysisState &state) {
1278 AsmState asmState(op);
1279 Builder b(op->getContext());
1280 // Helper function to build an array attribute of aliasing SSA value strings.
1281 auto buildAliasesArray = [&](Value v) {
1282 SmallVector<Attribute> aliases;
1283 state.applyOnAliases(v, fun: [&](Value alias) {
1284 std::string buffer;
1285 llvm::raw_string_ostream stream(buffer);
1286 alias.printAsOperand(os&: stream, state&: asmState);
1287 aliases.push_back(Elt: b.getStringAttr(bytes: buffer));
1288 });
1289 return b.getArrayAttr(value: aliases);
1290 };
1291
1292 op->walk(callback: [&](Operation *op) {
1293 // Build alias set array for every OpResult.
1294 SmallVector<Attribute> opResultAliasSets;
1295 for (OpResult opResult : op->getOpResults()) {
1296 if (llvm::isa<TensorType>(Val: opResult.getType())) {
1297 opResultAliasSets.push_back(Elt: buildAliasesArray(opResult));
1298 }
1299 }
1300 if (!opResultAliasSets.empty())
1301 op->setAttr(name: kOpResultAliasSetAttrName, value: b.getArrayAttr(value: opResultAliasSets));
1302
1303 // Build alias set array for every BlockArgument.
1304 SmallVector<Attribute> regionAliasSets;
1305 bool hasTensorBbArg = false;
1306 for (Region &r : op->getRegions()) {
1307 SmallVector<Attribute> blockAliasSets;
1308 for (Block &block : r.getBlocks()) {
1309 SmallVector<Attribute> bbArgAliasSets;
1310 for (BlockArgument bbArg : block.getArguments()) {
1311 if (llvm::isa<TensorType>(Val: bbArg.getType())) {
1312 bbArgAliasSets.push_back(Elt: buildAliasesArray(bbArg));
1313 hasTensorBbArg = true;
1314 }
1315 }
1316 blockAliasSets.push_back(Elt: b.getArrayAttr(value: bbArgAliasSets));
1317 }
1318 regionAliasSets.push_back(Elt: b.getArrayAttr(value: blockAliasSets));
1319 }
1320 if (hasTensorBbArg)
1321 op->setAttr(name: kBbArgAliasSetAttrName, value: b.getArrayAttr(value: regionAliasSets));
1322 });
1323}
1324
1325LogicalResult bufferization::analyzeOp(Operation *op,
1326 OneShotAnalysisState &state,
1327 BufferizationStatistics *statistics) {
1328 DominanceInfo domInfo(op);
1329 const OneShotBufferizationOptions &options = state.getOptions();
1330
1331 if (failed(Result: checkPreBufferizationAssumptions(op, domInfo, state)))
1332 return failure();
1333
1334 // If the analysis fails, just return.
1335 if (failed(Result: state.analyzeOp(op, domInfo)))
1336 return failure();
1337
1338 if (statistics) {
1339 statistics->numTensorInPlace = state.getStatNumTensorInPlace();
1340 statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace();
1341 }
1342
1343 bool failedAnalysis = false;
1344
1345 // Gather some extra analysis data.
1346 state.gatherUndefinedTensorUses(op);
1347
1348 // Analysis verification: After setting up alias/equivalence sets, each op
1349 // can check for expected invariants/limitations and fail the analysis if
1350 // necessary.
1351 op->walk(callback: [&](Operation *op) {
1352 if (BufferizableOpInterface bufferizableOp =
1353 options.dynCastBufferizableOp(op))
1354 failedAnalysis |= failed(Result: bufferizableOp.verifyAnalysis(state));
1355 });
1356
1357 // Annotate operations if we only want to report the analysis.
1358 if (options.testAnalysisOnly)
1359 annotateOpsWithBufferizationMarkers(op, state);
1360 if (options.dumpAliasSets)
1361 annotateOpsWithAliasSets(op, state);
1362
1363 return success(IsSuccess: !failedAnalysis);
1364}
1365
1366LogicalResult bufferization::runOneShotBufferize(
1367 Operation *op, const OneShotBufferizationOptions &options,
1368 BufferizationState &state, BufferizationStatistics *statistics) {
1369 // copy-before-write deactivates the analysis. It cannot be used together with
1370 // test-analysis-only.
1371 assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
1372 "invalid combination of bufferization flags");
1373
1374 if (options.copyBeforeWrite) {
1375 // Copy buffer before each write. No analysis is needed.
1376 } else {
1377 // Run One-Shot Analysis and insert buffer copies (on the tensor level)
1378 // only where needed. This is the default and much more efficient than
1379 // copy-before-write.
1380 if (failed(Result: insertTensorCopies(op, options, bufferizationState: state, statistics)))
1381 return failure();
1382
1383 // If test-analysis-only is set, the IR was annotated with RaW conflict
1384 // markers (attributes) during One-Shot Analysis.
1385 if (options.testAnalysisOnly)
1386 return success();
1387 }
1388
1389 // Bufferize the op and its nested ops. If options.copyBeforeWrite is set,
1390 // a new buffer copy is allocated every time a buffer is written to.
1391 return bufferizeOp(op, options, bufferizationState&: state, statistics);
1392}
1393

source code of mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp