1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
16#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/IR/BuiltinTypeInterfaces.h"
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/Operation.h"
24
25using namespace mlir;
26using namespace mlir::bufferization;
27using namespace mlir::tensor;
28
29namespace mlir {
30namespace tensor {
31namespace {
32
33struct CastOpInterface
34 : public BufferizableOpInterface::ExternalModel<CastOpInterface,
35 tensor::CastOp> {
36 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
37 const AnalysisState &state) const {
38 return false;
39 }
40
41 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
42 const AnalysisState &state) const {
43 return false;
44 }
45
46 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
47 const AnalysisState &state) const {
48 return {{op->getResult(idx: 0), BufferRelation::Equivalent}};
49 }
50
51 FailureOr<BufferLikeType>
52 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
53 const BufferizationState &state,
54 SmallVector<Value> &invocationStack) const {
55 auto castOp = cast<tensor::CastOp>(Val: op);
56 auto maybeSrcBufferType =
57 bufferization::detail::asMemRefType(bufferType: bufferization::getBufferType(
58 value: castOp.getSource(), options, state, invocationStack));
59 if (failed(Result: maybeSrcBufferType))
60 return failure();
61 Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
62
63 // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
64 // type in case the input is an unranked tensor type.
65
66 // Case 1: Casting an unranked tensor
67 if (isa<UnrankedTensorType>(Val: castOp.getSource().getType())) {
68 // When casting to a ranked tensor, we cannot infer any static offset or
69 // strides from the source. Assume fully dynamic.
70 return cast<BufferLikeType>(
71 Val: getMemRefTypeWithFullyDynamicLayout(tensorType: castOp.getType(), memorySpace));
72 }
73
74 // Case 2: Casting to an unranked tensor type
75 if (isa<UnrankedTensorType>(Val: castOp.getType())) {
76 return cast<BufferLikeType>(
77 Val: getMemRefTypeWithFullyDynamicLayout(tensorType: castOp.getType(), memorySpace));
78 }
79
80 // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
81 // change.
82 auto rankedResultType = cast<RankedTensorType>(Val: castOp.getType());
83 return cast<BufferLikeType>(Val: MemRefType::get(
84 shape: rankedResultType.getShape(), elementType: rankedResultType.getElementType(),
85 layout: llvm::cast<MemRefType>(Val&: *maybeSrcBufferType).getLayout(), memorySpace));
86 }
87
88 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
89 const BufferizationOptions &options,
90 BufferizationState &state) const {
91 auto castOp = cast<tensor::CastOp>(Val: op);
92
93 // The result buffer still has the old (pre-cast) type.
94 FailureOr<Value> resultBuffer =
95 getBuffer(rewriter, value: castOp.getSource(), options, state);
96 if (failed(Result: resultBuffer))
97 return failure();
98
99 // Compute the new type.
100 auto resultMemRefType =
101 bufferization::getBufferType(value: castOp.getResult(), options, state);
102 if (failed(Result: resultMemRefType))
103 return failure();
104 if (resultBuffer->getType() == *resultMemRefType) {
105 // This cast is a no-op.
106 replaceOpWithBufferizedValues(rewriter, op, values: *resultBuffer);
107 return success();
108 }
109
110 // Replace the op with a memref.cast.
111 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
112 *resultMemRefType) &&
113 "CallOp::bufferize: cast incompatible");
114 replaceOpWithNewBufferizedOp<memref::CastOp>(
115 rewriter, op, args&: *resultMemRefType, args&: *resultBuffer);
116
117 return success();
118 }
119};
120
121/// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
122struct CollapseShapeOpInterface
123 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
124 tensor::CollapseShapeOp> {
125 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
126 const AnalysisState &state) const {
127 // tensor.collapse_shape may reallocate, at which point the source buffer is
128 // copied. I.e., there will be a memory read side effect on the bufferized
129 // source. This function conservatively returns "true" because whether a
130 // copy will be created or not is not known at this point.
131 return true;
132 }
133
134 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
135 const AnalysisState &state) const {
136 return false;
137 }
138
139 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
140 const AnalysisState &state) const {
141 // TODO: CollapseShapeOp may allocate at runtime.
142 return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}};
143 }
144
145 FailureOr<BufferLikeType>
146 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
147 const BufferizationState &state,
148 SmallVector<Value> &invocationStack) const {
149 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(Val: op);
150 auto maybeSrcBufferType = bufferization::getBufferType(
151 value: collapseShapeOp.getSrc(), options, state, invocationStack);
152 if (failed(Result: maybeSrcBufferType))
153 return failure();
154 auto srcBufferType = llvm::cast<MemRefType>(Val&: *maybeSrcBufferType);
155 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
156 srcType: srcBufferType, reassociation: collapseShapeOp.getReassociationIndices());
157
158 if (!canBeCollapsed) {
159 // If dims cannot be collapsed, this op bufferizes to a new allocation.
160 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
161 return cast<BufferLikeType>(
162 Val: bufferization::getMemRefTypeWithStaticIdentityLayout(
163 tensorType: tensorResultType, memorySpace: srcBufferType.getMemorySpace()));
164 }
165
166 return cast<BufferLikeType>(Val: memref::CollapseShapeOp::computeCollapsedType(
167 srcType: srcBufferType, reassociation: collapseShapeOp.getReassociationIndices()));
168 }
169
170 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
171 const BufferizationOptions &options,
172 BufferizationState &state) const {
173 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(Val: op);
174 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
175 FailureOr<Value> maybeBuffer =
176 getBuffer(rewriter, value: collapseShapeOp.getSrc(), options, state);
177 if (failed(Result: maybeBuffer))
178 return failure();
179 Value buffer = *maybeBuffer;
180 auto bufferType = cast<MemRefType>(Val: buffer.getType());
181
182 if (tensorResultType.getRank() == 0) {
183 // 0-d collapses must go through a different op builder.
184 MemRefType resultType;
185
186 if (bufferType.getLayout().isIdentity()) {
187 // Standard layout: result type has no offset.
188 MemRefLayoutAttrInterface layout;
189 resultType = MemRefType::get(shape: {}, elementType: tensorResultType.getElementType(),
190 layout, memorySpace: bufferType.getMemorySpace());
191 } else {
192 // Source memref has a layout map: result type has the same offset as
193 // the source type.
194 SmallVector<int64_t> strides;
195 int64_t offset;
196 if (failed(Result: bufferType.getStridesAndOffset(strides, offset)))
197 return failure();
198 resultType = MemRefType::get(
199 shape: {}, elementType: tensorResultType.getElementType(),
200 layout: StridedLayoutAttr::get(context: op->getContext(), offset, strides: {}),
201 memorySpace: bufferType.getMemorySpace());
202 }
203
204 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
205 rewriter, op, args&: resultType, args&: buffer, args: collapseShapeOp.getReassociation());
206 return success();
207 }
208
209 // If the dims are not collapsible (due to an incompatible source layout
210 // map), force an out-of-place bufferization, i.e., a buffer copy. This
211 // newly allocated buffer will have no layout map and thus be collapsible.
212 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
213 srcType: bufferType, reassociation: collapseShapeOp.getReassociationIndices());
214 if (!canBeCollapsed) {
215 // TODO: Create alloc_tensor ops during TensorCopyInsertion.
216 AnalysisState analysisState(options);
217 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
218 b&: rewriter, loc: op->getLoc(), shapedValue: collapseShapeOp.getSrc(), options, state);
219 if (failed(Result: tensorAlloc))
220 return failure();
221 auto memrefType =
222 MemRefType::get(shape: collapseShapeOp.getSrcType().getShape(),
223 elementType: collapseShapeOp.getSrcType().getElementType(),
224 map: AffineMap(), memorySpace: bufferType.getMemorySpace());
225 buffer = rewriter.create<bufferization::ToBufferOp>(
226 location: op->getLoc(), args&: memrefType, args&: *tensorAlloc);
227 }
228
229 // Result type is inferred by the builder.
230 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
231 rewriter, op, args&: buffer, args: collapseShapeOp.getReassociationIndices());
232 return success();
233 }
234};
235
236/// Bufferization of tensor.dim. Replace with memref.dim.
237struct DimOpInterface
238 : public BufferizableOpInterface::ExternalModel<DimOpInterface,
239 tensor::DimOp> {
240 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
241 const AnalysisState &state) const {
242 // The op reads the tensor's metadata but not its contents.
243 return false;
244 }
245
246 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
247 const AnalysisState &state) const {
248 return false;
249 }
250
251 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
252 const AnalysisState &state) const {
253 return {};
254 }
255
256 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
257 const BufferizationOptions &options,
258 BufferizationState &state) const {
259 auto dimOp = cast<tensor::DimOp>(Val: op);
260 FailureOr<Value> v = getBuffer(rewriter, value: dimOp.getSource(), options, state);
261 if (failed(Result: v))
262 return failure();
263 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, args&: *v,
264 args: dimOp.getIndex());
265 return success();
266 }
267};
268
269/// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor".
270struct EmptyOpInterface
271 : public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
272 tensor::EmptyOp> {
273 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
274
275 bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
276 const AnalysisState &state) const {
277 // The returned tensor does not have specified contents.
278 return false;
279 }
280
281 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
282 const BufferizationOptions &options,
283 BufferizationState &state) const {
284 auto emptyOp = cast<tensor::EmptyOp>(Val: op);
285
286 // Optimization: Fold away the op if it has no uses.
287 if (op->getUses().empty()) {
288 rewriter.eraseOp(op);
289 return success();
290 }
291
292 // Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
293 FailureOr<Value> allocTensor = allocateTensorForShapedValue(
294 b&: rewriter, loc: op->getLoc(), shapedValue: emptyOp.getResult(), options, state,
295 /*copy=*/false);
296 if (failed(Result: allocTensor))
297 return failure();
298 rewriter.replaceOp(op, newValues: *allocTensor);
299 return success();
300 }
301};
302
303/// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
304struct ExpandShapeOpInterface
305 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
306 tensor::ExpandShapeOp> {
307 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
308 const AnalysisState &state) const {
309 // In contrast to tensor.collapse_shape, this op can always be bufferized
310 // without a copy.
311 return false;
312 }
313
314 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
315 const AnalysisState &state) const {
316 return false;
317 }
318
319 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
320 const AnalysisState &state) const {
321 return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}};
322 }
323
324 FailureOr<BufferLikeType>
325 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
326 const BufferizationState &state,
327 SmallVector<Value> &invocationStack) const {
328 auto expandShapeOp = cast<tensor::ExpandShapeOp>(Val: op);
329 auto maybeSrcBufferType = bufferization::getBufferType(
330 value: expandShapeOp.getSrc(), options, state, invocationStack);
331 if (failed(Result: maybeSrcBufferType))
332 return failure();
333 auto srcBufferType = llvm::cast<MemRefType>(Val&: *maybeSrcBufferType);
334 auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
335 srcType: srcBufferType, resultShape: expandShapeOp.getResultType().getShape(),
336 reassociation: expandShapeOp.getReassociationIndices());
337 if (failed(Result: maybeResultType))
338 return failure();
339 return cast<BufferLikeType>(Val&: *maybeResultType);
340 }
341
342 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
343 const BufferizationOptions &options,
344 BufferizationState &state) const {
345 auto expandShapeOp = cast<tensor::ExpandShapeOp>(Val: op);
346 auto tensorResultType = expandShapeOp.getResultType();
347 FailureOr<Value> buffer =
348 getBuffer(rewriter, value: expandShapeOp.getSrc(), options, state);
349 if (failed(Result: buffer))
350 return failure();
351
352 auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
353 location: op->getLoc(), args: tensorResultType.getShape(), args&: *buffer,
354 args: expandShapeOp.getReassociationIndices(),
355 args: expandShapeOp.getMixedOutputShape());
356 replaceOpWithBufferizedValues(rewriter, op,
357 values: memrefExpandShape->getResults());
358 return success();
359 }
360};
361
362/// Bufferization of tensor.extract_slice. Replace with memref.subview.
363struct ExtractSliceOpInterface
364 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
365 tensor::ExtractSliceOp> {
366 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
367 const AnalysisState &state) const {
368 return false;
369 }
370
371 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
372 const AnalysisState &state) const {
373 return false;
374 }
375
376 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
377 const AnalysisState &state) const {
378 return {{op->getOpResult(idx: 0), BufferRelation::Unknown}};
379 }
380
381 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
382 const BufferizationOptions &options,
383 BufferizationState &state) const {
384 auto extractSliceOp = cast<tensor::ExtractSliceOp>(Val: op);
385 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
386 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
387 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
388 Location loc = extractSliceOp.getLoc();
389
390 // Get source buffer.
391 FailureOr<Value> srcMemref =
392 getBuffer(rewriter, value: extractSliceOp.getSource(), options, state);
393 if (failed(Result: srcMemref))
394 return failure();
395
396 // Take a subview of the source buffer.
397 auto resultMemrefType = bufferization::getBufferType(
398 value: extractSliceOp.getResult(), options, state);
399 if (failed(Result: resultMemrefType))
400 return failure();
401 Value subView = rewriter.create<memref::SubViewOp>(
402 location: loc, args: llvm::cast<MemRefType>(Val&: *resultMemrefType), args&: *srcMemref,
403 args&: mixedOffsets, args&: mixedSizes, args&: mixedStrides);
404
405 replaceOpWithBufferizedValues(rewriter, op, values: subView);
406 return success();
407 }
408
409 FailureOr<BufferLikeType>
410 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
411 const BufferizationState &state,
412 SmallVector<Value> &invocationStack) const {
413 auto extractSliceOp = cast<tensor::ExtractSliceOp>(Val: op);
414 assert(value == extractSliceOp.getResult() && "invalid value");
415 auto srcMemrefType = bufferization::getBufferType(
416 value: extractSliceOp.getSource(), options, state, invocationStack);
417 if (failed(Result: srcMemrefType))
418 return failure();
419 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
420 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
421 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
422 return cast<BufferLikeType>(Val: memref::SubViewOp::inferRankReducedResultType(
423 resultShape: extractSliceOp.getType().getShape(),
424 sourceMemRefType: llvm::cast<MemRefType>(Val&: *srcMemrefType), staticOffsets: mixedOffsets, staticSizes: mixedSizes,
425 staticStrides: mixedStrides));
426 }
427};
428
429/// Bufferization of tensor.extract. Replace with memref.load.
430struct ExtractOpInterface
431 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
432 tensor::ExtractOp> {
433 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
434 const AnalysisState &state) const {
435 return true;
436 }
437
438 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
439 const AnalysisState &state) const {
440 return false;
441 }
442
443 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
444 const AnalysisState &state) const {
445 return {};
446 }
447
448 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
449 const BufferizationOptions &options,
450 BufferizationState &state) const {
451 auto extractOp = cast<tensor::ExtractOp>(Val: op);
452 FailureOr<Value> srcMemref =
453 getBuffer(rewriter, value: extractOp.getTensor(), options, state);
454 if (failed(Result: srcMemref))
455 return failure();
456 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, args&: *srcMemref,
457 args: extractOp.getIndices());
458 return success();
459 }
460};
461
462// Implements backtracking to traverse indices of the output buffer while
463// iterating over op.elements().
464static void createStores(RewriterBase &rewriter, Location loc, int dim,
465 Value buffer, ArrayRef<int64_t> shape,
466 ArrayRef<Value> constants,
467 OperandRange::iterator &elementIt,
468 SmallVectorImpl<Value> &indices) {
469 if (dim == static_cast<int>(shape.size()) - 1) {
470 for (int i = 0; i < shape.back(); ++i) {
471 indices.back() = constants[i];
472 rewriter.create<memref::StoreOp>(location: loc, args: *elementIt, args&: buffer, args&: indices);
473 ++elementIt;
474 }
475 return;
476 }
477 for (int i = 0; i < shape[dim]; ++i) {
478 indices[dim] = constants[i];
479 createStores(rewriter, loc, dim: dim + 1, buffer, shape, constants, elementIt,
480 indices);
481 }
482}
483
484/// Bufferization of tensor.from_elements.
485struct FromElementsOpInterface
486 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
487 tensor::FromElementsOp> {
488
489 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
490
491 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
492 const BufferizationOptions &options,
493 BufferizationState &state) const {
494 auto fromElementsOp = cast<tensor::FromElementsOp>(Val: op);
495 auto tensorType = cast<RankedTensorType>(Val: fromElementsOp.getType());
496
497 // Allocate a buffer for the result.
498 Location loc = op->getLoc();
499 auto shape = tensorType.getShape();
500 // TODO: Create alloc_tensor ops during TensorCopyInsertion.
501 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
502 b&: rewriter, loc, shapedValue: fromElementsOp.getResult(), options, state,
503 /*copy=*/false);
504 if (failed(Result: tensorAlloc))
505 return failure();
506 FailureOr<BufferLikeType> memrefType =
507 bufferization::getBufferType(value: *tensorAlloc, options, state);
508 if (failed(Result: memrefType))
509 return failure();
510 Value buffer = rewriter.create<bufferization::ToBufferOp>(
511 location: op->getLoc(), args&: *memrefType, args&: *tensorAlloc);
512
513 // Case: tensor<0xelem_type>.
514 if (fromElementsOp.getElements().empty()) {
515 replaceOpWithBufferizedValues(rewriter, op, values: buffer);
516 return success();
517 }
518
519 // Case: tensor<elem_type>.
520 if (shape.empty()) {
521 rewriter.create<memref::StoreOp>(
522 location: loc, args: fromElementsOp.getElements().front(), args&: buffer);
523 replaceOpWithBufferizedValues(rewriter, op, values: buffer);
524 return success();
525 }
526
527 // Create constants for the range of possible indices [0, max{shape_i}).
528 auto maxDim = *llvm::max_element(Range&: shape);
529 SmallVector<Value, 2> constants;
530 constants.reserve(N: maxDim);
531 for (int i = 0; i < maxDim; ++i)
532 constants.push_back(Elt: rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i));
533
534 // Traverse all `elements` and create `memref.store` ops.
535 auto elementIt = fromElementsOp.getElements().begin();
536 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
537 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
538 indices);
539
540 replaceOpWithBufferizedValues(rewriter, op, values: buffer);
541
542 return success();
543 }
544};
545
546/// Lower the body of a tensor.generate like op (one index-typed bbArg per dim).
547/// Such ops are lowered to linalg.map with the given tensor as a destination.
548///
549/// Example:
550/// ```
551/// %r = tensor.generate %x, %y {
552/// ^bb0(%arg0: index, %arg1: index):
553/// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index)
554/// tensor.yield %0 : index
555/// } : tensor<?x?xindex>
556/// ```
557///
558/// Is lowered to:
559/// ```
560/// linalg.map ins() outs(%dest) {
561/// %d0 = linalg.index 0 : index
562/// %d1 = linalg.index 1 : index
563/// %0 = "some_op"(%d0, %d1) : (index, index) -> (index)
564/// linalg.yield %0 : index
565/// }
566/// ```
567static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
568 Value tensorDestination,
569 ValueRange dynamicSizes,
570 Region &generateBody) {
571 assert(generateBody.hasOneBlock() && "expected body with single block");
572 auto tensorType = cast<RankedTensorType>(Val: tensorDestination.getType());
573 assert(generateBody.getNumArguments() == tensorType.getRank() &&
574 "rank mismatch");
575
576 // Create linalg::MapOp.
577 OpBuilder::InsertionGuard g(rewriter);
578 auto linalgOp =
579 rewriter.create<linalg::MapOp>(location: loc, args&: tensorType, /*inputs=*/args: ValueRange(),
580 /*init=*/args&: tensorDestination);
581 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
582
583 // Create linalg::IndexOps.
584 rewriter.setInsertionPointToStart(&linalgBody);
585 SmallVector<Value> indices;
586 for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
587 indices.push_back(Elt: rewriter.create<linalg::IndexOp>(location: loc, args&: dim));
588
589 // Move over body.
590 rewriter.mergeBlocks(source: &generateBody.front(), dest: &linalgBody, argValues: indices);
591 auto yieldOp = cast<tensor::YieldOp>(Val: linalgBody.getTerminator());
592 rewriter.replaceOpWithNewOp<linalg::YieldOp>(op: yieldOp, args: yieldOp.getValue());
593
594 return linalgOp.getResult()[0];
595}
596
597/// Bufferization of tensor.generate.
598struct GenerateOpInterface
599 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
600 tensor::GenerateOp> {
601
602 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
603
604 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
605 const BufferizationOptions &options,
606 BufferizationState &state) const {
607 auto generateOp = cast<tensor::GenerateOp>(Val: op);
608
609 auto type = generateOp.getResult().getType();
610
611 // TODO: Implement memory space for this op.
612 if (options.defaultMemorySpaceFn(type) != Attribute())
613 return op->emitError(message: "memory space not implemented yet");
614
615 // Allocate memory.
616 Location loc = op->getLoc();
617 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
618 b&: rewriter, loc, shapedValue: generateOp.getResult(), options, state,
619 /*copy=*/false);
620 if (failed(Result: tensorAlloc))
621 return failure();
622
623 Value result = lowerGenerateLikeOpBody(rewriter, loc, tensorDestination: *tensorAlloc,
624 dynamicSizes: generateOp.getDynamicExtents(),
625 generateBody&: generateOp.getBody());
626 rewriter.replaceOp(op: generateOp, newValues: result);
627
628 return success();
629 }
630};
631
632/// Bufferization of tensor.insert. Replace with memref.store.
633///
634/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
635/// implementations for DestinationStyle ops.
636struct InsertOpInterface
637 : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
638 tensor::InsertOp> {
639 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
640 const BufferizationOptions &options,
641 BufferizationState &state) const {
642 auto insertOp = cast<tensor::InsertOp>(Val: op);
643 FailureOr<Value> destMemref =
644 getBuffer(rewriter, value: insertOp.getDest(), options, state);
645 if (failed(Result: destMemref))
646 return failure();
647 rewriter.create<memref::StoreOp>(location: insertOp.getLoc(), args: insertOp.getScalar(),
648 args&: *destMemref, args: insertOp.getIndices());
649 replaceOpWithBufferizedValues(rewriter, op, values: *destMemref);
650 return success();
651 }
652};
653
654template <typename InsertOpTy>
655static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
656 OpOperand &opOperand) {
657 // The source is always read.
658 if (opOperand == insertSliceOp.getSourceMutable())
659 return true;
660
661 // For the destination, it depends...
662 assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
663
664 // Dest is not read if it is entirely overwritten. E.g.:
665 // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
666 bool allOffsetsZero =
667 llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger);
668 RankedTensorType destType = insertSliceOp.getDestType();
669 bool sizesMatchDestSizes =
670 areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
671 bool allStridesOne =
672 areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
673 return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
674}
675
676/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
677/// certain circumstances, this op can also be a no-op.
678///
679/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
680/// implementations for DestinationStyle ops.
681struct InsertSliceOpInterface
682 : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
683 tensor::InsertSliceOp> {
684 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
685 const AnalysisState &state) const {
686 return insertSliceOpRequiresRead(insertSliceOp: cast<tensor::InsertSliceOp>(Val: op),
687 opOperand);
688 }
689
690 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
691 const BufferizationOptions &options,
692 BufferizationState &state) const {
693 // insert_slice ops arise from tiling and bufferizing them out-of-place is
694 // generally a deal breaker. When used with loops, this ends up cloning the
695 // whole tensor on every single iteration and is a symptom of a
696 // catastrophically bad scheduling decision.
697 // TODO: be very loud about it or even consider failing the pass.
698 auto insertSliceOp = cast<tensor::InsertSliceOp>(Val: op);
699 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
700 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
701 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
702 Location loc = insertSliceOp.getLoc();
703
704 // Get destination buffer.
705 FailureOr<Value> dstMemref =
706 getBuffer(rewriter, value: insertSliceOp.getDest(), options, state);
707 if (failed(Result: dstMemref))
708 return failure();
709
710 // Take a subview of the destination buffer.
711 auto dstMemrefType = cast<MemRefType>(Val: dstMemref->getType());
712 MemRefType subviewMemRefType =
713 memref::SubViewOp::inferRankReducedResultType(
714 resultShape: insertSliceOp.getSourceType().getShape(), sourceMemRefType: dstMemrefType,
715 staticOffsets: mixedOffsets, staticSizes: mixedSizes, staticStrides: mixedStrides);
716 Value subView = rewriter.create<memref::SubViewOp>(
717 location: loc, args&: subviewMemRefType, args&: *dstMemref, args&: mixedOffsets, args&: mixedSizes,
718 args&: mixedStrides);
719
720 // Copy tensor. If this tensor.insert_slice has a matching
721 // tensor.extract_slice, the copy operation will eventually fold away.
722 FailureOr<Value> srcMemref =
723 getBuffer(rewriter, value: insertSliceOp.getSource(), options, state);
724 if (failed(Result: srcMemref))
725 return failure();
726 if (failed(Result: options.createMemCpy(b&: rewriter, loc, from: *srcMemref, to: subView)))
727 return failure();
728
729 replaceOpWithBufferizedValues(rewriter, op, values: *dstMemref);
730 return success();
731 }
732};
733
734/// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor +
735/// linalg.map + insert_slice.
736/// For best performance, vectorize before bufferization (better performance in
737/// case of padding with a constant).
738struct PadOpInterface
739 : public BufferizableOpInterface::ExternalModel<PadOpInterface,
740 tensor::PadOp> {
741 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
742
743 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
744 const AnalysisState &state) const {
745 return true;
746 }
747
748 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
749 const AnalysisState &state) const {
750 return false;
751 }
752
753 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
754 const AnalysisState &state) const {
755 return {};
756 }
757
758 FailureOr<BufferLikeType>
759 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
760 const BufferizationState &state,
761 SmallVector<Value> &invocationStack) const {
762 // Infer memory space from the source tensor.
763 auto padOp = cast<tensor::PadOp>(Val: op);
764 auto maybeSrcBufferType =
765 bufferization::detail::asMemRefType(bufferType: bufferization::getBufferType(
766 value: padOp.getSource(), options, state, invocationStack));
767 if (failed(Result: maybeSrcBufferType))
768 return failure();
769 MemRefLayoutAttrInterface layout;
770 return cast<BufferLikeType>(
771 Val: MemRefType::get(shape: padOp.getResultType().getShape(),
772 elementType: padOp.getResultType().getElementType(), layout,
773 memorySpace: maybeSrcBufferType->getMemorySpace()));
774 }
775
776 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
777 const BufferizationOptions &options,
778 BufferizationState &state) const {
779 auto padOp = cast<tensor::PadOp>(Val: op);
780 Location loc = padOp.getLoc();
781 RankedTensorType resultType = padOp.getResultType();
782 RankedTensorType srcType = padOp.getSourceType();
783
784 auto toValue = [&](OpFoldResult ofr) {
785 if (auto value = dyn_cast<Value>(Val&: ofr))
786 return value;
787 return rewriter
788 .create<arith::ConstantIndexOp>(location: loc, args: *getConstantIntValue(ofr))
789 .getResult();
790 };
791
792 // Compute dynamic result dimensions.
793 SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad();
794 SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad();
795 SmallVector<Value> dynamicSizes;
796 for (int64_t i = 0; i < resultType.getRank(); ++i) {
797 if (!resultType.isDynamicDim(idx: i))
798 continue;
799 Value srcDim = rewriter.create<tensor::DimOp>(location: loc, args: padOp.getSource(), args&: i);
800 Value lowPad = toValue(mixedLowPad[i]);
801 Value highPad = toValue(mixedHighPad[i]);
802 AffineExpr s0, s1, s2;
803 bindSymbols(ctx: op->getContext(), exprs&: s0, exprs&: s1, exprs&: s2);
804 AffineExpr sumExpr = s0 + s1 + s2;
805 Value sum = rewriter.create<affine::AffineApplyOp>(
806 location: loc, args&: sumExpr, args: ValueRange{srcDim, lowPad, highPad});
807 dynamicSizes.push_back(Elt: sum);
808 }
809
810 // Allocate a buffer for the padded result.
811 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
812 b&: rewriter, loc, shapedValue: padOp.getResult(), options, state,
813 /*copy=*/false);
814 if (failed(Result: tensorAlloc))
815 return failure();
816
817 // tensor::PadOp is like tensor::GenerateOp: The only difference is that
818 // only a part of the generated tensor is needed. For simplicity, we reuse
819 // the same functionality here.
820 Value filledBuffer = lowerGenerateLikeOpBody(
821 rewriter, loc, tensorDestination: *tensorAlloc, dynamicSizes, generateBody&: padOp.getBodyRegion());
822
823 // Create tensor::InsertSliceOp.
824 SmallVector<OpFoldResult> sliceSizes =
825 getMixedSizes(builder&: rewriter, loc, value: padOp.getSource());
826 SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
827 rewriter.getIndexAttr(value: 1));
828 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
829 op: padOp, args: padOp.getSource(), args&: filledBuffer,
830 /*offsets=*/args: padOp.getMixedLowPad(), args&: sliceSizes, args&: sliceStrides);
831
832 return success();
833 }
834};
835
836/// Bufferization of tensor.rank. Replace with memref.rank.
837struct RankOpInterface
838 : public BufferizableOpInterface::ExternalModel<RankOpInterface,
839 tensor::RankOp> {
840 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
841 const AnalysisState &state) const {
842 // The op reads the tensor's metadata but not its contents.
843 return false;
844 }
845
846 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
847 const AnalysisState &state) const {
848 return false;
849 }
850
851 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
852 const AnalysisState &state) const {
853 return {};
854 }
855
856 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
857 const BufferizationOptions &options,
858 BufferizationState &state) const {
859 auto rankOp = cast<tensor::RankOp>(Val: op);
860 FailureOr<Value> v =
861 getBuffer(rewriter, value: rankOp.getTensor(), options, state);
862 if (failed(Result: v))
863 return failure();
864 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, args: rankOp.getType(),
865 args&: *v);
866 return success();
867 }
868};
869
870/// Bufferization of tensor.reshape. Replace with memref.reshape.
871struct ReshapeOpInterface
872 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
873 tensor::ReshapeOp> {
874 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
875 const AnalysisState &state) const {
876 // Depending on the layout map, the source buffer may have to be copied.
877 auto reshapeOp = cast<tensor::ReshapeOp>(Val: op);
878 return opOperand == reshapeOp.getShapeMutable();
879 }
880
881 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
882 const AnalysisState &state) const {
883 return false;
884 }
885
886 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
887 const AnalysisState &state) const {
888 // Only the 'source' operand aliases the result.
889 auto reshapeOp = cast<tensor::ReshapeOp>(Val: op);
890 if (reshapeOp.getSourceMutable() != opOperand)
891 return {};
892 return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}};
893 }
894
895 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
896 const BufferizationOptions &options,
897 BufferizationState &state) const {
898 auto reshapeOp = cast<tensor::ReshapeOp>(Val: op);
899 FailureOr<Value> srcBuffer =
900 getBuffer(rewriter, value: reshapeOp.getSource(), options, state);
901 FailureOr<Value> shapeBuffer =
902 getBuffer(rewriter, value: reshapeOp.getShape(), options, state);
903 if (failed(Result: srcBuffer) || failed(Result: shapeBuffer))
904 return failure();
905 auto maybeResultMemRefType =
906 bufferization::getBufferType(value: reshapeOp.getResult(), options, state);
907 if (failed(Result: maybeResultMemRefType))
908 return failure();
909
910 // memref.reshape requires the source buffer to have an identity layout.
911 // If the source memref does not have an identity layout, copy the source
912 // into a new buffer with an identity layout.
913 auto srcType = llvm::dyn_cast<MemRefType>(Val: srcBuffer->getType());
914 if (srcType && !srcType.getLayout().isIdentity()) {
915 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
916 b&: rewriter, loc: op->getLoc(), shapedValue: reshapeOp.getSource(), options, state);
917 if (failed(Result: tensorAlloc))
918 return failure();
919 auto memrefType = MemRefType::get(
920 shape: srcType.getShape(), elementType: srcType.getElementType(), map: AffineMap(),
921 memorySpace: cast<BaseMemRefType>(Val: srcBuffer->getType()).getMemorySpace());
922 srcBuffer = rewriter
923 .create<bufferization::ToBufferOp>(
924 location: op->getLoc(), args&: memrefType, args&: *tensorAlloc)
925 .getResult();
926 }
927
928 replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
929 rewriter, op, args&: maybeResultMemRefType.value(), args&: *srcBuffer, args&: *shapeBuffer);
930 return success();
931 }
932
933 FailureOr<BufferLikeType>
934 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
935 const BufferizationState &state,
936 SmallVector<Value> &invocationStack) const {
937 auto reshapeOp = cast<tensor::ReshapeOp>(Val: op);
938 assert(value == reshapeOp.getResult() && "unexpected value provided");
939 auto maybeSourceBufferType = bufferization::getBufferType(
940 value: reshapeOp.getSource(), options, state, invocationStack);
941 if (failed(Result: maybeSourceBufferType))
942 return failure();
943 return cast<BufferLikeType>(Val: getMemRefTypeWithStaticIdentityLayout(
944 tensorType: reshapeOp.getResult().getType(),
945 memorySpace: cast<BaseMemRefType>(Val&: maybeSourceBufferType.value()).getMemorySpace()));
946 }
947};
948
949/// Analysis of ParallelInsertSliceOp.
950struct ParallelInsertSliceOpInterface
951 : public BufferizableOpInterface::ExternalModel<
952 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
953 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
954 const AnalysisState &state) const {
955 return {};
956 }
957
958 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
959 const AnalysisState &state) const {
960 return opOperand == cast<ParallelInsertSliceOp>(Val: op).getSourceMutable();
961 }
962
963 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
964 const AnalysisState &state) const {
965 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(Val: op);
966 return opOperand == parallelInsertSliceOp.getDestMutable();
967 }
968
969 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
970 const BufferizationOptions &options,
971 BufferizationState &state) const {
972 OpBuilder::InsertionGuard g(rewriter);
973 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(Val: op);
974 ParallelCombiningOpInterface parallelCombiningParent =
975 parallelInsertSliceOp.getParallelCombiningParent();
976
977 // Bufferize the op outside of the parallel combining terminator.
978 rewriter.setInsertionPoint(parallelCombiningParent);
979
980 // Get source and destination buffers.
981 FailureOr<Value> destBuffer =
982 getBuffer(rewriter, value: parallelInsertSliceOp.getDest(), options, state);
983 if (failed(Result: destBuffer))
984 return failure();
985 FailureOr<Value> srcBuffer =
986 getBuffer(rewriter, value: parallelInsertSliceOp.getSource(), options, state);
987 if (failed(Result: srcBuffer))
988 return failure();
989
990 // Take a subview of the destination buffer.
991 auto destBufferType = cast<MemRefType>(Val: destBuffer->getType());
992 MemRefType subviewMemRefType =
993 memref::SubViewOp::inferRankReducedResultType(
994 resultShape: parallelInsertSliceOp.getSourceType().getShape(), sourceMemRefType: destBufferType,
995 staticOffsets: parallelInsertSliceOp.getMixedOffsets(),
996 staticSizes: parallelInsertSliceOp.getMixedSizes(),
997 staticStrides: parallelInsertSliceOp.getMixedStrides());
998 Value subview = rewriter.create<memref::SubViewOp>(
999 location: parallelInsertSliceOp.getLoc(), args&: subviewMemRefType, args&: *destBuffer,
1000 args: parallelInsertSliceOp.getMixedOffsets(),
1001 args: parallelInsertSliceOp.getMixedSizes(),
1002 args: parallelInsertSliceOp.getMixedStrides());
1003
1004 // This memcpy will fold away if everything bufferizes in-place.
1005 if (failed(Result: options.createMemCpy(b&: rewriter, loc: parallelInsertSliceOp.getLoc(),
1006 from: *srcBuffer, to: subview)))
1007 return failure();
1008
1009 // In case the source was allocated in the same block, make sure that the
1010 // deallocation op (if any) appears after the memcpy. By default, deallocs
1011 // are placed before the terminator, but this does not work for ForallOp
1012 // because the terminator does more than just yielding a value.
1013 //
1014 // Note: This is not a problem for the destination buffer because these are
1015 // assumed to always bufferize in-place.
1016 for (Operation *user : srcBuffer->getUsers()) {
1017 if (hasEffect<MemoryEffects::Free>(op: user)) {
1018 if (user->getBlock() == parallelCombiningParent->getBlock())
1019 rewriter.moveOpBefore(op: user, existingOp: user->getBlock()->getTerminator());
1020 break;
1021 }
1022 }
1023
1024 // Delete the op.
1025 rewriter.eraseOp(op);
1026 return success();
1027 }
1028
1029 /// tensor.parallel_insert_slice op has implicit inplace behavior. We
1030 /// shouldn't create copy to resolve conflict.
1031 LogicalResult
1032 resolveConflicts(Operation *op, RewriterBase &rewriter,
1033 const AnalysisState &analysisState,
1034 const BufferizationState &bufferizationState) const {
1035 return success();
1036 }
1037};
1038
1039/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
1040/// with a linalg.map. Similar to tensor.generate.
1041struct SplatOpInterface
1042 : public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1043 tensor::SplatOp> {
1044
1045 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1046
1047 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1048 const BufferizationOptions &options,
1049 BufferizationState &state) const {
1050 OpBuilder::InsertionGuard g(rewriter);
1051 auto splatOp = cast<tensor::SplatOp>(Val: op);
1052
1053 // Allocate memory.
1054 Location loc = op->getLoc();
1055 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1056 b&: rewriter, loc, shapedValue: splatOp.getResult(), options, state,
1057 /*copy=*/false);
1058 if (failed(Result: tensorAlloc))
1059 return failure();
1060
1061 // Create linalg::MapOp.
1062 auto tensorType = cast<RankedTensorType>(Val: tensorAlloc->getType());
1063
1064 // TODO: Implement memory space for this op.
1065 if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1066 return op->emitError(message: "memory space not implemented yet");
1067
1068 auto linalgOp =
1069 rewriter.create<linalg::MapOp>(location: loc, args&: tensorType, /*inputs=*/args: ValueRange(),
1070 /*init=*/args&: *tensorAlloc);
1071 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1072
1073 // Create linalg::IndexOps.
1074 rewriter.setInsertionPointToStart(&linalgBody);
1075 rewriter.create<linalg::YieldOp>(location: loc, args: splatOp.getInput());
1076 rewriter.replaceOp(op: splatOp, newValues: linalgOp.getResult()[0]);
1077
1078 return success();
1079 }
1080};
1081
1082/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
1083/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1084/// on subviews instead of memref.store.
1085struct ConcatOpInterface
1086 : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1087 tensor::ConcatOp> {
1088
1089 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1090
1091 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1092 const AnalysisState &state) const {
1093 return false;
1094 }
1095
1096 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1097 const AnalysisState &state) const {
1098 return true;
1099 }
1100
1101 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1102 const AnalysisState &state) const {
1103 return {};
1104 }
1105
1106 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1107 const BufferizationOptions &options,
1108 BufferizationState &state) const {
1109 OpBuilder::InsertionGuard g(rewriter);
1110 auto concatOp = cast<tensor::ConcatOp>(Val: op);
1111
1112 // Allocate memory.
1113 Location loc = op->getLoc();
1114 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1115 b&: rewriter, loc, shapedValue: concatOp.getResult(), options, state,
1116 /*copy=*/false);
1117 if (failed(Result: tensorAlloc))
1118 return failure();
1119 auto tensorType = cast<RankedTensorType>(Val: tensorAlloc->getType());
1120
1121 // TODO: Implement memory space for this op.
1122 if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1123 return op->emitError(message: "memory space not implemented yet");
1124
1125 MemRefLayoutAttrInterface layout;
1126 MemRefType memrefType =
1127 MemRefType::get(shape: concatOp.getResultType().getShape(),
1128 elementType: concatOp.getResultType().getElementType(), layout);
1129 Value dstBuffer = rewriter.create<bufferization::ToBufferOp>(
1130 location: op->getLoc(), args&: memrefType, args&: *tensorAlloc);
1131
1132 // Extract the dimension for the concat op
1133 uint64_t concatDim = concatOp.getDim();
1134 bool dynamicConcatDim = false;
1135
1136 SmallVector<OpFoldResult> offsets(tensorType.getRank(),
1137 rewriter.getIndexAttr(value: 0));
1138 SmallVector<OpFoldResult> strides(tensorType.getRank(),
1139 rewriter.getIndexAttr(value: 1));
1140 SmallVector<OpFoldResult> sizes;
1141
1142 for (const auto &[dimIdx, dimSize] :
1143 llvm::enumerate(First: tensorType.getShape())) {
1144 if (dimSize == ShapedType::kDynamic) {
1145 auto dimOp = rewriter.create<memref::DimOp>(location: loc, args&: dstBuffer, args&: dimIdx);
1146 sizes.push_back(Elt: dimOp.getResult());
1147 if (dimIdx == concatDim)
1148 dynamicConcatDim = true;
1149 } else {
1150 sizes.push_back(Elt: rewriter.getIndexAttr(value: dimSize));
1151 }
1152 }
1153
1154 int64_t concatDimOffset = 0;
1155 std::optional<Value> dynamicOffset;
1156 std::optional<Value> dynamicSize;
1157 if (dynamicConcatDim) {
1158 // One or more operands have dynamic size, so we must accumulate the
1159 // offset with arith ops.
1160 dynamicOffset = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1161 }
1162
1163 for (auto operand : concatOp.getInputs()) {
1164 // Get the buffer for the operand.
1165 FailureOr<Value> srcBuffer = getBuffer(rewriter, value: operand, options, state);
1166 if (failed(Result: srcBuffer))
1167 return failure();
1168
1169 // Each operand may have a different size along the concat dimension,
1170 // so the offset on that axis must accumulate through the loop, and the
1171 // size must change to the size of the current operand.
1172 auto operandTensorType = cast<RankedTensorType>(Val: operand.getType());
1173 int64_t operandConcatDimSize = operandTensorType.getDimSize(idx: concatDim);
1174
1175 if (dynamicConcatDim) {
1176 offsets[concatDim] = dynamicOffset.value();
1177 dynamicSize = rewriter.create<memref::DimOp>(location: loc, args&: *srcBuffer, args&: concatDim)
1178 .getResult();
1179 sizes[concatDim] = dynamicSize.value();
1180 } else {
1181 sizes[concatDim] = rewriter.getIndexAttr(value: operandConcatDimSize);
1182 offsets[concatDim] = rewriter.getIndexAttr(value: concatDimOffset);
1183 }
1184
1185 // Create a subview of the destination buffer.
1186 auto dstMemrefType = cast<MemRefType>(Val&: memrefType);
1187 MemRefType subviewMemRefType =
1188 memref::SubViewOp::inferRankReducedResultType(
1189 resultShape: operandTensorType.getShape(), sourceMemRefType: dstMemrefType, staticOffsets: offsets, staticSizes: sizes,
1190 staticStrides: strides);
1191 Value subview = rewriter.create<memref::SubViewOp>(
1192 location: loc, args&: subviewMemRefType, args&: dstBuffer, args&: offsets, args&: sizes, args&: strides);
1193
1194 // Copy the source buffer into the destination subview.
1195 if (failed(Result: options.createMemCpy(b&: rewriter, loc, from: *srcBuffer, to: subview)))
1196 return failure();
1197
1198 if (dynamicConcatDim) {
1199 dynamicOffset = rewriter.create<arith::AddIOp>(
1200 location: loc, args&: dynamicOffset.value(), args&: dynamicSize.value());
1201 } else {
1202 concatDimOffset += operandConcatDimSize;
1203 }
1204 }
1205
1206 replaceOpWithBufferizedValues(rewriter, op, values: dstBuffer);
1207 return success();
1208 }
1209};
1210
1211} // namespace
1212} // namespace tensor
1213} // namespace mlir
1214
1215void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
1216 DialectRegistry &registry) {
1217 registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1218 CastOp::attachInterface<CastOpInterface>(context&: *ctx);
1219 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(context&: *ctx);
1220 ConcatOp::attachInterface<ConcatOpInterface>(context&: *ctx);
1221 DimOp::attachInterface<DimOpInterface>(context&: *ctx);
1222 EmptyOp::attachInterface<EmptyOpInterface>(context&: *ctx);
1223 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(context&: *ctx);
1224 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(context&: *ctx);
1225 ExtractOp::attachInterface<ExtractOpInterface>(context&: *ctx);
1226 FromElementsOp::attachInterface<FromElementsOpInterface>(context&: *ctx);
1227 GenerateOp::attachInterface<GenerateOpInterface>(context&: *ctx);
1228 InsertOp::attachInterface<InsertOpInterface>(context&: *ctx);
1229 InsertSliceOp::attachInterface<InsertSliceOpInterface>(context&: *ctx);
1230 PadOp::attachInterface<PadOpInterface>(context&: *ctx);
1231 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1232 context&: *ctx);
1233 RankOp::attachInterface<RankOpInterface>(context&: *ctx);
1234 ReshapeOp::attachInterface<ReshapeOpInterface>(context&: *ctx);
1235 SplatOp::attachInterface<SplatOpInterface>(context&: *ctx);
1236
1237 // Load additional dialects of which ops may get created.
1238 ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
1239 });
1240
1241 // Bufferization requires SubsetInsertionOpInterface models. Make sure that
1242 // they are registered.
1243 tensor::registerSubsetOpInterfaceExternalModels(registry);
1244}
1245

source code of mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp