1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
13 | #include "mlir/Dialect/Shape/IR/Shape.h" |
14 | #include "mlir/IR/Dialect.h" |
15 | #include "mlir/IR/Operation.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::bufferization; |
20 | using namespace mlir::shape; |
21 | |
22 | namespace mlir { |
23 | namespace shape { |
24 | namespace { |
25 | |
26 | /// Bufferization of shape.assuming. |
27 | struct AssumingOpInterface |
28 | : public BufferizableOpInterface::ExternalModel<AssumingOpInterface, |
29 | shape::AssumingOp> { |
30 | AliasingOpOperandList |
31 | getAliasingOpOperands(Operation *op, Value value, |
32 | const AnalysisState &state) const { |
33 | // AssumingOps do not have tensor OpOperands. The yielded value can be any |
34 | // SSA value that is in scope. To allow for use-def chain traversal through |
35 | // AssumingOps in the analysis, the corresponding yield value is considered |
36 | // to be aliasing with the result. |
37 | auto assumingOp = cast<shape::AssumingOp>(op); |
38 | size_t resultNum = std::distance(first: op->getOpResults().begin(), |
39 | last: llvm::find(Range: op->getOpResults(), Val: value)); |
40 | // TODO: Support multiple blocks. |
41 | assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) && |
42 | "expected exactly 1 block" ); |
43 | auto yieldOp = dyn_cast<shape::AssumingYieldOp>( |
44 | assumingOp.getDoRegion().front().getTerminator()); |
45 | assert(yieldOp && "expected shape.assuming_yield terminator" ); |
46 | return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; |
47 | } |
48 | |
49 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
50 | const BufferizationOptions &options, |
51 | BufferizationState &state) const { |
52 | auto assumingOp = cast<shape::AssumingOp>(op); |
53 | assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) && |
54 | "only 1 block supported" ); |
55 | auto yieldOp = cast<shape::AssumingYieldOp>( |
56 | assumingOp.getDoRegion().front().getTerminator()); |
57 | |
58 | // Create new op and move over region. |
59 | TypeRange newResultTypes(yieldOp.getOperands()); |
60 | auto newOp = rewriter.create<shape::AssumingOp>( |
61 | op->getLoc(), newResultTypes, assumingOp.getWitness()); |
62 | newOp.getDoRegion().takeBody(assumingOp.getRegion()); |
63 | |
64 | // Update all uses of the old op. |
65 | rewriter.setInsertionPointAfter(newOp); |
66 | SmallVector<Value> newResults; |
67 | for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { |
68 | if (isa<TensorType>(it.value())) { |
69 | newResults.push_back(rewriter.create<bufferization::ToTensorOp>( |
70 | assumingOp.getLoc(), newOp->getResult(it.index()))); |
71 | } else { |
72 | newResults.push_back(newOp->getResult(it.index())); |
73 | } |
74 | } |
75 | |
76 | // Replace old op. |
77 | rewriter.replaceOp(assumingOp, newResults); |
78 | |
79 | return success(); |
80 | } |
81 | }; |
82 | |
83 | /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing |
84 | /// ops, so this is for analysis only. |
85 | struct AssumingYieldOpInterface |
86 | : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface, |
87 | shape::AssumingYieldOp> { |
88 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
89 | const AnalysisState &state) const { |
90 | return true; |
91 | } |
92 | |
93 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
94 | const AnalysisState &state) const { |
95 | return false; |
96 | } |
97 | |
98 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
99 | const AnalysisState &state) const { |
100 | assert(isa<shape::AssumingOp>(op->getParentOp()) && |
101 | "expected that parent is an AssumingOp" ); |
102 | OpResult opResult = |
103 | op->getParentOp()->getResult(idx: opOperand.getOperandNumber()); |
104 | return {{opResult, BufferRelation::Equivalent}}; |
105 | } |
106 | |
107 | bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
108 | const AnalysisState &state) const { |
109 | // Yield operands always bufferize inplace. Otherwise, an alloc + copy |
110 | // may be generated inside the block. We should not return/yield allocations |
111 | // when possible. |
112 | return true; |
113 | } |
114 | |
115 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
116 | const BufferizationOptions &options, |
117 | BufferizationState &state) const { |
118 | auto yieldOp = cast<shape::AssumingYieldOp>(op); |
119 | SmallVector<Value> newResults; |
120 | for (Value value : yieldOp.getOperands()) { |
121 | if (isa<TensorType>(value.getType())) { |
122 | FailureOr<Value> buffer = getBuffer(rewriter, value, options, state); |
123 | if (failed(buffer)) |
124 | return failure(); |
125 | newResults.push_back(*buffer); |
126 | } else { |
127 | newResults.push_back(value); |
128 | } |
129 | } |
130 | replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op, |
131 | newResults); |
132 | return success(); |
133 | } |
134 | }; |
135 | |
136 | } // namespace |
137 | } // namespace shape |
138 | } // namespace mlir |
139 | |
140 | void mlir::shape::registerBufferizableOpInterfaceExternalModels( |
141 | DialectRegistry ®istry) { |
142 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, shape::ShapeDialect *dialect) { |
143 | shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx); |
144 | shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx); |
145 | }); |
146 | } |
147 | |