1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
13 | #include "mlir/Dialect/Shape/IR/Shape.h" |
14 | #include "mlir/IR/Dialect.h" |
15 | #include "mlir/IR/Operation.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::bufferization; |
20 | using namespace mlir::shape; |
21 | |
22 | namespace mlir { |
23 | namespace shape { |
24 | namespace { |
25 | |
26 | /// Bufferization of shape.assuming. |
27 | struct AssumingOpInterface |
28 | : public BufferizableOpInterface::ExternalModel<AssumingOpInterface, |
29 | shape::AssumingOp> { |
30 | AliasingOpOperandList |
31 | getAliasingOpOperands(Operation *op, Value value, |
32 | const AnalysisState &state) const { |
33 | // AssumingOps do not have tensor OpOperands. The yielded value can be any |
34 | // SSA value that is in scope. To allow for use-def chain traversal through |
35 | // AssumingOps in the analysis, the corresponding yield value is considered |
36 | // to be aliasing with the result. |
37 | auto assumingOp = cast<shape::AssumingOp>(op); |
38 | size_t resultNum = std::distance(first: op->getOpResults().begin(), |
39 | last: llvm::find(Range: op->getOpResults(), Val: value)); |
40 | // TODO: Support multiple blocks. |
41 | assert(assumingOp.getDoRegion().getBlocks().size() == 1 && |
42 | "expected exactly 1 block" ); |
43 | auto yieldOp = dyn_cast<shape::AssumingYieldOp>( |
44 | assumingOp.getDoRegion().front().getTerminator()); |
45 | assert(yieldOp && "expected shape.assuming_yield terminator" ); |
46 | return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; |
47 | } |
48 | |
49 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
50 | const BufferizationOptions &options) const { |
51 | auto assumingOp = cast<shape::AssumingOp>(op); |
52 | assert(assumingOp.getDoRegion().getBlocks().size() == 1 && |
53 | "only 1 block supported" ); |
54 | auto yieldOp = cast<shape::AssumingYieldOp>( |
55 | assumingOp.getDoRegion().front().getTerminator()); |
56 | |
57 | // Create new op and move over region. |
58 | TypeRange newResultTypes(yieldOp.getOperands()); |
59 | auto newOp = rewriter.create<shape::AssumingOp>( |
60 | op->getLoc(), newResultTypes, assumingOp.getWitness()); |
61 | newOp.getDoRegion().takeBody(assumingOp.getRegion()); |
62 | |
63 | // Update all uses of the old op. |
64 | rewriter.setInsertionPointAfter(newOp); |
65 | SmallVector<Value> newResults; |
66 | for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { |
67 | if (isa<TensorType>(it.value())) { |
68 | newResults.push_back(rewriter.create<bufferization::ToTensorOp>( |
69 | assumingOp.getLoc(), newOp->getResult(it.index()))); |
70 | } else { |
71 | newResults.push_back(newOp->getResult(it.index())); |
72 | } |
73 | } |
74 | |
75 | // Replace old op. |
76 | rewriter.replaceOp(assumingOp, newResults); |
77 | |
78 | return success(); |
79 | } |
80 | }; |
81 | |
82 | /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing |
83 | /// ops, so this is for analysis only. |
84 | struct AssumingYieldOpInterface |
85 | : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface, |
86 | shape::AssumingYieldOp> { |
87 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
88 | const AnalysisState &state) const { |
89 | return true; |
90 | } |
91 | |
92 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
93 | const AnalysisState &state) const { |
94 | return false; |
95 | } |
96 | |
97 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
98 | const AnalysisState &state) const { |
99 | assert(isa<shape::AssumingOp>(op->getParentOp()) && |
100 | "expected that parent is an AssumingOp" ); |
101 | OpResult opResult = |
102 | op->getParentOp()->getResult(idx: opOperand.getOperandNumber()); |
103 | return {{opResult, BufferRelation::Equivalent}}; |
104 | } |
105 | |
106 | bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
107 | const AnalysisState &state) const { |
108 | // Yield operands always bufferize inplace. Otherwise, an alloc + copy |
109 | // may be generated inside the block. We should not return/yield allocations |
110 | // when possible. |
111 | return true; |
112 | } |
113 | |
114 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
115 | const BufferizationOptions &options) const { |
116 | auto yieldOp = cast<shape::AssumingYieldOp>(op); |
117 | SmallVector<Value> newResults; |
118 | for (Value value : yieldOp.getOperands()) { |
119 | if (isa<TensorType>(value.getType())) { |
120 | FailureOr<Value> buffer = getBuffer(rewriter, value, options); |
121 | if (failed(buffer)) |
122 | return failure(); |
123 | newResults.push_back(*buffer); |
124 | } else { |
125 | newResults.push_back(value); |
126 | } |
127 | } |
128 | replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op, |
129 | newResults); |
130 | return success(); |
131 | } |
132 | }; |
133 | |
134 | } // namespace |
135 | } // namespace shape |
136 | } // namespace mlir |
137 | |
138 | void mlir::shape::registerBufferizableOpInterfaceExternalModels( |
139 | DialectRegistry ®istry) { |
140 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, shape::ShapeDialect *dialect) { |
141 | shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx); |
142 | shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx); |
143 | }); |
144 | } |
145 | |