1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" |
10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" |
13 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
14 | #include "mlir/IR/Attributes.h" |
15 | #include "mlir/IR/Dialect.h" |
16 | #include "mlir/IR/Operation.h" |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::bufferization; |
20 | |
21 | namespace { |
22 | /// Bufferization of arith.constant. Replace with memref.get_global. |
23 | struct ConstantOpInterface |
24 | : public BufferizableOpInterface::ExternalModel<ConstantOpInterface, |
25 | arith::ConstantOp> { |
26 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
27 | const BufferizationOptions &options) const { |
28 | auto constantOp = cast<arith::ConstantOp>(op); |
29 | auto type = dyn_cast<RankedTensorType>(constantOp.getType()); |
30 | |
31 | // Only ranked tensors are supported. |
32 | if (!type) |
33 | return failure(); |
34 | |
35 | Attribute memorySpace; |
36 | if (auto memSpace = options.defaultMemorySpaceFn(type)) |
37 | memorySpace = *memSpace; |
38 | else |
39 | return constantOp->emitError("could not infer memory space" ); |
40 | |
41 | // Only constants inside a module are supported. |
42 | auto moduleOp = constantOp->getParentOfType<ModuleOp>(); |
43 | if (!moduleOp) |
44 | return failure(); |
45 | |
46 | // Create global memory segment and replace tensor with memref pointing to |
47 | // that memory segment. |
48 | FailureOr<memref::GlobalOp> globalOp = |
49 | getGlobalFor(constantOp, options.bufferAlignment, memorySpace); |
50 | if (failed(globalOp)) |
51 | return failure(); |
52 | memref::GlobalOp globalMemref = *globalOp; |
53 | replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( |
54 | rewriter, op, globalMemref.getType(), globalMemref.getName()); |
55 | |
56 | return success(); |
57 | } |
58 | |
59 | bool isWritable(Operation *op, Value value, |
60 | const AnalysisState &state) const { |
61 | // Memory locations returned by memref::GetGlobalOp may not be written to. |
62 | assert(isa<OpResult>(value)); |
63 | return false; |
64 | } |
65 | }; |
66 | |
67 | struct IndexCastOpInterface |
68 | : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface, |
69 | arith::IndexCastOp> { |
70 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
71 | const AnalysisState &state) const { |
72 | return false; |
73 | } |
74 | |
75 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
76 | const AnalysisState &state) const { |
77 | return false; |
78 | } |
79 | |
80 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
81 | const AnalysisState &state) const { |
82 | return {{op->getResult(idx: 0), BufferRelation::Equivalent}}; |
83 | } |
84 | |
85 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
86 | const BufferizationOptions &options) const { |
87 | auto castOp = cast<arith::IndexCastOp>(op); |
88 | auto resultTensorType = cast<TensorType>(castOp.getType()); |
89 | |
90 | FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options); |
91 | if (failed(result: source)) |
92 | return failure(); |
93 | auto sourceType = cast<BaseMemRefType>(Val: source->getType()); |
94 | |
95 | // Result type should have same layout and address space as the source type. |
96 | BaseMemRefType resultType; |
97 | if (auto rankedMemRefType = dyn_cast<MemRefType>(sourceType)) { |
98 | resultType = MemRefType::get( |
99 | rankedMemRefType.getShape(), resultTensorType.getElementType(), |
100 | rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace()); |
101 | } else { |
102 | auto unrankedMemrefType = cast<UnrankedMemRefType>(sourceType); |
103 | resultType = UnrankedMemRefType::get(resultTensorType.getElementType(), |
104 | unrankedMemrefType.getMemorySpace()); |
105 | } |
106 | |
107 | replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType, |
108 | *source); |
109 | return success(); |
110 | } |
111 | }; |
112 | |
113 | /// Bufferization of arith.select. Just replace the operands. |
114 | struct SelectOpInterface |
115 | : public BufferizableOpInterface::ExternalModel<SelectOpInterface, |
116 | arith::SelectOp> { |
117 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
118 | const AnalysisState &state) const { |
119 | return false; |
120 | } |
121 | |
122 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
123 | const AnalysisState &state) const { |
124 | return false; |
125 | } |
126 | |
127 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
128 | const AnalysisState &state) const { |
129 | return {{op->getOpResult(idx: 0) /*result*/, BufferRelation::Equivalent, |
130 | /*isDefinite=*/false}}; |
131 | } |
132 | |
133 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
134 | const BufferizationOptions &options) const { |
135 | auto selectOp = cast<arith::SelectOp>(op); |
136 | Location loc = selectOp.getLoc(); |
137 | |
138 | // Elementwise conditions are not supported yet. To bufferize such an op, |
139 | // it could be lowered to an elementwise "linalg.generic" with a new |
140 | // "tensor.empty" out tensor, followed by "empty tensor elimination". Such |
141 | // IR will bufferize. |
142 | if (!selectOp.getCondition().getType().isInteger(1)) |
143 | return op->emitOpError(message: "only i1 condition values are supported" ); |
144 | |
145 | // TODO: It would be more efficient to copy the result of the `select` op |
146 | // instead of its OpOperands. In the worst case, 2 copies are inserted at |
147 | // the moment (one for each tensor). When copying the op result, only one |
148 | // copy would be needed. |
149 | FailureOr<Value> maybeTrueBuffer = |
150 | getBuffer(rewriter, selectOp.getTrueValue(), options); |
151 | FailureOr<Value> maybeFalseBuffer = |
152 | getBuffer(rewriter, selectOp.getFalseValue(), options); |
153 | if (failed(result: maybeTrueBuffer) || failed(result: maybeFalseBuffer)) |
154 | return failure(); |
155 | Value trueBuffer = *maybeTrueBuffer; |
156 | Value falseBuffer = *maybeFalseBuffer; |
157 | |
158 | // The "true" and the "false" operands must have the same type. If the |
159 | // buffers have different types, they differ only in their layout map. Cast |
160 | // both of them to the most dynamic MemRef type. |
161 | if (trueBuffer.getType() != falseBuffer.getType()) { |
162 | auto targetType = |
163 | bufferization::getBufferType(value: selectOp.getResult(), options); |
164 | if (failed(targetType)) |
165 | return failure(); |
166 | if (trueBuffer.getType() != *targetType) |
167 | trueBuffer = |
168 | rewriter.create<memref::CastOp>(loc, *targetType, trueBuffer); |
169 | if (falseBuffer.getType() != *targetType) |
170 | falseBuffer = |
171 | rewriter.create<memref::CastOp>(loc, *targetType, falseBuffer); |
172 | } |
173 | |
174 | replaceOpWithNewBufferizedOp<arith::SelectOp>( |
175 | rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); |
176 | return success(); |
177 | } |
178 | |
179 | FailureOr<BaseMemRefType> |
180 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
181 | SmallVector<Value> &invocationStack) const { |
182 | auto selectOp = cast<arith::SelectOp>(op); |
183 | assert(value == selectOp.getResult() && "invalid value" ); |
184 | auto trueType = bufferization::getBufferType(value: selectOp.getTrueValue(), |
185 | options, invocationStack); |
186 | auto falseType = bufferization::getBufferType(value: selectOp.getFalseValue(), |
187 | options, invocationStack); |
188 | if (failed(trueType) || failed(falseType)) |
189 | return failure(); |
190 | if (*trueType == *falseType) |
191 | return *trueType; |
192 | if (trueType->getMemorySpace() != falseType->getMemorySpace()) |
193 | return op->emitError(message: "inconsistent memory space on true/false operands" ); |
194 | |
195 | // If the buffers have different types, they differ only in their layout |
196 | // map. |
197 | auto memrefType = llvm::cast<MemRefType>(*trueType); |
198 | return getMemRefTypeWithFullyDynamicLayout( |
199 | RankedTensorType::get(memrefType.getShape(), |
200 | memrefType.getElementType()), |
201 | memrefType.getMemorySpace()); |
202 | } |
203 | }; |
204 | |
205 | } // namespace |
206 | |
207 | void mlir::arith::registerBufferizableOpInterfaceExternalModels( |
208 | DialectRegistry ®istry) { |
209 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, ArithDialect *dialect) { |
210 | ConstantOp::attachInterface<ConstantOpInterface>(*ctx); |
211 | IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx); |
212 | SelectOp::attachInterface<SelectOpInterface>(*ctx); |
213 | }); |
214 | } |
215 | |