1//===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// One-Shot Analysis analyzes function bodies. By default, function boundaries
10// (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
11// OneShotModuleBufferization.cpp is an extension of One-Shot Analysis for
12// simple call graphs without loops.
13//
14// One-Shot Bufferize consists of three phases.
15//
16// 1. Analyze ops to decide which OpOperands can bufferize inplace, i.e.,
17// without inserting buffer copies. The analysis queries op bufferization
18// semantics via `BufferizableOpInterface`.
19// 2. Insert copies for OpOperands that were decided to bufferize out-of-place
20// in tensor land during `TensorCopyInsertion`.
21// 3. Bufferize ops by calling `BufferizableOpInterface::bufferize`.
22//
23// This file contains only the analysis. For convenience, this file also
24// contains a helper function `runOneShotBufferize` that analyzes an op (and its
25// nested ops) and then bufferizes it.
26//
27// Inplace bufferization decisions are passed from the analysis to the
28// `TensorCopyInsertion` phase via `AnalysisState`. They can be printed for
29// debugging purposes with `testAnalysisOnly`.
30//
31// Ops that do not implement `BufferizableOpInterface` can be analyzed but are
32// treated conservatively. E.g., the analysis has to assume that their tensor
33// OpOperands bufferize to memory writes. While such ops can be analyzed, they
34// are not bufferized and remain in the IR. to_tensor and to_memref ops are
35// inserted at the bufferization boundary.
36//
37// This analysis caters to high-performance codegen where buffer reuse is deemed
38// critical: the analysis should fail if the bufferized form of the function
39// needs to return a buffer, unless `allowReturnAllocs` is enabled.
40
41#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
42
43#include <optional>
44#include <random>
45
46#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
47#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
48#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
49#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
50#include "mlir/Dialect/Func/IR/FuncOps.h"
51#include "mlir/Dialect/MemRef/IR/MemRef.h"
52#include "mlir/IR/AsmState.h"
53#include "mlir/IR/Dominance.h"
54#include "mlir/IR/Iterators.h"
55#include "mlir/IR/Operation.h"
56#include "mlir/IR/TypeUtilities.h"
57#include "mlir/Interfaces/ControlFlowInterfaces.h"
58#include "mlir/Interfaces/SubsetOpInterface.h"
59#include "llvm/ADT/DenseSet.h"
60#include "llvm/ADT/SetVector.h"
61
62MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
63
64// Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug
65// output.
66#define DEBUG_TYPE "one-shot-analysis"
67
68using namespace mlir;
69using namespace mlir::bufferization;
70
71static bool isaTensor(Type t) { return isa<TensorType>(Val: t); }
72
73//===----------------------------------------------------------------------===//
74// Bufferization-specific attribute manipulation.
75// These are for testing and debugging only. Bufferization information is stored
76// in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is
77// annotated with the results of the analysis, so that they can be checked in
78// tests.
79//===----------------------------------------------------------------------===//
80
81/// Attribute marker to specify op operands that bufferize in-place.
82constexpr StringLiteral kInPlaceOperandsAttrName = "__inplace_operands_attr__";
83
84constexpr StringLiteral kOpResultAliasSetAttrName =
85 "__opresult_alias_set_attr__";
86
87constexpr StringLiteral kBbArgAliasSetAttrName = "__bbarg_alias_set_attr__";
88
89/// Mark whether OpOperand will be bufferized inplace.
90static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
91 Operation *op = opOperand.getOwner();
92 SmallVector<StringRef> inPlaceVector;
93 if (auto attr = op->getAttr(name: kInPlaceOperandsAttrName)) {
94 inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>(
95 cast<ArrayAttr>(attr).getAsValueRange<StringAttr>()));
96 } else {
97 inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
98 for (OpOperand &opOperand : op->getOpOperands())
99 if (isa<TensorType>(Val: opOperand.get().getType()))
100 inPlaceVector[opOperand.getOperandNumber()] = "false";
101 }
102 inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
103 op->setAttr(kInPlaceOperandsAttrName,
104 OpBuilder(op).getStrArrayAttr(inPlaceVector));
105}
106
107//===----------------------------------------------------------------------===//
108// OneShotAnalysisState
109//===----------------------------------------------------------------------===//
110
111OneShotAnalysisState::OneShotAnalysisState(
112 Operation *op, const OneShotBufferizationOptions &options)
113 : AnalysisState(options, TypeID::get<OneShotAnalysisState>()) {
114 // Set up alias sets.
115 op->walk(callback: [&](Operation *op) {
116 for (Value v : op->getResults())
117 if (isa<TensorType>(Val: v.getType()))
118 createAliasInfoEntry(v);
119 for (Region &r : op->getRegions())
120 for (Block &b : r.getBlocks())
121 for (auto bbArg : b.getArguments())
122 if (isa<TensorType>(Val: bbArg.getType()))
123 createAliasInfoEntry(v: bbArg);
124 });
125
126 // Mark OpOperands in-place that must bufferize in-place.
127 op->walk([&](BufferizableOpInterface bufferizableOp) {
128 if (!options.isOpAllowed(op: bufferizableOp))
129 return WalkResult::skip();
130 for (OpOperand &opOperand : bufferizableOp->getOpOperands())
131 if (isa<TensorType>(opOperand.get().getType()))
132 if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
133 bufferizeInPlace(opOperand);
134 return WalkResult::advance();
135 });
136}
137
138void OneShotAnalysisState::applyOnEquivalenceClass(
139 Value v, function_ref<void(Value)> fun) const {
140 auto leaderIt = equivalentInfo.findLeader(V: v);
141 for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
142 ++mit) {
143 fun(*mit);
144 }
145}
146
147void OneShotAnalysisState::applyOnAliases(Value v,
148 function_ref<void(Value)> fun) const {
149 auto leaderIt = aliasInfo.findLeader(V: v);
150 for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
151 fun(*mit);
152 }
153}
154
155bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
156 Value v2) const {
157 return equivalentInfo.isEquivalent(V1: v1, V2: v2);
158}
159
160bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
161 Value v2) const {
162 return aliasInfo.isEquivalent(V1: v1, V2: v2);
163}
164
165void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) {
166 if (inplaceBufferized.contains(V: &operand))
167 return;
168 inplaceBufferized.insert(V: &operand);
169 for (AliasingValue alias : getAliasingValues(operand))
170 aliasInfo.unionSets(alias.value, operand.get());
171 ++statNumTensorInPlace;
172}
173
174void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) {
175 assert(!inplaceBufferized.contains(&operand) &&
176 "OpOperand was already decided to bufferize inplace");
177 ++statNumTensorOutOfPlace;
178}
179
180void OneShotAnalysisState::createAliasInfoEntry(Value v) {
181 aliasInfo.insert(Data: v);
182 equivalentInfo.insert(Data: v);
183}
184
185void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
186 op->walk(callback: [&](Operation *op) {
187 // Skip unknown ops.
188 auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
189 if (!bufferizableOp)
190 return WalkResult::skip();
191
192 // Check all tensor OpResults.
193 for (OpResult opResult : op->getOpResults()) {
194 if (!isa<TensorType>(Val: opResult.getType()))
195 continue;
196
197 // If there is no preceding definition, the tensor contents are
198 // undefined.
199 if (findDefinitionsCached(value: opResult).empty())
200 for (OpOperand &use : opResult.getUses())
201 undefinedTensorUses.insert(V: &use);
202 }
203
204 return WalkResult::advance();
205 });
206}
207
208bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
209 return undefinedTensorUses.contains(V: opOperand);
210}
211
212bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
213 return inplaceBufferized.contains(V: &opOperand);
214}
215
216bool OneShotAnalysisState::isValueWritten(Value value) const {
217 bool isWritten = false;
218 applyOnAliases(v: value, fun: [&](Value val) {
219 for (OpOperand &use : val.getUses())
220 if (isInPlace(opOperand&: use) && bufferizesToMemoryWrite(use))
221 isWritten = true;
222 });
223 return isWritten;
224}
225
226bool OneShotAnalysisState::isWritable(Value value) const {
227 // TODO: Out-of-place bufferized value could be considered writable.
228 // Query BufferizableOpInterface to see if the BlockArgument is writable.
229 if (auto bufferizableOp =
230 getOptions().dynCastBufferizableOp(getOwnerOfValue(value)))
231 return bufferizableOp.isWritable(value, *this);
232
233 // Not a bufferizable op: The conservative answer is "not writable".
234 return false;
235}
236
237void OneShotAnalysisState::unionAliasSets(Value v1, Value v2) {
238 aliasInfo.unionSets(V1: v1, V2: v2);
239}
240
241void OneShotAnalysisState::unionEquivalenceClasses(Value v1, Value v2) {
242 equivalentInfo.unionSets(V1: v1, V2: v2);
243}
244
245OneShotAnalysisState::Extension::~Extension() = default;
246
247//===----------------------------------------------------------------------===//
248// Bufferization-specific alias analysis.
249//===----------------------------------------------------------------------===//
250
251/// Return true if opOperand has been decided to bufferize in-place.
252static bool isInplaceMemoryWrite(OpOperand &opOperand,
253 const OneShotAnalysisState &state) {
254 // OpOperands that do not bufferize to a memory write do not write in-place.
255 if (!state.bufferizesToMemoryWrite(opOperand))
256 return false;
257 // Check current bufferization decisions.
258 return state.isInPlace(opOperand);
259}
260
261/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
262/// properly dominates `b` and `b` is not inside `a`.
263static bool happensBefore(Operation *a, Operation *b,
264 const DominanceInfo &domInfo) {
265 do {
266 // TODO: Instead of isProperAncestor + properlyDominates, we should use
267 // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
268 if (a->isProperAncestor(other: b))
269 return false;
270 if (domInfo.properlyDominates(a, b))
271 return true;
272 } while ((a = a->getParentOp()));
273 return false;
274}
275
276static bool isReachable(Block *from, Block *to, ArrayRef<Block *> except) {
277 DenseSet<Block *> visited;
278 SmallVector<Block *> worklist;
279 for (Block *succ : from->getSuccessors())
280 worklist.push_back(Elt: succ);
281 while (!worklist.empty()) {
282 Block *next = worklist.pop_back_val();
283 if (llvm::is_contained(Range&: except, Element: next))
284 continue;
285 if (next == to)
286 return true;
287 if (visited.contains(V: next))
288 continue;
289 visited.insert(V: next);
290 for (Block *succ : next->getSuccessors())
291 worklist.push_back(Elt: succ);
292 }
293 return false;
294}
295
296/// Return `true` if op dominance can be used to rule out a read-after-write
297/// conflicts based on the ordering of ops. Returns `false` if op dominance
298/// cannot be used to due region-based loops.
299///
300/// Generalized op dominance can often be used to rule out potential conflicts
301/// due to "read happens before write". E.g., the following IR is not a RaW
302/// conflict because the read happens *before* the write.
303///
304/// Example 1:
305/// %0 = ... : tensor<?xf32> // DEF
306/// "reading_op"(%0) : tensor<?xf32> // READ
307/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
308///
309/// This is no longer true inside loops (or repetitive regions). In such cases,
310/// there may not be a meaningful `happensBefore` relationship because ops
311/// could be executed multiple times. E.g.:
312///
313/// Example 2:
314/// %0 = ... : tensor<?xf32> // DEF
315/// scf.for ... {
316/// "reading_op"(%0) : tensor<?xf32> // READ
317/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
318/// ...
319/// }
320///
321/// In the above example, reading_op happens before writing_op according to
322/// op dominance. However, both ops may happen multiple times; in
323/// particular, the second execution of reading_op happens after the first
324/// execution of writing_op. This is problematic because the tensor %0 they
325/// operate on (i.e., the "definition") is defined outside of the loop.
326///
327/// On a high-level, there is a potential RaW in a program if there exists a
328/// possible program execution such that there is a sequence of DEF, followed
329/// by WRITE, followed by READ. Each additional DEF resets the sequence.
330///
331/// E.g.:
332/// No conflict: DEF, WRITE, DEF, READ
333/// Potential conflict: DEF, READ, WRITE, READ, WRITE
334///
335/// Example 1 has no conflict: DEF, READ, WRITE
336/// Example 2 has a potential conflict: DEF, (READ, WRITE)*
337//
338/// Example 3:
339/// scf.for ... {
340/// %0 = ... : tensor<?xf32>
341/// "reading_op"(%0) : tensor<?xf32>
342/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
343/// ...
344/// }
345/// This has no conflict: (DEF, READ, WRITE)*
346///
347/// Example 4:
348/// %0 = ... : tensor<?xf32>
349/// scf.for ... {
350/// scf.for ... { "reading_op"(%0) }
351/// %1 = "writing_op"(%0)
352/// }
353/// This has a potential conflict: DEF, ((READ)*, WRITE)*
354///
355/// Example 5:
356/// %0 = ... : tensor<?xf32>
357/// scf.for ... { %1 = "writing_op"(%0) }
358/// scf.for ... { "reading_op"(%0) }
359/// This has a potential conflict: DEF, WRITE*, READ*
360///
361/// The following rules are used to rule out RaW conflicts via ordering of ops:
362///
363/// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of
364/// a repetitive region that enclosing both READ and WRITE, we cannot rule
365/// out RaW conflict due to the ordering of ops.
366/// 2. Otherwise: There are no loops that interfere with our analysis; for
367/// analysis purposes, we can assume that there are no loops/repetitive
368/// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
369/// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
370///
371static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite,
372 const SetVector<Value> &definitions,
373 AnalysisState &state) {
374 const BufferizationOptions &options = state.getOptions();
375 for (Value def : definitions) {
376 Region *rRead =
377 state.getEnclosingRepetitiveRegion(op: uRead->getOwner(), options);
378 Region *rDef = state.getEnclosingRepetitiveRegion(value: def, options);
379
380 // READ and DEF are in the same repetitive region. `happensBefore` can be
381 // used to rule out RaW conflicts due to op ordering.
382 if (rRead == rDef)
383 continue;
384
385 // Find the enclosing repetitive region of READ that is closest to DEF but
386 // not the repetitive region of DEF itself.
387 while (true) {
388 Region *nextRegion = getNextEnclosingRepetitiveRegion(region: rRead, options);
389 if (nextRegion == rDef)
390 break;
391 assert(nextRegion && "expected to find another repetitive region");
392 rRead = nextRegion;
393 }
394
395 // We cannot use op dominance if WRITE is inside the same repetitive region.
396 if (rRead->getParentOp()->isAncestor(other: uWrite->getOwner()))
397 return false;
398 }
399
400 return true;
401}
402
403/// Return `true` if op dominance can be used to rule out a read-after-write
404/// conflicts based on the ordering of ops. Returns `false` if op dominance
405/// cannot be used to due block-based loops within a region.
406///
407/// Refer to the `canUseOpDominanceDueToRegions` documentation for details on
408/// how op domiance is used during RaW conflict detection.
409///
410/// On a high-level, there is a potential RaW in a program if there exists a
411/// possible program execution such that there is a sequence of DEF, followed
412/// by WRITE, followed by READ. Each additional DEF resets the sequence.
413///
414/// Op dominance cannot be used if there is a path from block(READ) to
415/// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should
416/// not appear on that path.
417static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite,
418 const SetVector<Value> &definitions,
419 AnalysisState &state) {
420 // Fast path: If READ and WRITE are in different regions, their block cannot
421 // be reachable just via unstructured control flow. (Loops due to regions are
422 // covered by `canUseOpDominanceDueToRegions`.)
423 if (uRead->getOwner()->getParentRegion() !=
424 uWrite->getOwner()->getParentRegion())
425 return true;
426
427 Block *readBlock = uRead->getOwner()->getBlock();
428 Block *writeBlock = uWrite->getOwner()->getBlock();
429 for (Value def : definitions) {
430 Block *defBlock = def.getParentBlock();
431 if (isReachable(from: readBlock, to: writeBlock, except: {defBlock}) &&
432 isReachable(from: writeBlock, to: readBlock, except: {defBlock}))
433 return false;
434 }
435
436 return true;
437}
438
439static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
440 const SetVector<Value> &definitions,
441 AnalysisState &state) {
442 return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) &&
443 canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state);
444}
445
446/// Annotate IR with details about the detected RaW conflict.
447static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
448 Value definition) {
449 static uint64_t counter = 0;
450 Operation *readingOp = uRead->getOwner();
451 Operation *conflictingWritingOp = uConflictingWrite->getOwner();
452
453 OpBuilder b(conflictingWritingOp->getContext());
454 std::string id = "C_" + std::to_string(val: counter++);
455
456 std::string conflictingWriteAttr =
457 id +
458 "[CONFL-WRITE: " + std::to_string(val: uConflictingWrite->getOperandNumber()) +
459 "]";
460 conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
461
462 std::string readAttr =
463 id + "[READ: " + std::to_string(val: uRead->getOperandNumber()) + "]";
464 readingOp->setAttr(readAttr, b.getUnitAttr());
465
466 if (auto opResult = dyn_cast<OpResult>(Val&: definition)) {
467 std::string defAttr =
468 id + "[DEF: result " + std::to_string(val: opResult.getResultNumber()) + "]";
469 opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
470 } else {
471 auto bbArg = cast<BlockArgument>(Val&: definition);
472 std::string defAttr =
473 id + "[DEF: bbArg " + std::to_string(val: bbArg.getArgNumber()) + "]";
474 bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
475 }
476}
477
478/// Return 'true' if a tensor that is equivalent to `other` can be found in the
479/// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of
480/// place along that use-def chain, the two tensors may not materialize as
481/// equivalent buffers (but separate allocations).
482///
483/// Note: This function also requires that the two tensors have equivalent
484/// indexing. I.e., the tensor types do not change along the use-def chain,
485/// apart from static <-> dynamic dim casts.
486static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
487 Value start, Value other) {
488 TraversalConfig config;
489 config.followEquivalentOnly = true;
490 config.alwaysIncludeLeaves = false;
491 config.followSameTypeOrCastsOnly = true;
492 return !state
493 .findValueInReverseUseDefChain(
494 start, [&](Value v) { return v == other; }, config)
495 .empty();
496}
497
498/// Return "true" if `value` is originating from a subset that is equivalent to
499/// the subset that `subsetOp` inserts into.
500static bool matchesInsertDestination(const AnalysisState &state, Value value,
501 SubsetInsertionOpInterface subsetOp) {
502 auto matchingSubset = [&](Value val) {
503 if (auto opResult = dyn_cast<OpResult>(Val&: val))
504 if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
505 return state.areEquivalentBufferizedValues(v1, v2);
506 }))
507 return true;
508 return false;
509 };
510 // There may be multiple leaves at which the reverse SSA use-def chain lookup
511 // terminates. All of them must be equivalent subsets.
512 SetVector<Value> backwardSlice =
513 state.findValueInReverseUseDefChain(value, matchingSubset);
514 return static_cast<bool>(llvm::all_of(Range&: backwardSlice, P: matchingSubset));
515}
516
517/// Return "true" if the given "read" and potentially conflicting "write" are
518/// not conflicting due to their subset relationship. The comments in this
519/// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
520/// pairs, but apply to any subset ops that implement the
521/// `SubsetInsertionOpInterface`.
522static bool areNonConflictingSubsets(OpOperand *uRead,
523 OpOperand *uConflictingWrite,
524 const AnalysisState &state) {
525 Operation *readingOp = uRead->getOwner();
526 Operation *conflictingWritingOp = uConflictingWrite->getOwner();
527
528 // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
529 // uRead is an InsertSliceOp...
530 if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) {
531 // As an example, consider the following IR.
532 //
533 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
534 // %1 = linalg.fill %cst, %0 {inplace= [true] }
535 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
536 // {inplace= [true] }
537
538 if (uRead == &subsetOp.getDestinationOperand() &&
539 matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
540 // Case 1: The main insight is that InsertSliceOp reads only part of
541 // the destination tensor. The overwritten area is not read. If
542 // uConflictingWrite writes into exactly the memory location that is
543 // being read by uRead, this is not a conflict.
544 //
545 // In the above example:
546 // uRead = OpOperand 1 (%t) of tensor.insert_slice
547 // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
548 //
549 // The read of %t does not conflict with the write of the FillOp
550 // (same aliases!) because the area that the FillOp operates on is
551 // exactly the one that is *not* read via %t.
552 return true;
553
554 if (uRead == &subsetOp.getSourceOperand() &&
555 uConflictingWrite == &subsetOp.getDestinationOperand() &&
556 matchesInsertDestination(state, uRead->get(), subsetOp))
557 // Case 2: The read of the source tensor and the write to the dest
558 // tensor via an InsertSliceOp is not a conflict if the read is
559 // reading exactly that part of an equivalent tensor that the
560 // InsertSliceOp is writing.
561 //
562 // In the above example:
563 // uRead = OpOperand 0 (%1) of tensor.insert_slice
564 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
565 return true;
566 }
567
568 // If uConflictingWrite is an InsertSliceOp...
569 if (auto subsetOp =
570 dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp))
571 // As an example, consider the following IR.
572 //
573 // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
574 // %1 = linalg.fill %cst, %0 {inplace= [true] }
575 // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
576 // {inplace= [true] }
577 // %3 = vector.transfer_read %1, %cst
578 //
579 // In the above example:
580 // uRead = OpOperand 0 (%1) of vector.transfer_read
581 // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
582 // definition = %1
583 //
584 // This is not a conflict because the InsertSliceOp overwrites the
585 // memory segment of %1 with the exact same data. (Effectively, there
586 // is no memory write here.)
587 if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
588 state.areEquivalentBufferizedValues(
589 v1: uRead->get(), v2: subsetOp.getSourceOperand().get()) &&
590 matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
591 subsetOp))
592 return true;
593
594 return false;
595}
596
597/// Given sets of uses and writes, return true if there is a RaW conflict under
598/// the assumption that all given reads/writes alias the same buffer and that
599/// all given writes bufferize inplace.
600///
601/// A conflict is: According to SSA use-def chains, a read R is supposed to read
602/// the result of a definition W1. But because of bufferization decisions, R
603/// actually reads another definition W2.
604static bool
605hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
606 const DenseSet<OpOperand *> &usesWrite,
607 const DominanceInfo &domInfo,
608 OneShotAnalysisState &state) {
609 const BufferizationOptions &options = state.getOptions();
610
611 // Before going through the main RaW analysis, find cases where a buffer must
612 // be privatized due to parallelism. If the result of a write is never read,
613 // privatization is not necessary (and large parts of the IR are likely dead).
614 if (!usesRead.empty()) {
615 for (OpOperand *uConflictingWrite : usesWrite) {
616 // Find the allocation point or last write (definition) of the buffer.
617 // Note: In contrast to `findDefinitions`, this also returns results of
618 // ops that do not bufferize to memory write when no other definition
619 // could be found. E.g., "bufferization.alloc_tensor" would be included,
620 // even though that op just bufferizes to an allocation but does define
621 // the contents of the buffer.
622 SetVector<Value> definitionsOrLeaves =
623 state.findValueInReverseUseDefChain(
624 uConflictingWrite->get(),
625 [&](Value v) { return state.bufferizesToMemoryWrite(v); });
626 assert(!definitionsOrLeaves.empty() &&
627 "expected at least one definition or leaf");
628
629 // The writing op must bufferize out-of-place if the definition is in a
630 // different parallel region than this write.
631 for (Value def : definitionsOrLeaves) {
632 if (getParallelRegion(def.getParentRegion(), options) !=
633 getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(),
634 options)) {
635 LLVM_DEBUG(
636 llvm::dbgs()
637 << "\n- bufferizes out-of-place due to parallel region:\n");
638 LLVM_DEBUG(llvm::dbgs()
639 << " unConflictingWrite = operand "
640 << uConflictingWrite->getOperandNumber() << " of "
641 << *uConflictingWrite->getOwner() << "\n");
642 return true;
643 }
644 }
645 }
646 }
647
648 for (OpOperand *uRead : usesRead) {
649 Operation *readingOp = uRead->getOwner();
650 LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
651 LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()
652 << " of " << *readingOp << "\n");
653
654 // Find the definition of uRead by following the SSA use-def chain.
655 // E.g.:
656 //
657 // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
658 // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
659 // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
660 //
661 // In the above example, if uRead is the OpOperand of reading_op, the
662 // definition is %0. Note that operations that create an alias but do not
663 // bufferize to a memory write (such as ExtractSliceOp) are skipped.
664 const SetVector<Value> &definitions =
665 state.findDefinitionsCached(value: uRead->get());
666 if (definitions.empty()) {
667 // Fast path: No conflict if there are no definitions.
668 LLVM_DEBUG(llvm::dbgs()
669 << " no conflict: read value has no definitions\n");
670 continue;
671 }
672
673 // Look for conflicting memory writes. Potential conflicts are writes to an
674 // alias that have been decided to bufferize inplace.
675 for (OpOperand *uConflictingWrite : usesWrite) {
676 LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "
677 << uConflictingWrite->getOperandNumber() << " of "
678 << *uConflictingWrite->getOwner() << "\n");
679
680 // Check if op dominance can be used to rule out read-after-write
681 // conflicts.
682 bool useDominance =
683 canUseOpDominance(uRead, uConflictingWrite, definitions, state);
684 LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
685
686 // Throughout this loop, check for multiple requirements that have to be
687 // met for uConflictingWrite to be an actual conflict.
688 Operation *conflictingWritingOp = uConflictingWrite->getOwner();
689
690 // Inside of repetitive regions, ops may be executed multiple times and op
691 // dominance cannot be used to rule out conflicts.
692 if (useDominance) {
693 // No conflict if the readingOp dominates conflictingWritingOp, i.e.,
694 // the write is not visible when reading.
695 //
696 // Note: If ops are executed multiple times (e.g., because they are
697 // inside a loop), there may be no meaningful `happensBefore`
698 // relationship.
699 if (happensBefore(a: readingOp, b: conflictingWritingOp, domInfo)) {
700 LLVM_DEBUG(llvm::dbgs()
701 << " no conflict: read happens before write\n");
702 continue;
703 }
704
705 // No conflict if the reading use equals the use of the conflicting
706 // write. A use cannot conflict with itself.
707 //
708 // Note: Just being the same op is not enough. It has to be the same
709 // use.
710 // Note: If the op is executed multiple times (e.g., because it is
711 // inside a loop), it may be conflicting with itself.
712 if (uConflictingWrite == uRead) {
713 LLVM_DEBUG(llvm::dbgs()
714 << " no conflict: read and write are same use\n");
715 continue;
716 }
717
718 // Ops are not conflicting if they are in mutually exclusive regions.
719 //
720 // Note: If ops are executed multiple times (e.g., because they are
721 // inside a loop), mutually exclusive regions may be executed
722 // multiple times.
723 if (insideMutuallyExclusiveRegions(a: readingOp, b: conflictingWritingOp)) {
724 LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in "
725 "mutually exclusive regions\n");
726 continue;
727 }
728 }
729
730 // Two equivalent operands of the same op are not conflicting if the op
731 // bufferizes to element-wise access. I.e., all loads at a position happen
732 // before all stores to the same position.
733 if (conflictingWritingOp == readingOp) {
734 if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
735 if (bufferizableOp.bufferizesToElementwiseAccess(
736 state, {uRead, uConflictingWrite})) {
737 if (hasEquivalentValueInReverseUseDefChain(
738 state, uRead->get(), uConflictingWrite->get()) ||
739 hasEquivalentValueInReverseUseDefChain(
740 state, uConflictingWrite->get(), uRead->get())) {
741 LLVM_DEBUG(
742 llvm::dbgs()
743 << " no conflict: op bufferizes to element-wise access\n");
744 continue;
745 }
746 }
747 }
748 }
749
750 // No conflict if the operands are non-conflicting subsets.
751 if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
752 LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
753 continue;
754 }
755
756 // No conflict if the op interface says so.
757 if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
758 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
759 LLVM_DEBUG(llvm::dbgs()
760 << " no conflict: op interace of reading op says 'no'\n");
761 continue;
762 }
763 }
764
765 if (conflictingWritingOp != readingOp) {
766 if (auto bufferizableOp =
767 options.dynCastBufferizableOp(conflictingWritingOp)) {
768 if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
769 state)) {
770 LLVM_DEBUG(
771 llvm::dbgs()
772 << " no conflict: op interace of writing op says 'no'\n");
773 continue;
774 }
775 }
776 }
777
778 // Check all possible definitions.
779 for (Value definition : definitions) {
780 LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
781
782 // No conflict if the conflicting write happens before the definition.
783 if (Operation *defOp = definition.getDefiningOp()) {
784 if (happensBefore(a: conflictingWritingOp, b: defOp, domInfo)) {
785 // conflictingWritingOp happens before defOp. No conflict.
786 LLVM_DEBUG(llvm::dbgs()
787 << " no conflict: write happens before definition\n");
788 continue;
789 }
790 // No conflict if conflictingWritingOp is contained in defOp.
791 if (defOp->isProperAncestor(other: conflictingWritingOp)) {
792 LLVM_DEBUG(
793 llvm::dbgs()
794 << " no conflict: write is contained in definition\n");
795 continue;
796 }
797 } else {
798 auto bbArg = cast<BlockArgument>(Val&: definition);
799 Block *block = bbArg.getOwner();
800 if (!block->findAncestorOpInBlock(op&: *conflictingWritingOp)) {
801 LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
802 "and write happens outside of block\n");
803 // conflictingWritingOp happens outside of the block. No
804 // conflict.
805 continue;
806 }
807 }
808
809 // No conflict if the conflicting write and the definition are the same
810 // use.
811 AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite);
812 if (aliases.getNumAliases() == 1 &&
813 aliases.getAliases()[0].value == definition) {
814 LLVM_DEBUG(llvm::dbgs()
815 << " no conflict: definition and write are same\n");
816 continue;
817 }
818
819 // All requirements are met. Conflict found!
820
821 if (options.printConflicts)
822 annotateConflict(uRead, uConflictingWrite, definition);
823 LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");
824 return true;
825 }
826 }
827 }
828
829 return false;
830}
831
832// Helper function to iterate on aliases of `root` and capture the writes.
833static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
834 const OneShotAnalysisState &state) {
835 state.applyOnAliases(v: root, fun: [&](Value alias) {
836 for (auto &use : alias.getUses())
837 // Inplace write to a value that aliases root.
838 if (isInplaceMemoryWrite(opOperand&: use, state))
839 res.insert(V: &use);
840 });
841}
842
843// Helper function to iterate on aliases of `root` and capture the reads.
844static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
845 const OneShotAnalysisState &state) {
846 state.applyOnAliases(v: root, fun: [&](Value alias) {
847 for (auto &use : alias.getUses()) {
848 // Read of a value that aliases root.
849 if (state.bufferizesToMemoryRead(use)) {
850 res.insert(V: &use);
851 continue;
852 }
853
854 // Read of a dependent value in the SSA use-def chain. E.g.:
855 //
856 // %0 = ...
857 // %1 = tensor.extract_slice %0 {not_analyzed_yet}
858 // "read"(%1)
859 //
860 // In the above example, getAliasingReads(%0) includes the first OpOperand
861 // of the tensor.extract_slice op. The extract_slice itself does not read
862 // but its aliasing result is eventually fed into an op that does.
863 //
864 // Note: This is considered a "read" only if the use does not bufferize to
865 // a memory write. (We already ruled out memory reads. In case of a memory
866 // write, the buffer would be entirely overwritten; in the above example
867 // there would then be no flow of data from the extract_slice operand to
868 // its result's uses.)
869 if (!state.bufferizesToMemoryWrite(use)) {
870 AliasingValueList aliases = state.getAliasingValues(use);
871 if (llvm::any_of(Range&: aliases, P: [&](AliasingValue a) {
872 return state.isValueRead(a.value);
873 }))
874 res.insert(V: &use);
875 }
876 }
877 });
878}
879
880/// Return true if bufferizing `operand` inplace would create a conflict. A read
881/// R and a write W of the same alias set is a conflict if inplace bufferization
882/// of W changes the value read by R to a value different from the one that
883/// would be expected by tracing back R's origin through SSA use-def chains.
884/// A conflict can only be introduced by a new alias and/or an inplace
885/// bufferization decision.
886///
887/// Example:
888/// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
889/// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
890/// %e = tensor.extract_slice %1
891/// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
892/// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
893///
894/// In the above example, the two TransferWriteOps have already been decided to
895/// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
896/// conflict because:
897/// * According to SSA use-def chains, we expect to read the result of %1.
898/// * However, adding an alias {%0, %t} would mean that the second
899/// TransferWriteOp overwrites the result of the first one. Therefore, the
900/// TransferReadOp would no longer be reading the result of %1.
901///
902/// If `checkConsistencyOnly` is true, this function checks if there is a
903/// read-after-write conflict without bufferizing `operand` inplace. This would
904/// indicate a problem with the current inplace bufferization decisions.
905///
906/// Note: If `checkConsistencyOnly`, this function may be called with a null
907/// OpResult. In that case, only the consistency of bufferization decisions
908/// involving aliases of the given OpOperand are checked.
909static bool wouldCreateReadAfterWriteInterference(
910 OpOperand &operand, const DominanceInfo &domInfo,
911 OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
912 // Collect reads and writes of all aliases of OpOperand and OpResult.
913 DenseSet<OpOperand *> usesRead, usesWrite;
914 getAliasingReads(res&: usesRead, root: operand.get(), state);
915 getAliasingInplaceWrites(res&: usesWrite, root: operand.get(), state);
916 for (AliasingValue alias : state.getAliasingValues(operand)) {
917 getAliasingReads(usesRead, alias.value, state);
918 getAliasingInplaceWrites(usesWrite, alias.value, state);
919 }
920 if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
921 usesWrite.insert(V: &operand);
922
923 return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state);
924}
925
926/// Annotate IR with details about the detected non-writability conflict.
927static void annotateNonWritableTensor(Value value) {
928 static int64_t counter = 0;
929 OpBuilder b(value.getContext());
930 std::string id = "W_" + std::to_string(val: counter++);
931 if (auto opResult = dyn_cast<OpResult>(Val&: value)) {
932 std::string attr = id + "[NOT-WRITABLE: result " +
933 std::to_string(val: opResult.getResultNumber()) + "]";
934 opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr());
935 } else {
936 auto bbArg = cast<BlockArgument>(Val&: value);
937 std::string attr = id + "[NOT-WRITABLE: bbArg " +
938 std::to_string(val: bbArg.getArgNumber()) + "]";
939 bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr());
940 }
941}
942
943/// Return true if bufferizing `operand` inplace would create a write to a
944/// non-writable buffer.
945static bool
946wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
947 OneShotAnalysisState &state,
948 bool checkConsistencyOnly = false) {
949 bool foundWrite =
950 !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);
951
952 if (!foundWrite) {
953 // Collect writes of all aliases of OpOperand and OpResult.
954 DenseSet<OpOperand *> usesWrite;
955 getAliasingInplaceWrites(res&: usesWrite, root: operand.get(), state);
956 for (AliasingValue alias : state.getAliasingValues(operand))
957 getAliasingInplaceWrites(usesWrite, alias.value, state);
958 foundWrite = !usesWrite.empty();
959 }
960
961 if (!foundWrite)
962 return false;
963
964 // Look for a read-only tensor among all aliases.
965 bool foundReadOnly = false;
966 auto checkReadOnly = [&](Value v) {
967 if (!state.isWritable(value: v)) {
968 foundReadOnly = true;
969 if (state.getOptions().printConflicts)
970 annotateNonWritableTensor(value: v);
971 }
972 };
973 state.applyOnAliases(v: operand.get(), fun: checkReadOnly);
974 for (AliasingValue alias : state.getAliasingValues(operand))
975 state.applyOnAliases(alias.value, checkReadOnly);
976 if (foundReadOnly) {
977 LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
978 return true;
979 }
980
981 return false;
982}
983
984//===----------------------------------------------------------------------===//
985// Bufferization analyses.
986//===----------------------------------------------------------------------===//
987
988// Find the values that define the contents of the given value.
989const llvm::SetVector<Value> &
990OneShotAnalysisState::findDefinitionsCached(Value value) {
991 if (!cachedDefinitions.count(value))
992 cachedDefinitions[value] = findDefinitions(value);
993 return cachedDefinitions[value];
994}
995
996void OneShotAnalysisState::resetCache() {
997 AnalysisState::resetCache();
998 cachedDefinitions.clear();
999}
1000
1001/// Determine if `operand` can be bufferized in-place.
1002static LogicalResult
1003bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
1004 const DominanceInfo &domInfo) {
1005 LLVM_DEBUG(
1006 llvm::dbgs() << "//===-------------------------------------------===//\n"
1007 << "Analyzing operand #" << operand.getOperandNumber()
1008 << " of " << *operand.getOwner() << "\n");
1009
1010 bool foundInterference =
1011 wouldCreateWriteToNonWritableBuffer(operand, state) ||
1012 wouldCreateReadAfterWriteInterference(operand, domInfo, state);
1013
1014 if (foundInterference)
1015 state.bufferizeOutOfPlace(operand);
1016 else
1017 state.bufferizeInPlace(operand);
1018
1019 LLVM_DEBUG(llvm::dbgs()
1020 << "//===-------------------------------------------===//\n");
1021 return success();
1022}
1023
1024LogicalResult
1025OneShotAnalysisState::analyzeSingleOp(Operation *op,
1026 const DominanceInfo &domInfo) {
1027 for (OpOperand &opOperand : op->getOpOperands())
1028 if (isa<TensorType>(Val: opOperand.get().getType()))
1029 if (failed(result: bufferizableInPlaceAnalysisImpl(operand&: opOperand, state&: *this, domInfo)))
1030 return failure();
1031 return success();
1032}
1033
1034/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
1035static void equivalenceAnalysis(SmallVector<Operation *> &ops,
1036 OneShotAnalysisState &state) {
1037 for (Operation *op : ops) {
1038 if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
1039 for (OpResult opResult : op->getOpResults()) {
1040 if (!isa<TensorType>(Val: opResult.getType()))
1041 continue;
1042 AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
1043 if (aliases.getNumAliases() == 0)
1044 // Nothing to do if there are no aliasing OpOperands.
1045 continue;
1046
1047 Value firstOperand = aliases.begin()->opOperand->get();
1048 bool allEquivalent = true;
1049 for (AliasingOpOperand alias : aliases) {
1050 bool isEquiv = alias.relation == BufferRelation::Equivalent;
1051 bool isInPlace = state.isInPlace(*alias.opOperand);
1052 Value operand = alias.opOperand->get();
1053 if (isEquiv && isInPlace && alias.isDefinite) {
1054 // Found a definite, equivalent alias. Merge equivalence sets.
1055 // There can only be one definite alias, so we can stop here.
1056 state.unionEquivalenceClasses(opResult, operand);
1057 allEquivalent = false;
1058 break;
1059 }
1060 if (!isEquiv || !isInPlace)
1061 allEquivalent = false;
1062 if (!state.areEquivalentBufferizedValues(operand, firstOperand))
1063 allEquivalent = false;
1064 }
1065
1066 // If all "maybe" aliases are equivalent and the OpResult is not a new
1067 // allocation, it is a definite, equivalent alias. E.g.:
1068 //
1069 // aliasingOpOperands(%r) = {(%t0, EQUIV, MAYBE), (%t1, EQUIV, MAYBE)}
1070 // aliasingValues(%t0) = {(%r, EQUIV, MAYBE)}
1071 // aliasingValues(%t1) = {(%r, EQUIV, MAYBE)}
1072 // %r = arith.select %c, %t0, %t1 : tensor<?xf32>
1073 //
1074 // If %t0 and %t1 are equivalent, it is safe to union the equivalence
1075 // classes of %r, %t0 and %t1.
1076 if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult))
1077 state.unionEquivalenceClasses(v1: opResult, v2: firstOperand);
1078 }
1079 }
1080 }
1081}
1082
1083/// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
1084/// in `op`.
1085static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) {
1086 // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
1087 SmallVector<Operation *> ops;
1088 op->walk<WalkOrder::PostOrder>(callback: [&](Operation *op) {
1089 // No tensors => no buffers.
1090 if (none_of(Range: op->getResultTypes(), P: isaTensor))
1091 return;
1092 ops.push_back(Elt: op);
1093 });
1094
1095 equivalenceAnalysis(ops, state);
1096}
1097
1098/// "Bottom-up from terminators" heuristic.
1099static SmallVector<Operation *>
1100bottomUpFromTerminatorsHeuristic(Operation *op,
1101 const OneShotAnalysisState &state) {
1102 SetVector<Operation *> traversedOps;
1103
1104 // Find region terminators.
1105 op->walk<WalkOrder::PostOrder>(callback: [&](RegionBranchTerminatorOpInterface term) {
1106 if (!traversedOps.insert(term))
1107 return;
1108 // Follow the reverse SSA use-def chain from each yielded value as long as
1109 // we stay within the same region.
1110 SmallVector<OpResult> worklist;
1111 for (Value v : term->getOperands()) {
1112 if (!isa<TensorType>(v.getType()))
1113 continue;
1114 auto opResult = dyn_cast<OpResult>(v);
1115 if (!opResult)
1116 continue;
1117 worklist.push_back(opResult);
1118 }
1119 while (!worklist.empty()) {
1120 OpResult opResult = worklist.pop_back_val();
1121 Operation *defOp = opResult.getDefiningOp();
1122 if (!traversedOps.insert(X: defOp))
1123 continue;
1124 if (!term->getParentRegion()->findAncestorOpInRegion(*defOp))
1125 continue;
1126 AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
1127 for (auto alias : aliases) {
1128 Value v = alias.opOperand->get();
1129 if (!isa<TensorType>(v.getType()))
1130 continue;
1131 auto opResult = dyn_cast<OpResult>(v);
1132 if (!opResult)
1133 continue;
1134 worklist.push_back(opResult);
1135 }
1136 }
1137 });
1138
1139 // Analyze traversed ops, then all remaining ops.
1140 SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end());
1141 op->walk<WalkOrder::PostOrder, ReverseIterator>(callback: [&](Operation *op) {
1142 if (!traversedOps.contains(key: op) && hasTensorSemantics(op))
1143 result.push_back(Elt: op);
1144 });
1145 return result;
1146}
1147
1148LogicalResult OneShotAnalysisState::analyzeOp(Operation *op,
1149 const DominanceInfo &domInfo) {
1150 OneShotBufferizationOptions::AnalysisHeuristic heuristic =
1151 getOptions().analysisHeuristic;
1152
1153 SmallVector<Operation *> orderedOps;
1154 if (heuristic ==
1155 OneShotBufferizationOptions::AnalysisHeuristic::BottomUpFromTerminators) {
1156 orderedOps = bottomUpFromTerminatorsHeuristic(op, state: *this);
1157 } else {
1158 op->walk(callback: [&](Operation *op) {
1159 // No tensors => no buffers.
1160 if (!hasTensorSemantics(op))
1161 return;
1162 orderedOps.push_back(Elt: op);
1163 });
1164 switch (heuristic) {
1165 case OneShotBufferizationOptions::AnalysisHeuristic::BottomUp: {
1166 // Default: Walk ops in reverse for better interference analysis.
1167 std::reverse(first: orderedOps.begin(), last: orderedOps.end());
1168 break;
1169 }
1170 case OneShotBufferizationOptions::AnalysisHeuristic::TopDown: {
1171 // Ops are already sorted top-down in `orderedOps`.
1172 break;
1173 }
1174 case OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer: {
1175 assert(getOptions().analysisFuzzerSeed &&
1176 "expected that fuzzer seed it set");
1177 // This is a fuzzer. For testing purposes only. Randomize the order in
1178 // which operations are analyzed. The bufferization quality is likely
1179 // worse, but we want to make sure that no assertions are triggered
1180 // anywhere.
1181 std::mt19937 g(getOptions().analysisFuzzerSeed);
1182 llvm::shuffle(first: orderedOps.begin(), last: orderedOps.end(), g);
1183 break;
1184 }
1185 default: {
1186 llvm_unreachable("unsupported heuristic");
1187 }
1188 }
1189 }
1190
1191 // Analyze ops in the computed order.
1192 for (Operation *op : orderedOps)
1193 if (failed(result: analyzeSingleOp(op, domInfo)))
1194 return failure();
1195
1196 equivalenceAnalysis(op, state&: *this);
1197 return success();
1198}
1199
1200/// Perform various checks on the input IR to see if it contains IR constructs
1201/// that are unsupported by One-Shot Bufferize.
1202static LogicalResult
1203checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo,
1204 OneShotAnalysisState &state) {
1205 const BufferizationOptions &options = state.getOptions();
1206
1207 // Note: This walk cannot be combined with the one below because interface
1208 // methods of invalid/unsupported ops may be called during the second walk.
1209 // (On ops different from `op`.)
1210 WalkResult walkResult = op->walk([&](BufferizableOpInterface op) {
1211 // Skip ops that are not in the filter.
1212 if (!options.isOpAllowed(op: op.getOperation()))
1213 return WalkResult::advance();
1214
1215 // Check for unsupported unstructured control flow.
1216 if (!op.supportsUnstructuredControlFlow()) {
1217 for (Region &r : op->getRegions()) {
1218 if (r.getBlocks().size() > 1) {
1219 op->emitOpError("op or BufferizableOpInterface implementation does "
1220 "not support unstructured control flow, but at least "
1221 "one region has multiple blocks");
1222 return WalkResult::interrupt();
1223 }
1224 }
1225 }
1226
1227 return WalkResult::advance();
1228 });
1229 if (walkResult.wasInterrupted())
1230 return failure();
1231
1232 walkResult = op->walk([&](BufferizableOpInterface op) {
1233 // Skip ops that are not in the filter.
1234 if (!options.isOpAllowed(op: op.getOperation()))
1235 return WalkResult::advance();
1236
1237 // Input IR may not contain any ToTensorOps without the "restrict"
1238 // attribute. Such tensors may alias any other tensor, which is currently
1239 // not handled in the analysis.
1240 if (auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) {
1241 if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) {
1242 op->emitOpError("to_tensor ops without `restrict` are not supported by "
1243 "One-Shot Analysis");
1244 return WalkResult::interrupt();
1245 }
1246 }
1247
1248 for (OpOperand &opOperand : op->getOpOperands()) {
1249 if (isa<TensorType>(opOperand.get().getType())) {
1250 if (wouldCreateReadAfterWriteInterference(
1251 opOperand, domInfo, state,
1252 /*checkConsistencyOnly=*/true)) {
1253 // This error can happen if certain "mustBufferizeInPlace" interface
1254 // methods are implemented incorrectly, such that the IR already has
1255 // a RaW conflict before making any bufferization decisions. It can
1256 // also happen if the bufferization.materialize_in_destination is used
1257 // in such a way that a RaW conflict is not avoidable.
1258 op->emitOpError("not bufferizable under the given constraints: "
1259 "cannot avoid RaW conflict");
1260 return WalkResult::interrupt();
1261 }
1262
1263 if (state.isInPlace(opOperand) &&
1264 wouldCreateWriteToNonWritableBuffer(
1265 opOperand, state, /*checkConsistencyOnly=*/true)) {
1266 op->emitOpError("not bufferizable under the given constraints: would "
1267 "write to read-only buffer");
1268 return WalkResult::interrupt();
1269 }
1270 }
1271 }
1272
1273 return WalkResult::advance();
1274 });
1275
1276 return success(isSuccess: !walkResult.wasInterrupted());
1277}
1278
1279/// Annotate the IR with the result of the analysis. For testing/debugging only.
1280static void
1281annotateOpsWithBufferizationMarkers(Operation *op,
1282 const OneShotAnalysisState &state) {
1283 // Add __inplace_operands_attr__.
1284 op->walk(callback: [&](Operation *op) {
1285 for (OpOperand &opOperand : op->getOpOperands())
1286 if (isa<TensorType>(Val: opOperand.get().getType()))
1287 setInPlaceOpOperand(opOperand, inPlace: state.isInPlace(opOperand));
1288 });
1289}
1290
1291static void annotateOpsWithAliasSets(Operation *op,
1292 const OneShotAnalysisState &state) {
1293 AsmState asmState(op);
1294 Builder b(op->getContext());
1295 // Helper function to build an array attribute of aliasing SSA value strings.
1296 auto buildAliasesArray = [&](Value v) {
1297 SmallVector<Attribute> aliases;
1298 state.applyOnAliases(v, fun: [&](Value alias) {
1299 std::string buffer;
1300 llvm::raw_string_ostream stream(buffer);
1301 alias.printAsOperand(os&: stream, state&: asmState);
1302 aliases.push_back(b.getStringAttr(stream.str()));
1303 });
1304 return b.getArrayAttr(aliases);
1305 };
1306
1307 op->walk(callback: [&](Operation *op) {
1308 // Build alias set array for every OpResult.
1309 SmallVector<Attribute> opResultAliasSets;
1310 for (OpResult opResult : op->getOpResults()) {
1311 if (llvm::isa<TensorType>(Val: opResult.getType())) {
1312 opResultAliasSets.push_back(Elt: buildAliasesArray(opResult));
1313 }
1314 }
1315 if (!opResultAliasSets.empty())
1316 op->setAttr(kOpResultAliasSetAttrName, b.getArrayAttr(opResultAliasSets));
1317
1318 // Build alias set array for every BlockArgument.
1319 SmallVector<Attribute> regionAliasSets;
1320 bool hasTensorBbArg = false;
1321 for (Region &r : op->getRegions()) {
1322 SmallVector<Attribute> blockAliasSets;
1323 for (Block &block : r.getBlocks()) {
1324 SmallVector<Attribute> bbArgAliasSets;
1325 for (BlockArgument bbArg : block.getArguments()) {
1326 if (llvm::isa<TensorType>(Val: bbArg.getType())) {
1327 bbArgAliasSets.push_back(Elt: buildAliasesArray(bbArg));
1328 hasTensorBbArg = true;
1329 }
1330 }
1331 blockAliasSets.push_back(b.getArrayAttr(bbArgAliasSets));
1332 }
1333 regionAliasSets.push_back(b.getArrayAttr(blockAliasSets));
1334 }
1335 if (hasTensorBbArg)
1336 op->setAttr(kBbArgAliasSetAttrName, b.getArrayAttr(regionAliasSets));
1337 });
1338}
1339
1340LogicalResult bufferization::analyzeOp(Operation *op,
1341 OneShotAnalysisState &state,
1342 BufferizationStatistics *statistics) {
1343 DominanceInfo domInfo(op);
1344 const OneShotBufferizationOptions &options = state.getOptions();
1345
1346 if (failed(result: checkPreBufferizationAssumptions(op, domInfo, state)))
1347 return failure();
1348
1349 // If the analysis fails, just return.
1350 if (failed(result: state.analyzeOp(op, domInfo)))
1351 return failure();
1352
1353 if (statistics) {
1354 statistics->numTensorInPlace = state.getStatNumTensorInPlace();
1355 statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace();
1356 }
1357
1358 bool failedAnalysis = false;
1359
1360 // Gather some extra analysis data.
1361 state.gatherUndefinedTensorUses(op);
1362
1363 // Analysis verification: After setting up alias/equivalence sets, each op
1364 // can check for expected invariants/limitations and fail the analysis if
1365 // necessary.
1366 op->walk(callback: [&](Operation *op) {
1367 if (BufferizableOpInterface bufferizableOp =
1368 options.dynCastBufferizableOp(op))
1369 failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
1370 });
1371
1372 // Annotate operations if we only want to report the analysis.
1373 if (options.testAnalysisOnly)
1374 annotateOpsWithBufferizationMarkers(op, state);
1375 if (options.dumpAliasSets)
1376 annotateOpsWithAliasSets(op, state);
1377
1378 return success(isSuccess: !failedAnalysis);
1379}
1380
1381LogicalResult
1382bufferization::runOneShotBufferize(Operation *op,
1383 const OneShotBufferizationOptions &options,
1384 BufferizationStatistics *statistics) {
1385 assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
1386 "invalid combination of bufferization flags");
1387 if (!options.copyBeforeWrite) {
1388 // If a buffer is copied before every write, no analysis is needed.
1389 if (failed(result: insertTensorCopies(op, options, statistics)))
1390 return failure();
1391 }
1392 if (options.testAnalysisOnly)
1393 return success();
1394 return bufferizeOp(op, options, statistics);
1395}
1396

source code of mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp