1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15#include "mlir/IR/Dialect.h"
16#include "mlir/IR/Operation.h"
17
18using namespace mlir;
19using namespace mlir::bufferization;
20using namespace mlir::vector;
21
22namespace mlir {
23namespace vector {
24namespace {
25
26/// Bufferization of vector.transfer_read. Replaced with a new
27/// vector.transfer_read that operates on a memref.
28struct TransferReadOpInterface
29 : public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
30 vector::TransferReadOp> {
31 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
32 const AnalysisState &state) const {
33 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
34 "only tensor types expected");
35 return true;
36 }
37
38 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
39 const AnalysisState &state) const {
40 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
41 "only tensor types expected");
42 return false;
43 }
44
45 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
46 const AnalysisState &state) const {
47 return {};
48 }
49
50 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51 const BufferizationOptions &options,
52 BufferizationState &state) const {
53 auto readOp = cast<vector::TransferReadOp>(op);
54 assert(isa<TensorType>(readOp.getShapedType()) &&
55 "only tensor types expected");
56 FailureOr<Value> buffer =
57 getBuffer(rewriter, readOp.getBase(), options, state);
58 if (failed(Result: buffer))
59 return failure();
60 replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
61 rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
62 readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
63 readOp.getInBoundsAttr());
64 return success();
65 }
66};
67
68/// Bufferization of vector.transfer_write. Replace with a new
69/// vector.transfer_write that operates on a memref.
70///
71/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
72/// implementations for DestinationStyle ops.
73struct TransferWriteOpInterface
74 : public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
75 vector::TransferWriteOp> {
76 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
77 const AnalysisState &state) const {
78 auto writeOp = cast<vector::TransferWriteOp>(op);
79
80 // Does not bufferize to a memory read if the vector completely overwrites
81 // the buffer.
82
83 // Destination must have static shape.
84 if (!writeOp.getShapedType().hasStaticShape())
85 return true;
86
87 // All offsets must be 0.
88 for (Value offset : writeOp.getIndices()) {
89 if (getConstantIntValue(offset) != 0)
90 return true;
91 }
92
93 // There is no mask.
94 if (writeOp.isMasked())
95 return true;
96
97 // Must write at least the full dimension size.
98 for (auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
99 writeOp.getVectorType().getShape())) {
100 if (d0 > d1)
101 return true;
102 }
103
104 return false;
105 }
106
107 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
108 const BufferizationOptions &options,
109 BufferizationState &state) const {
110 auto writeOp = cast<vector::TransferWriteOp>(op);
111 assert(isa<TensorType>(writeOp.getShapedType()) &&
112 "only tensor types expected");
113
114 // Create a new transfer_write on buffer that doesn't have a return value.
115 FailureOr<Value> resultBuffer =
116 getBuffer(rewriter, writeOp.getBase(), options, state);
117 if (failed(Result: resultBuffer))
118 return failure();
119 rewriter.create<vector::TransferWriteOp>(
120 writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
121 writeOp.getIndices(), writeOp.getPermutationMapAttr(),
122 writeOp.getMask(), writeOp.getInBoundsAttr());
123 replaceOpWithBufferizedValues(rewriter, op, values: *resultBuffer);
124
125 return success();
126 }
127};
128
129/// Bufferization of vector.gather. Replaced with a new vector.gather that
130/// operates on a memref.
131struct GatherOpInterface
132 : public BufferizableOpInterface::ExternalModel<GatherOpInterface,
133 vector::GatherOp> {
134 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
135 const AnalysisState &state) const {
136 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
137 "only tensor types expected");
138 return true;
139 }
140
141 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
142 const AnalysisState &state) const {
143 assert(isa<RankedTensorType>(opOperand.get().getType()) &&
144 "only tensor types expected");
145 return false;
146 }
147
148 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
149 const AnalysisState &state) const {
150 return {};
151 }
152
153 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
154 const BufferizationOptions &options,
155 BufferizationState &state) const {
156 auto gatherOp = cast<vector::GatherOp>(op);
157 assert(isa<TensorType>(gatherOp.getBaseType()) &&
158 "only tensor types expected");
159 FailureOr<Value> buffer =
160 getBuffer(rewriter, gatherOp.getBase(), options, state);
161 if (failed(Result: buffer))
162 return failure();
163 replaceOpWithNewBufferizedOp<vector::GatherOp>(
164 rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
165 gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
166 gatherOp.getPassThru());
167 return success();
168 }
169};
170
171/// Bufferization of vector.mask. Replaced with a new vector.mask that
172/// operates on a memref.
173struct MaskOpInterface
174 : public BufferizableOpInterface::ExternalModel<MaskOpInterface,
175 vector::MaskOp> {
176 AliasingOpOperandList
177 getAliasingOpOperands(Operation *op, Value value,
178 const AnalysisState &state) const {
179 // MaskOps do not have tensor OpOperands. The yielded values are the result
180 // of the wrapped op.
181 auto maskOp = cast<vector::MaskOp>(op);
182 size_t resultNum = std::distance(first: op->getOpResults().begin(),
183 last: llvm::find(Range: op->getOpResults(), Val: value));
184 auto yieldOp =
185 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
186 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
187 }
188
189 LogicalResult
190 resolveConflicts(Operation *op, RewriterBase &rewriter,
191 const AnalysisState &analysisState,
192 const BufferizationState &bufferizationState) const {
193 auto bufferizableOp = cast<BufferizableOpInterface>(op);
194 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
195 rewriter, analysisState, bufferizationState)))
196 return failure();
197
198 // TODO: Remove this function when vector.mask bodies can bufferize
199 // out-of-place. This is currently not supported because yielding allocs
200 // from a block leads to a memory leak and because vector.mask supports only
201 // a single op in its body.
202 auto maskOp = cast<vector::MaskOp>(op);
203 if (!maskOp.getMaskRegion()
204 .front()
205 .getOps<bufferization::AllocTensorOp>()
206 .empty())
207 return op->emitOpError(message: "body must bufferize in-place");
208
209 return success();
210 }
211
212 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
213 const BufferizationOptions &options,
214 BufferizationState &state) const {
215 auto maskOp = cast<vector::MaskOp>(op);
216
217 // Do not bufferize if the masked op is not bufferizable.
218 Operation *maskedOp = maskOp.getMaskableOp();
219 if (!options.dynCastBufferizableOp(maskedOp))
220 return success();
221
222 // Update the terminator: Drop all operands that are not results of the
223 // masked op.
224 auto yieldOp =
225 cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
226 SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
227 SmallVector<Value> newYieldedValues;
228 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
229 if (llvm::is_contained(maskedOp->getOpResults(), it.value())) {
230 newYieldedValues.push_back(it.value());
231 } else {
232 // This used to be a tensor result of the masked op, but is now a memref
233 // that is defined outside of the vector.mask op.
234 newReturnValues[it.index()] = it.value();
235 }
236 }
237 rewriter.modifyOpInPlace(yieldOp, [&]() {
238 yieldOp.getOperandsMutable().assign(newYieldedValues);
239 });
240
241 // Create a new vector.mask op.
242 ValueRange newYieldedValuesRange(newYieldedValues);
243 TypeRange newResultTypes(newYieldedValuesRange);
244 auto newOp = rewriter.create<vector::MaskOp>(
245 op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
246 /*maskableOp=*/nullptr,
247 /*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
248 newOp.getRegion().takeBody(maskOp.getMaskRegion());
249
250 // Replace all uses of the old vector.mask op.
251 int idx = 0;
252 for (int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
253 if (!newReturnValues[i])
254 newReturnValues[i] = newOp->getResult(idx++);
255 }
256 replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
257 return success();
258 }
259};
260
261/// Bufferization of vector.yield. Replaced with a new vector.yield that
262/// operates on a memref.
263struct YieldOpInterface
264 : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
265 vector::YieldOp> {
266 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
267 const AnalysisState &state) const {
268 return true;
269 }
270
271 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
272 const AnalysisState &state) const {
273 return false;
274 }
275
276 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
277 const AnalysisState &state) const {
278 return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()),
279 BufferRelation::Equivalent}};
280 }
281
282 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
283 const AnalysisState &state) const {
284 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
285 // may be generated inside the block. We should not return/yield allocations
286 // when possible.
287 return true;
288 }
289
290 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
291 const BufferizationOptions &options,
292 BufferizationState &state) const {
293 auto yieldOp = cast<vector::YieldOp>(op);
294
295 // Only supported as a vector.mask terminator.
296 auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
297 if (!maskOp)
298 return yieldOp->emitError("unsupported vector::YieldOp parent");
299
300 // Do not bufferize if the masked op is not bufferizable.
301 Operation *maskedOp = &maskOp.getMaskRegion().front().front();
302 if (!options.dynCastBufferizableOp(maskedOp))
303 return success();
304
305 // Create a new terminator with the same number of operands. Some of these
306 // may get dropped during the bufferization of vector.mask.
307 SmallVector<Value> newResults;
308 for (Value value : yieldOp.getOperands()) {
309 if (isa<TensorType>(value.getType())) {
310 FailureOr<Value> maybeBuffer =
311 getBuffer(rewriter, value, options, state);
312 if (failed(maybeBuffer))
313 return failure();
314 newResults.push_back(*maybeBuffer);
315 } else {
316 newResults.push_back(value);
317 }
318 }
319
320 replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
321 return success();
322 }
323};
324
325} // namespace
326} // namespace vector
327} // namespace mlir
328
329void mlir::vector::registerBufferizableOpInterfaceExternalModels(
330 DialectRegistry &registry) {
331 registry.addExtension(extensionFn: +[](MLIRContext *ctx, vector::VectorDialect *dialect) {
332 TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
333 TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
334 GatherOp::attachInterface<GatherOpInterface>(*ctx);
335 MaskOp::attachInterface<MaskOpInterface>(*ctx);
336 YieldOp::attachInterface<YieldOpInterface>(*ctx);
337 });
338}
339

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp