1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
14 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
15 | #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" |
16 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/SCF/IR/SCF.h" |
19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
20 | #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" |
21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
22 | #include "mlir/IR/Dialect.h" |
23 | #include "mlir/IR/Operation.h" |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::bufferization; |
27 | using namespace mlir::tensor; |
28 | |
29 | namespace mlir { |
30 | namespace tensor { |
31 | namespace { |
32 | |
33 | struct CastOpInterface |
34 | : public BufferizableOpInterface::ExternalModel<CastOpInterface, |
35 | tensor::CastOp> { |
36 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
37 | const AnalysisState &state) const { |
38 | return false; |
39 | } |
40 | |
41 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
42 | const AnalysisState &state) const { |
43 | return false; |
44 | } |
45 | |
46 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
47 | const AnalysisState &state) const { |
48 | return {{op->getResult(idx: 0), BufferRelation::Equivalent}}; |
49 | } |
50 | |
51 | FailureOr<BaseMemRefType> |
52 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
53 | SmallVector<Value> &invocationStack) const { |
54 | auto castOp = cast<tensor::CastOp>(op); |
55 | auto maybeSrcBufferType = bufferization::getBufferType( |
56 | value: castOp.getSource(), options, invocationStack); |
57 | if (failed(maybeSrcBufferType)) |
58 | return failure(); |
59 | Attribute memorySpace = maybeSrcBufferType->getMemorySpace(); |
60 | |
61 | // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref |
62 | // type in case the input is an unranked tensor type. |
63 | |
64 | // Case 1: Casting an unranked tensor |
65 | if (isa<UnrankedTensorType>(castOp.getSource().getType())) { |
66 | // When casting to a ranked tensor, we cannot infer any static offset or |
67 | // strides from the source. Assume fully dynamic. |
68 | return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); |
69 | } |
70 | |
71 | // Case 2: Casting to an unranked tensor type |
72 | if (isa<UnrankedTensorType>(castOp.getType())) { |
73 | return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); |
74 | } |
75 | |
76 | // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not |
77 | // change. |
78 | auto rankedResultType = cast<RankedTensorType>(castOp.getType()); |
79 | return MemRefType::get( |
80 | rankedResultType.getShape(), rankedResultType.getElementType(), |
81 | llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace); |
82 | } |
83 | |
84 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
85 | const BufferizationOptions &options) const { |
86 | auto castOp = cast<tensor::CastOp>(op); |
87 | |
88 | // The result buffer still has the old (pre-cast) type. |
89 | FailureOr<Value> resultBuffer = |
90 | getBuffer(rewriter, castOp.getSource(), options); |
91 | if (failed(result: resultBuffer)) |
92 | return failure(); |
93 | |
94 | // Compute the new type. |
95 | auto resultMemRefType = |
96 | bufferization::getBufferType(value: castOp.getResult(), options); |
97 | if (failed(resultMemRefType)) |
98 | return failure(); |
99 | if (resultBuffer->getType() == *resultMemRefType) { |
100 | // This cast is a no-op. |
101 | replaceOpWithBufferizedValues(rewriter, op, values: *resultBuffer); |
102 | return success(); |
103 | } |
104 | |
105 | // Replace the op with a memref.cast. |
106 | assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), |
107 | *resultMemRefType) && |
108 | "CallOp::bufferize: cast incompatible" ); |
109 | replaceOpWithNewBufferizedOp<memref::CastOp>( |
110 | rewriter, op, *resultMemRefType, *resultBuffer); |
111 | |
112 | return success(); |
113 | } |
114 | }; |
115 | |
116 | /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. |
117 | struct CollapseShapeOpInterface |
118 | : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, |
119 | tensor::CollapseShapeOp> { |
120 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
121 | const AnalysisState &state) const { |
122 | // tensor.collapse_shape may reallocate, at which point the source buffer is |
123 | // copied. I.e., there will be a memory read side effect on the bufferized |
124 | // source. This function conservatively returns "true" because whether a |
125 | // copy will be created or not is not known at this point. |
126 | return true; |
127 | } |
128 | |
129 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
130 | const AnalysisState &state) const { |
131 | return false; |
132 | } |
133 | |
134 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
135 | const AnalysisState &state) const { |
136 | // TODO: CollapseShapeOp may allocate at runtime. |
137 | return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}}; |
138 | } |
139 | |
140 | FailureOr<BaseMemRefType> |
141 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
142 | SmallVector<Value> &invocationStack) const { |
143 | auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); |
144 | auto maybeSrcBufferType = bufferization::getBufferType( |
145 | value: collapseShapeOp.getSrc(), options, invocationStack); |
146 | if (failed(maybeSrcBufferType)) |
147 | return failure(); |
148 | auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType); |
149 | bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( |
150 | srcBufferType, collapseShapeOp.getReassociationIndices()); |
151 | |
152 | if (!canBeCollapsed) { |
153 | // If dims cannot be collapsed, this op bufferizes to a new allocation. |
154 | RankedTensorType tensorResultType = collapseShapeOp.getResultType(); |
155 | return bufferization::getMemRefTypeWithStaticIdentityLayout( |
156 | tensorType: tensorResultType, memorySpace: srcBufferType.getMemorySpace()); |
157 | } |
158 | |
159 | return memref::CollapseShapeOp::computeCollapsedType( |
160 | srcBufferType, collapseShapeOp.getReassociationIndices()); |
161 | } |
162 | |
163 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
164 | const BufferizationOptions &options) const { |
165 | auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); |
166 | RankedTensorType tensorResultType = collapseShapeOp.getResultType(); |
167 | FailureOr<Value> maybeBuffer = |
168 | getBuffer(rewriter, collapseShapeOp.getSrc(), options); |
169 | if (failed(result: maybeBuffer)) |
170 | return failure(); |
171 | Value buffer = *maybeBuffer; |
172 | auto bufferType = cast<MemRefType>(buffer.getType()); |
173 | |
174 | if (tensorResultType.getRank() == 0) { |
175 | // 0-d collapses must go through a different op builder. |
176 | MemRefType resultType; |
177 | |
178 | if (bufferType.getLayout().isIdentity()) { |
179 | // Standard layout: result type has no offset. |
180 | MemRefLayoutAttrInterface layout; |
181 | resultType = MemRefType::get({}, tensorResultType.getElementType(), |
182 | layout, bufferType.getMemorySpace()); |
183 | } else { |
184 | // Source memref has a layout map: result type has the same offset as |
185 | // the source type. |
186 | SmallVector<int64_t> strides; |
187 | int64_t offset; |
188 | if (failed(getStridesAndOffset(bufferType, strides, offset))) |
189 | return failure(); |
190 | resultType = MemRefType::get( |
191 | {}, tensorResultType.getElementType(), |
192 | StridedLayoutAttr::get(op->getContext(), offset, {}), |
193 | bufferType.getMemorySpace()); |
194 | } |
195 | |
196 | replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( |
197 | rewriter, op, resultType, buffer, collapseShapeOp.getReassociation()); |
198 | return success(); |
199 | } |
200 | |
201 | // If the dims are not collapsible (due to an incompatible source layout |
202 | // map), force an out-of-place bufferization, i.e., a buffer copy. This |
203 | // newly allocated buffer will have no layout map and thus be collapsible. |
204 | bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( |
205 | bufferType, collapseShapeOp.getReassociationIndices()); |
206 | if (!canBeCollapsed) { |
207 | // TODO: Create alloc_tensor ops during TensorCopyInsertion. |
208 | AnalysisState analysisState(options); |
209 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
210 | rewriter, op->getLoc(), collapseShapeOp.getSrc(), options); |
211 | if (failed(result: tensorAlloc)) |
212 | return failure(); |
213 | auto memrefType = |
214 | MemRefType::get(collapseShapeOp.getSrcType().getShape(), |
215 | collapseShapeOp.getSrcType().getElementType(), |
216 | AffineMap(), bufferType.getMemorySpace()); |
217 | buffer = rewriter.create<bufferization::ToMemrefOp>( |
218 | op->getLoc(), memrefType, *tensorAlloc); |
219 | } |
220 | |
221 | // Result type is inferred by the builder. |
222 | replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( |
223 | rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); |
224 | return success(); |
225 | } |
226 | }; |
227 | |
228 | /// Bufferization of tensor.dim. Replace with memref.dim. |
229 | struct DimOpInterface |
230 | : public BufferizableOpInterface::ExternalModel<DimOpInterface, |
231 | tensor::DimOp> { |
232 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
233 | const AnalysisState &state) const { |
234 | // The op reads the tensor's metadata but not its contents. |
235 | return false; |
236 | } |
237 | |
238 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
239 | const AnalysisState &state) const { |
240 | return false; |
241 | } |
242 | |
243 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
244 | const AnalysisState &state) const { |
245 | return {}; |
246 | } |
247 | |
248 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
249 | const BufferizationOptions &options) const { |
250 | auto dimOp = cast<tensor::DimOp>(op); |
251 | FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options); |
252 | if (failed(result: v)) |
253 | return failure(); |
254 | replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v, |
255 | dimOp.getIndex()); |
256 | return success(); |
257 | } |
258 | }; |
259 | |
260 | /// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor". |
261 | struct EmptyOpInterface |
262 | : public BufferizableOpInterface::ExternalModel<EmptyOpInterface, |
263 | tensor::EmptyOp> { |
264 | bool bufferizesToAllocation(Operation *op, Value value) const { return true; } |
265 | |
266 | bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult, |
267 | const AnalysisState &state) const { |
268 | // The returned tensor does not have specified contents. |
269 | return false; |
270 | } |
271 | |
272 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
273 | const BufferizationOptions &options) const { |
274 | auto emptyOp = cast<tensor::EmptyOp>(op); |
275 | |
276 | // Optimization: Fold away the op if it has no uses. |
277 | if (op->getUses().empty()) { |
278 | rewriter.eraseOp(op); |
279 | return success(); |
280 | } |
281 | |
282 | // Allocate a tensor. This emits a "bufferization.alloc_tensor" op. |
283 | FailureOr<Value> allocTensor = allocateTensorForShapedValue( |
284 | rewriter, op->getLoc(), emptyOp.getResult(), options, /*copy=*/false); |
285 | if (failed(result: allocTensor)) |
286 | return failure(); |
287 | rewriter.replaceOp(op, newValues: *allocTensor); |
288 | return success(); |
289 | } |
290 | }; |
291 | |
292 | /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. |
293 | struct ExpandShapeOpInterface |
294 | : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, |
295 | tensor::ExpandShapeOp> { |
296 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
297 | const AnalysisState &state) const { |
298 | // In contrast to tensor.collapse_shape, this op can always be bufferized |
299 | // without a copy. |
300 | return false; |
301 | } |
302 | |
303 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
304 | const AnalysisState &state) const { |
305 | return false; |
306 | } |
307 | |
308 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
309 | const AnalysisState &state) const { |
310 | return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}}; |
311 | } |
312 | |
313 | FailureOr<BaseMemRefType> |
314 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
315 | SmallVector<Value> &invocationStack) const { |
316 | auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); |
317 | auto maybeSrcBufferType = bufferization::getBufferType( |
318 | value: expandShapeOp.getSrc(), options, invocationStack); |
319 | if (failed(maybeSrcBufferType)) |
320 | return failure(); |
321 | auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType); |
322 | auto maybeResultType = memref::ExpandShapeOp::computeExpandedType( |
323 | srcBufferType, expandShapeOp.getResultType().getShape(), |
324 | expandShapeOp.getReassociationIndices()); |
325 | if (failed(maybeResultType)) |
326 | return failure(); |
327 | return *maybeResultType; |
328 | } |
329 | |
330 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
331 | const BufferizationOptions &options) const { |
332 | auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); |
333 | auto tensorResultType = expandShapeOp.getResultType(); |
334 | FailureOr<Value> buffer = |
335 | getBuffer(rewriter, expandShapeOp.getSrc(), options); |
336 | if (failed(result: buffer)) |
337 | return failure(); |
338 | |
339 | // Memref result type is inferred by the builder based on reassociation |
340 | // indices and result shape. |
341 | replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>( |
342 | rewriter, op, tensorResultType.getShape(), *buffer, |
343 | expandShapeOp.getReassociationIndices()); |
344 | return success(); |
345 | } |
346 | }; |
347 | |
348 | /// Bufferization of tensor.extract_slice. Replace with memref.subview. |
349 | struct |
350 | : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, |
351 | tensor::ExtractSliceOp> { |
352 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
353 | const AnalysisState &state) const { |
354 | return false; |
355 | } |
356 | |
357 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
358 | const AnalysisState &state) const { |
359 | return false; |
360 | } |
361 | |
362 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
363 | const AnalysisState &state) const { |
364 | return {{op->getOpResult(idx: 0), BufferRelation::Unknown}}; |
365 | } |
366 | |
367 | LogicalResult (Operation *op, RewriterBase &rewriter, |
368 | const BufferizationOptions &options) const { |
369 | auto = cast<tensor::ExtractSliceOp>(op); |
370 | SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); |
371 | SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); |
372 | SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); |
373 | Location loc = extractSliceOp.getLoc(); |
374 | |
375 | // Get source buffer. |
376 | FailureOr<Value> srcMemref = |
377 | getBuffer(rewriter, extractSliceOp.getSource(), options); |
378 | if (failed(result: srcMemref)) |
379 | return failure(); |
380 | |
381 | // Take a subview of the source buffer. |
382 | auto resultMemrefType = |
383 | bufferization::getBufferType(value: extractSliceOp.getResult(), options); |
384 | if (failed(resultMemrefType)) |
385 | return failure(); |
386 | Value subView = rewriter.create<memref::SubViewOp>( |
387 | loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets, |
388 | mixedSizes, mixedStrides); |
389 | |
390 | replaceOpWithBufferizedValues(rewriter, op, values: subView); |
391 | return success(); |
392 | } |
393 | |
394 | FailureOr<BaseMemRefType> |
395 | (Operation *op, Value value, const BufferizationOptions &options, |
396 | SmallVector<Value> &invocationStack) const { |
397 | auto = cast<tensor::ExtractSliceOp>(op); |
398 | assert(value == extractSliceOp.getResult() && "invalid value" ); |
399 | auto srcMemrefType = bufferization::getBufferType( |
400 | value: extractSliceOp.getSource(), options, invocationStack); |
401 | if (failed(srcMemrefType)) |
402 | return failure(); |
403 | SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); |
404 | SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); |
405 | SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); |
406 | return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType( |
407 | extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType), |
408 | mixedOffsets, mixedSizes, mixedStrides)); |
409 | } |
410 | }; |
411 | |
412 | /// Bufferization of tensor.extract. Replace with memref.load. |
413 | struct |
414 | : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, |
415 | tensor::ExtractOp> { |
416 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
417 | const AnalysisState &state) const { |
418 | return true; |
419 | } |
420 | |
421 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
422 | const AnalysisState &state) const { |
423 | return false; |
424 | } |
425 | |
426 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
427 | const AnalysisState &state) const { |
428 | return {}; |
429 | } |
430 | |
431 | LogicalResult (Operation *op, RewriterBase &rewriter, |
432 | const BufferizationOptions &options) const { |
433 | auto = cast<tensor::ExtractOp>(op); |
434 | FailureOr<Value> srcMemref = |
435 | getBuffer(rewriter, extractOp.getTensor(), options); |
436 | if (failed(result: srcMemref)) |
437 | return failure(); |
438 | replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref, |
439 | extractOp.getIndices()); |
440 | return success(); |
441 | } |
442 | }; |
443 | |
444 | // Implements backtracking to traverse indices of the output buffer while |
445 | // iterating over op.elements(). |
446 | static void createStores(RewriterBase &rewriter, Location loc, int dim, |
447 | Value buffer, ArrayRef<int64_t> shape, |
448 | ArrayRef<Value> constants, |
449 | OperandRange::iterator &elementIt, |
450 | SmallVectorImpl<Value> &indices) { |
451 | if (dim == static_cast<int>(shape.size()) - 1) { |
452 | for (int i = 0; i < shape.back(); ++i) { |
453 | indices.back() = constants[i]; |
454 | rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); |
455 | ++elementIt; |
456 | } |
457 | return; |
458 | } |
459 | for (int i = 0; i < shape[dim]; ++i) { |
460 | indices[dim] = constants[i]; |
461 | createStores(rewriter, loc, dim: dim + 1, buffer, shape, constants, elementIt, |
462 | indices); |
463 | } |
464 | } |
465 | |
466 | /// Bufferization of tensor.from_elements. |
467 | struct FromElementsOpInterface |
468 | : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, |
469 | tensor::FromElementsOp> { |
470 | |
471 | bool bufferizesToAllocation(Operation *op, Value value) const { return true; } |
472 | |
473 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
474 | const BufferizationOptions &options) const { |
475 | auto fromElementsOp = cast<tensor::FromElementsOp>(op); |
476 | auto tensorType = cast<RankedTensorType>(fromElementsOp.getType()); |
477 | |
478 | // TODO: Implement memory space for this op. |
479 | if (options.defaultMemorySpaceFn(tensorType) != Attribute()) |
480 | return op->emitError(message: "memory space not implemented yet" ); |
481 | |
482 | // Allocate a buffer for the result. |
483 | Location loc = op->getLoc(); |
484 | auto shape = tensorType.getShape(); |
485 | // TODO: Create alloc_tensor ops during TensorCopyInsertion. |
486 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
487 | rewriter, loc, fromElementsOp.getResult(), options, |
488 | /*copy=*/false); |
489 | if (failed(result: tensorAlloc)) |
490 | return failure(); |
491 | auto memrefType = |
492 | MemRefType::get(tensorType.getShape(), tensorType.getElementType()); |
493 | Value buffer = rewriter.create<bufferization::ToMemrefOp>( |
494 | op->getLoc(), memrefType, *tensorAlloc); |
495 | |
496 | // Case: tensor<0xelem_type>. |
497 | if (fromElementsOp.getElements().empty()) { |
498 | replaceOpWithBufferizedValues(rewriter, op, values: buffer); |
499 | return success(); |
500 | } |
501 | |
502 | // Case: tensor<elem_type>. |
503 | if (shape.empty()) { |
504 | rewriter.create<memref::StoreOp>( |
505 | loc, fromElementsOp.getElements().front(), buffer); |
506 | replaceOpWithBufferizedValues(rewriter, op, values: buffer); |
507 | return success(); |
508 | } |
509 | |
510 | // Create constants for the range of possible indices [0, max{shape_i}). |
511 | auto maxDim = *llvm::max_element(shape); |
512 | SmallVector<Value, 2> constants; |
513 | constants.reserve(N: maxDim); |
514 | for (int i = 0; i < maxDim; ++i) |
515 | constants.push_back(rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i)); |
516 | |
517 | // Traverse all `elements` and create `memref.store` ops. |
518 | auto elementIt = fromElementsOp.getElements().begin(); |
519 | SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); |
520 | createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, |
521 | indices); |
522 | |
523 | replaceOpWithBufferizedValues(rewriter, op, values: buffer); |
524 | |
525 | return success(); |
526 | } |
527 | }; |
528 | |
529 | /// Lower the body of a tensor.generate like op (one index-typed bbArg per dim). |
530 | /// Such ops are lowered to linalg.map with the given tensor as a destination. |
531 | /// |
532 | /// Example: |
533 | /// ``` |
534 | /// %r = tensor.generate %x, %y { |
535 | /// ^bb0(%arg0: index, %arg1: index): |
536 | /// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index) |
537 | /// tensor.yield %0 : index |
538 | /// } : tensor<?x?xindex> |
539 | /// ``` |
540 | /// |
541 | /// Is lowered to: |
542 | /// ``` |
543 | /// linalg.map ins() outs(%dest) { |
544 | /// %d0 = linalg.index 0 : index |
545 | /// %d1 = linalg.index 1 : index |
546 | /// %0 = "some_op"(%d0, %d1) : (index, index) -> (index) |
547 | /// linalg.yield %0 : index |
548 | /// } |
549 | /// ``` |
550 | static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc, |
551 | Value tensorDestination, |
552 | ValueRange dynamicSizes, |
553 | Region &generateBody) { |
554 | assert(generateBody.hasOneBlock() && "expected body with single block" ); |
555 | auto tensorType = cast<RankedTensorType>(tensorDestination.getType()); |
556 | assert(generateBody.getNumArguments() == tensorType.getRank() && |
557 | "rank mismatch" ); |
558 | |
559 | // Create linalg::MapOp. |
560 | OpBuilder::InsertionGuard g(rewriter); |
561 | auto linalgOp = |
562 | rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(), |
563 | /*init=*/tensorDestination); |
564 | Block &linalgBody = linalgOp.getMapper().emplaceBlock(); |
565 | |
566 | // Create linalg::IndexOps. |
567 | rewriter.setInsertionPointToStart(&linalgBody); |
568 | SmallVector<Value> indices; |
569 | for (int64_t dim = 0; dim < tensorType.getRank(); ++dim) |
570 | indices.push_back(rewriter.create<linalg::IndexOp>(loc, dim)); |
571 | |
572 | // Move over body. |
573 | rewriter.mergeBlocks(source: &generateBody.front(), dest: &linalgBody, argValues: indices); |
574 | auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator()); |
575 | rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue()); |
576 | |
577 | return linalgOp.getResult()[0]; |
578 | } |
579 | |
580 | /// Bufferization of tensor.generate. |
581 | struct GenerateOpInterface |
582 | : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, |
583 | tensor::GenerateOp> { |
584 | |
585 | bool bufferizesToAllocation(Operation *op, Value value) const { return true; } |
586 | |
587 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
588 | const BufferizationOptions &options) const { |
589 | auto generateOp = cast<tensor::GenerateOp>(op); |
590 | |
591 | auto type = generateOp.getResult().getType(); |
592 | |
593 | // TODO: Implement memory space for this op. |
594 | if (options.defaultMemorySpaceFn(type) != Attribute()) |
595 | return op->emitError(message: "memory space not implemented yet" ); |
596 | |
597 | // Allocate memory. |
598 | Location loc = op->getLoc(); |
599 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
600 | rewriter, loc, generateOp.getResult(), options, |
601 | /*copy=*/false); |
602 | if (failed(result: tensorAlloc)) |
603 | return failure(); |
604 | |
605 | Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc, |
606 | generateOp.getDynamicExtents(), |
607 | generateOp.getBody()); |
608 | rewriter.replaceOp(generateOp, result); |
609 | |
610 | return success(); |
611 | } |
612 | }; |
613 | |
614 | /// Bufferization of tensor.insert. Replace with memref.store. |
615 | /// |
616 | /// Note: DstBufferizableOpInterfaceExternalModel provides many default method |
617 | /// implementations for DestinationStyle ops. |
618 | struct InsertOpInterface |
619 | : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface, |
620 | tensor::InsertOp> { |
621 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
622 | const BufferizationOptions &options) const { |
623 | auto insertOp = cast<tensor::InsertOp>(op); |
624 | FailureOr<Value> destMemref = |
625 | getBuffer(rewriter, insertOp.getDest(), options); |
626 | if (failed(result: destMemref)) |
627 | return failure(); |
628 | rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(), |
629 | *destMemref, insertOp.getIndices()); |
630 | replaceOpWithBufferizedValues(rewriter, op, values: *destMemref); |
631 | return success(); |
632 | } |
633 | }; |
634 | |
635 | /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under |
636 | /// certain circumstances, this op can also be a no-op. |
637 | /// |
638 | /// Note: DstBufferizableOpInterfaceExternalModel provides many default method |
639 | /// implementations for DestinationStyle ops. |
640 | struct InsertSliceOpInterface |
641 | : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface, |
642 | tensor::InsertSliceOp> { |
643 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
644 | const AnalysisState &state) const { |
645 | auto insertSliceOp = cast<tensor::InsertSliceOp>(op); |
646 | RankedTensorType destType = insertSliceOp.getDestType(); |
647 | |
648 | // The source is always read. |
649 | if (opOperand == insertSliceOp.getSourceMutable()) |
650 | return true; |
651 | |
652 | // For the destination, it depends... |
653 | assert(opOperand == insertSliceOp.getDestMutable() && "expected dest" ); |
654 | |
655 | // Dest is not read if it is entirely overwritten. E.g.: |
656 | // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> |
657 | bool allOffsetsZero = |
658 | llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) { |
659 | return isConstantIntValue(ofr, value: 0); |
660 | }); |
661 | bool sizesMatchDestSizes = llvm::all_of( |
662 | llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) { |
663 | return getConstantIntValue(it.value()) == |
664 | destType.getDimSize(it.index()); |
665 | }); |
666 | bool allStridesOne = |
667 | llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) { |
668 | return isConstantIntValue(ofr, value: 1); |
669 | }); |
670 | return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne); |
671 | } |
672 | |
673 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
674 | const BufferizationOptions &options) const { |
675 | // insert_slice ops arise from tiling and bufferizing them out-of-place is |
676 | // generally a deal breaker. When used with loops, this ends up cloning the |
677 | // whole tensor on every single iteration and is a symptom of a |
678 | // catastrophically bad scheduling decision. |
679 | // TODO: be very loud about it or even consider failing the pass. |
680 | auto insertSliceOp = cast<tensor::InsertSliceOp>(op); |
681 | SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); |
682 | SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); |
683 | SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); |
684 | Location loc = insertSliceOp.getLoc(); |
685 | |
686 | // Get destination buffer. |
687 | FailureOr<Value> dstMemref = |
688 | getBuffer(rewriter, insertSliceOp.getDest(), options); |
689 | if (failed(result: dstMemref)) |
690 | return failure(); |
691 | |
692 | // Take a subview of the destination buffer. |
693 | auto dstMemrefType = cast<MemRefType>(dstMemref->getType()); |
694 | auto subviewMemRefType = |
695 | cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( |
696 | insertSliceOp.getSourceType().getShape(), dstMemrefType, |
697 | mixedOffsets, mixedSizes, mixedStrides)); |
698 | Value subView = rewriter.create<memref::SubViewOp>( |
699 | loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, |
700 | mixedStrides); |
701 | |
702 | // Copy tensor. If this tensor.insert_slice has a matching |
703 | // tensor.extract_slice, the copy operation will eventually fold away. |
704 | FailureOr<Value> srcMemref = |
705 | getBuffer(rewriter, insertSliceOp.getSource(), options); |
706 | if (failed(result: srcMemref)) |
707 | return failure(); |
708 | if (failed(result: options.createMemCpy(b&: rewriter, loc, from: *srcMemref, to: subView))) |
709 | return failure(); |
710 | |
711 | replaceOpWithBufferizedValues(rewriter, op, values: *dstMemref); |
712 | return success(); |
713 | } |
714 | }; |
715 | |
716 | /// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor + |
717 | /// linalg.map + insert_slice. |
718 | /// For best performance, vectorize before bufferization (better performance in |
719 | /// case of padding with a constant). |
720 | struct PadOpInterface |
721 | : public BufferizableOpInterface::ExternalModel<PadOpInterface, |
722 | tensor::PadOp> { |
723 | bool bufferizesToAllocation(Operation *op, Value value) const { return true; } |
724 | |
725 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
726 | const AnalysisState &state) const { |
727 | return true; |
728 | } |
729 | |
730 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
731 | const AnalysisState &state) const { |
732 | return false; |
733 | } |
734 | |
735 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
736 | const AnalysisState &state) const { |
737 | return {}; |
738 | } |
739 | |
740 | FailureOr<BaseMemRefType> |
741 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
742 | SmallVector<Value> &invocationStack) const { |
743 | // Infer memory space from the source tensor. |
744 | auto padOp = cast<tensor::PadOp>(op); |
745 | auto maybeSrcBufferType = bufferization::getBufferType( |
746 | value: padOp.getSource(), options, invocationStack); |
747 | if (failed(maybeSrcBufferType)) |
748 | return failure(); |
749 | MemRefLayoutAttrInterface layout; |
750 | return MemRefType::get(padOp.getResultType().getShape(), |
751 | padOp.getResultType().getElementType(), layout, |
752 | maybeSrcBufferType->getMemorySpace()); |
753 | } |
754 | |
755 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
756 | const BufferizationOptions &options) const { |
757 | auto padOp = cast<tensor::PadOp>(op); |
758 | Location loc = padOp.getLoc(); |
759 | RankedTensorType resultType = padOp.getResultType(); |
760 | RankedTensorType srcType = padOp.getSourceType(); |
761 | |
762 | auto toValue = [&](OpFoldResult ofr) { |
763 | if (ofr.is<Value>()) |
764 | return ofr.get<Value>(); |
765 | return rewriter |
766 | .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr)) |
767 | .getResult(); |
768 | }; |
769 | |
770 | // Compute dynamic result dimensions. |
771 | SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad(); |
772 | SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad(); |
773 | SmallVector<Value> dynamicSizes; |
774 | for (int64_t i = 0; i < resultType.getRank(); ++i) { |
775 | if (!resultType.isDynamicDim(i)) |
776 | continue; |
777 | Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i); |
778 | Value lowPad = toValue(mixedLowPad[i]); |
779 | Value highPad = toValue(mixedHighPad[i]); |
780 | AffineExpr s0, s1, s2; |
781 | bindSymbols(ctx: op->getContext(), exprs&: s0, exprs&: s1, exprs&: s2); |
782 | AffineExpr sumExpr = s0 + s1 + s2; |
783 | Value sum = rewriter.create<affine::AffineApplyOp>( |
784 | loc, sumExpr, ValueRange{srcDim, lowPad, highPad}); |
785 | dynamicSizes.push_back(Elt: sum); |
786 | } |
787 | |
788 | // Allocate a buffer for the padded result. |
789 | FailureOr<Value> tensorAlloc = |
790 | allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), options, |
791 | /*copy=*/false); |
792 | if (failed(result: tensorAlloc)) |
793 | return failure(); |
794 | |
795 | // tensor::PadOp is like tensor::GenerateOp: The only difference is that |
796 | // only a part of the generated tensor is needed. For simplicity, we reuse |
797 | // the same functionality here. |
798 | Value filledBuffer = lowerGenerateLikeOpBody( |
799 | rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion()); |
800 | |
801 | // Create tensor::InsertSliceOp. |
802 | SmallVector<OpFoldResult> sliceSizes = |
803 | getMixedSizes(rewriter, loc, padOp.getSource()); |
804 | SmallVector<OpFoldResult> sliceStrides(srcType.getRank(), |
805 | rewriter.getIndexAttr(1)); |
806 | rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( |
807 | padOp, padOp.getSource(), filledBuffer, |
808 | /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); |
809 | |
810 | return success(); |
811 | } |
812 | }; |
813 | |
814 | /// Bufferization of tensor.rank. Replace with memref.rank. |
815 | struct RankOpInterface |
816 | : public BufferizableOpInterface::ExternalModel<RankOpInterface, |
817 | tensor::RankOp> { |
818 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
819 | const AnalysisState &state) const { |
820 | // The op reads the tensor's metadata but not its contents. |
821 | return false; |
822 | } |
823 | |
824 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
825 | const AnalysisState &state) const { |
826 | return false; |
827 | } |
828 | |
829 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
830 | const AnalysisState &state) const { |
831 | return {}; |
832 | } |
833 | |
834 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
835 | const BufferizationOptions &options) const { |
836 | auto rankOp = cast<tensor::RankOp>(op); |
837 | FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options); |
838 | if (failed(result: v)) |
839 | return failure(); |
840 | replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), |
841 | *v); |
842 | return success(); |
843 | } |
844 | }; |
845 | |
846 | /// Bufferization of tensor.reshape. Replace with memref.reshape. |
847 | struct ReshapeOpInterface |
848 | : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, |
849 | tensor::ReshapeOp> { |
850 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
851 | const AnalysisState &state) const { |
852 | // Depending on the layout map, the source buffer may have to be copied. |
853 | auto reshapeOp = cast<tensor::ReshapeOp>(op); |
854 | return opOperand == reshapeOp.getShapeMutable(); |
855 | } |
856 | |
857 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
858 | const AnalysisState &state) const { |
859 | return false; |
860 | } |
861 | |
862 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
863 | const AnalysisState &state) const { |
864 | return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}}; |
865 | } |
866 | |
867 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
868 | const BufferizationOptions &options) const { |
869 | auto reshapeOp = cast<tensor::ReshapeOp>(op); |
870 | FailureOr<Value> srcBuffer = |
871 | getBuffer(rewriter, reshapeOp.getSource(), options); |
872 | FailureOr<Value> shapeBuffer = |
873 | getBuffer(rewriter, reshapeOp.getShape(), options); |
874 | if (failed(result: srcBuffer) || failed(result: shapeBuffer)) |
875 | return failure(); |
876 | auto maybeResultMemRefType = |
877 | bufferization::getBufferType(value: reshapeOp.getResult(), options); |
878 | if (failed(maybeResultMemRefType)) |
879 | return failure(); |
880 | |
881 | // memref.reshape requires the source buffer to have an identity layout. |
882 | // If the source memref does not have an identity layout, copy the source |
883 | // into a new buffer with an identity layout. |
884 | auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType()); |
885 | if (srcType && !srcType.getLayout().isIdentity()) { |
886 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
887 | rewriter, op->getLoc(), reshapeOp.getSource(), options); |
888 | if (failed(result: tensorAlloc)) |
889 | return failure(); |
890 | auto memrefType = MemRefType::get( |
891 | srcType.getShape(), srcType.getElementType(), AffineMap(), |
892 | cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace()); |
893 | srcBuffer = rewriter |
894 | .create<bufferization::ToMemrefOp>( |
895 | op->getLoc(), memrefType, *tensorAlloc) |
896 | .getResult(); |
897 | } |
898 | |
899 | replaceOpWithNewBufferizedOp<memref::ReshapeOp>( |
900 | rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer); |
901 | return success(); |
902 | } |
903 | |
904 | FailureOr<BaseMemRefType> |
905 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
906 | SmallVector<Value> &invocationStack) const { |
907 | auto reshapeOp = cast<tensor::ReshapeOp>(op); |
908 | assert(value == reshapeOp.getResult() && "unexpected value provided" ); |
909 | auto maybeSourceBufferType = bufferization::getBufferType( |
910 | value: reshapeOp.getSource(), options, invocationStack); |
911 | if (failed(maybeSourceBufferType)) |
912 | return failure(); |
913 | return getMemRefTypeWithStaticIdentityLayout( |
914 | reshapeOp.getResult().getType(), |
915 | cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()); |
916 | } |
917 | }; |
918 | |
919 | /// Analysis of ParallelInsertSliceOp. |
920 | struct ParallelInsertSliceOpInterface |
921 | : public BufferizableOpInterface::ExternalModel< |
922 | ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { |
923 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
924 | const AnalysisState &state) const { |
925 | return {}; |
926 | } |
927 | |
928 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
929 | const AnalysisState &state) const { |
930 | return true; |
931 | } |
932 | |
933 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
934 | const AnalysisState &state) const { |
935 | auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); |
936 | return opOperand == parallelInsertSliceOp.getDestMutable(); |
937 | } |
938 | |
939 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
940 | const BufferizationOptions &options) const { |
941 | OpBuilder::InsertionGuard g(rewriter); |
942 | auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); |
943 | ParallelCombiningOpInterface parallelCombiningParent = |
944 | parallelInsertSliceOp.getParallelCombiningParent(); |
945 | |
946 | // Bufferize the op outside of the parallel combining terminator. |
947 | rewriter.setInsertionPoint(parallelCombiningParent); |
948 | |
949 | // Get source and destination buffers. |
950 | FailureOr<Value> destBuffer = |
951 | getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); |
952 | if (failed(result: destBuffer)) |
953 | return failure(); |
954 | FailureOr<Value> srcBuffer = |
955 | getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); |
956 | if (failed(result: srcBuffer)) |
957 | return failure(); |
958 | |
959 | // Take a subview of the destination buffer. |
960 | auto destBufferType = cast<MemRefType>(destBuffer->getType()); |
961 | auto subviewMemRefType = |
962 | cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( |
963 | parallelInsertSliceOp.getSourceType().getShape(), destBufferType, |
964 | parallelInsertSliceOp.getMixedOffsets(), |
965 | parallelInsertSliceOp.getMixedSizes(), |
966 | parallelInsertSliceOp.getMixedStrides())); |
967 | Value subview = rewriter.create<memref::SubViewOp>( |
968 | parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, |
969 | parallelInsertSliceOp.getMixedOffsets(), |
970 | parallelInsertSliceOp.getMixedSizes(), |
971 | parallelInsertSliceOp.getMixedStrides()); |
972 | |
973 | // This memcpy will fold away if everything bufferizes in-place. |
974 | if (failed(options.createMemCpy(b&: rewriter, loc: parallelInsertSliceOp.getLoc(), |
975 | from: *srcBuffer, to: subview))) |
976 | return failure(); |
977 | |
978 | // In case the source was allocated in the same block, make sure that the |
979 | // deallocation op (if any) appears after the memcpy. By default, deallocs |
980 | // are placed before the terminator, but this does not work for ForallOp |
981 | // because the terminator does more than just yielding a value. |
982 | // |
983 | // Note: This is not a problem for the destination buffer because these are |
984 | // assumed to always bufferize in-place. |
985 | for (Operation *user : srcBuffer->getUsers()) { |
986 | if (hasEffect<MemoryEffects::Free>(user)) { |
987 | if (user->getBlock() == parallelCombiningParent->getBlock()) |
988 | rewriter.moveOpBefore(user, user->getBlock()->getTerminator()); |
989 | break; |
990 | } |
991 | } |
992 | |
993 | // Delete the op. |
994 | rewriter.eraseOp(op); |
995 | return success(); |
996 | } |
997 | }; |
998 | |
999 | /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled |
1000 | /// with a linalg.map. Similar to tensor.generate. |
1001 | struct SplatOpInterface |
1002 | : public BufferizableOpInterface::ExternalModel<SplatOpInterface, |
1003 | tensor::SplatOp> { |
1004 | |
1005 | bool bufferizesToAllocation(Operation *op, Value value) const { return true; } |
1006 | |
1007 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
1008 | const BufferizationOptions &options) const { |
1009 | OpBuilder::InsertionGuard g(rewriter); |
1010 | auto splatOp = cast<tensor::SplatOp>(op); |
1011 | |
1012 | // Allocate memory. |
1013 | Location loc = op->getLoc(); |
1014 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
1015 | rewriter, loc, splatOp.getResult(), options, |
1016 | /*copy=*/false); |
1017 | if (failed(result: tensorAlloc)) |
1018 | return failure(); |
1019 | |
1020 | // Create linalg::MapOp. |
1021 | auto tensorType = cast<RankedTensorType>(tensorAlloc->getType()); |
1022 | |
1023 | // TODO: Implement memory space for this op. |
1024 | if (options.defaultMemorySpaceFn(tensorType) != Attribute()) |
1025 | return op->emitError(message: "memory space not implemented yet" ); |
1026 | |
1027 | auto linalgOp = |
1028 | rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(), |
1029 | /*init=*/*tensorAlloc); |
1030 | Block &linalgBody = linalgOp.getMapper().emplaceBlock(); |
1031 | |
1032 | // Create linalg::IndexOps. |
1033 | rewriter.setInsertionPointToStart(&linalgBody); |
1034 | rewriter.create<linalg::YieldOp>(loc, splatOp.getInput()); |
1035 | rewriter.replaceOp(splatOp, linalgOp.getResult()[0]); |
1036 | |
1037 | return success(); |
1038 | } |
1039 | }; |
1040 | |
1041 | } // namespace |
1042 | } // namespace tensor |
1043 | } // namespace mlir |
1044 | |
1045 | void mlir::tensor::registerBufferizableOpInterfaceExternalModels( |
1046 | DialectRegistry ®istry) { |
1047 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) { |
1048 | CastOp::attachInterface<CastOpInterface>(*ctx); |
1049 | CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); |
1050 | DimOp::attachInterface<DimOpInterface>(*ctx); |
1051 | EmptyOp::attachInterface<EmptyOpInterface>(*ctx); |
1052 | ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); |
1053 | ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); |
1054 | ExtractOp::attachInterface<ExtractOpInterface>(*ctx); |
1055 | FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); |
1056 | GenerateOp::attachInterface<GenerateOpInterface>(*ctx); |
1057 | InsertOp::attachInterface<InsertOpInterface>(*ctx); |
1058 | InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); |
1059 | PadOp::attachInterface<PadOpInterface>(*ctx); |
1060 | ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( |
1061 | *ctx); |
1062 | RankOp::attachInterface<RankOpInterface>(*ctx); |
1063 | ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); |
1064 | SplatOp::attachInterface<SplatOpInterface>(*ctx); |
1065 | |
1066 | // Load additional dialects of which ops may get created. |
1067 | ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>(); |
1068 | }); |
1069 | |
1070 | // Bufferization requires SubsetInsertionOpInterface models. Make sure that |
1071 | // they are registered. |
1072 | tensor::registerSubsetOpInterfaceExternalModels(registry); |
1073 | } |
1074 | |