1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
16#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/IR/BuiltinTypeInterfaces.h"
23#include "mlir/IR/Dialect.h"
24#include "mlir/IR/Operation.h"
25
26using namespace mlir;
27using namespace mlir::bufferization;
28using namespace mlir::tensor;
29
30namespace mlir {
31namespace tensor {
32namespace {
33
34struct CastOpInterface
35 : public BufferizableOpInterface::ExternalModel<CastOpInterface,
36 tensor::CastOp> {
37 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
38 const AnalysisState &state) const {
39 return false;
40 }
41
42 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
43 const AnalysisState &state) const {
44 return false;
45 }
46
47 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
48 const AnalysisState &state) const {
49 return {{op->getResult(idx: 0), BufferRelation::Equivalent}};
50 }
51
52 FailureOr<BaseMemRefType>
53 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
54 const BufferizationState &state,
55 SmallVector<Value> &invocationStack) const {
56 auto castOp = cast<tensor::CastOp>(op);
57 auto maybeSrcBufferType = bufferization::getBufferType(
58 value: castOp.getSource(), options, state, invocationStack);
59 if (failed(maybeSrcBufferType))
60 return failure();
61 Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
62
63 // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
64 // type in case the input is an unranked tensor type.
65
66 // Case 1: Casting an unranked tensor
67 if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
68 // When casting to a ranked tensor, we cannot infer any static offset or
69 // strides from the source. Assume fully dynamic.
70 return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
71 }
72
73 // Case 2: Casting to an unranked tensor type
74 if (isa<UnrankedTensorType>(castOp.getType())) {
75 return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
76 }
77
78 // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
79 // change.
80 auto rankedResultType = cast<RankedTensorType>(castOp.getType());
81 return MemRefType::get(
82 rankedResultType.getShape(), rankedResultType.getElementType(),
83 llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
84 }
85
86 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
87 const BufferizationOptions &options,
88 BufferizationState &state) const {
89 auto castOp = cast<tensor::CastOp>(op);
90
91 // The result buffer still has the old (pre-cast) type.
92 FailureOr<Value> resultBuffer =
93 getBuffer(rewriter, castOp.getSource(), options, state);
94 if (failed(Result: resultBuffer))
95 return failure();
96
97 // Compute the new type.
98 auto resultMemRefType =
99 bufferization::getBufferType(value: castOp.getResult(), options, state);
100 if (failed(resultMemRefType))
101 return failure();
102 if (resultBuffer->getType() == *resultMemRefType) {
103 // This cast is a no-op.
104 replaceOpWithBufferizedValues(rewriter, op, values: *resultBuffer);
105 return success();
106 }
107
108 // Replace the op with a memref.cast.
109 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
110 *resultMemRefType) &&
111 "CallOp::bufferize: cast incompatible");
112 replaceOpWithNewBufferizedOp<memref::CastOp>(
113 rewriter, op, *resultMemRefType, *resultBuffer);
114
115 return success();
116 }
117};
118
119/// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
120struct CollapseShapeOpInterface
121 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
122 tensor::CollapseShapeOp> {
123 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
124 const AnalysisState &state) const {
125 // tensor.collapse_shape may reallocate, at which point the source buffer is
126 // copied. I.e., there will be a memory read side effect on the bufferized
127 // source. This function conservatively returns "true" because whether a
128 // copy will be created or not is not known at this point.
129 return true;
130 }
131
132 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
133 const AnalysisState &state) const {
134 return false;
135 }
136
137 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
138 const AnalysisState &state) const {
139 // TODO: CollapseShapeOp may allocate at runtime.
140 return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}};
141 }
142
143 FailureOr<BaseMemRefType>
144 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
145 const BufferizationState &state,
146 SmallVector<Value> &invocationStack) const {
147 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
148 auto maybeSrcBufferType = bufferization::getBufferType(
149 value: collapseShapeOp.getSrc(), options, state, invocationStack);
150 if (failed(maybeSrcBufferType))
151 return failure();
152 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
153 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
154 srcBufferType, collapseShapeOp.getReassociationIndices());
155
156 if (!canBeCollapsed) {
157 // If dims cannot be collapsed, this op bufferizes to a new allocation.
158 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
159 return bufferization::getMemRefTypeWithStaticIdentityLayout(
160 tensorType: tensorResultType, memorySpace: srcBufferType.getMemorySpace());
161 }
162
163 return memref::CollapseShapeOp::computeCollapsedType(
164 srcBufferType, collapseShapeOp.getReassociationIndices());
165 }
166
167 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
168 const BufferizationOptions &options,
169 BufferizationState &state) const {
170 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
171 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
172 FailureOr<Value> maybeBuffer =
173 getBuffer(rewriter, collapseShapeOp.getSrc(), options, state);
174 if (failed(Result: maybeBuffer))
175 return failure();
176 Value buffer = *maybeBuffer;
177 auto bufferType = cast<MemRefType>(buffer.getType());
178
179 if (tensorResultType.getRank() == 0) {
180 // 0-d collapses must go through a different op builder.
181 MemRefType resultType;
182
183 if (bufferType.getLayout().isIdentity()) {
184 // Standard layout: result type has no offset.
185 MemRefLayoutAttrInterface layout;
186 resultType = MemRefType::get({}, tensorResultType.getElementType(),
187 layout, bufferType.getMemorySpace());
188 } else {
189 // Source memref has a layout map: result type has the same offset as
190 // the source type.
191 SmallVector<int64_t> strides;
192 int64_t offset;
193 if (failed(bufferType.getStridesAndOffset(strides, offset)))
194 return failure();
195 resultType = MemRefType::get(
196 {}, tensorResultType.getElementType(),
197 StridedLayoutAttr::get(op->getContext(), offset, {}),
198 bufferType.getMemorySpace());
199 }
200
201 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
202 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
203 return success();
204 }
205
206 // If the dims are not collapsible (due to an incompatible source layout
207 // map), force an out-of-place bufferization, i.e., a buffer copy. This
208 // newly allocated buffer will have no layout map and thus be collapsible.
209 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
210 bufferType, collapseShapeOp.getReassociationIndices());
211 if (!canBeCollapsed) {
212 // TODO: Create alloc_tensor ops during TensorCopyInsertion.
213 AnalysisState analysisState(options);
214 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
215 rewriter, op->getLoc(), collapseShapeOp.getSrc(), options, state);
216 if (failed(Result: tensorAlloc))
217 return failure();
218 auto memrefType =
219 MemRefType::get(collapseShapeOp.getSrcType().getShape(),
220 collapseShapeOp.getSrcType().getElementType(),
221 AffineMap(), bufferType.getMemorySpace());
222 buffer = rewriter.create<bufferization::ToBufferOp>(
223 op->getLoc(), memrefType, *tensorAlloc);
224 }
225
226 // Result type is inferred by the builder.
227 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
228 rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
229 return success();
230 }
231};
232
233/// Bufferization of tensor.dim. Replace with memref.dim.
234struct DimOpInterface
235 : public BufferizableOpInterface::ExternalModel<DimOpInterface,
236 tensor::DimOp> {
237 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
238 const AnalysisState &state) const {
239 // The op reads the tensor's metadata but not its contents.
240 return false;
241 }
242
243 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
244 const AnalysisState &state) const {
245 return false;
246 }
247
248 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
249 const AnalysisState &state) const {
250 return {};
251 }
252
253 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
254 const BufferizationOptions &options,
255 BufferizationState &state) const {
256 auto dimOp = cast<tensor::DimOp>(op);
257 FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options, state);
258 if (failed(Result: v))
259 return failure();
260 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
261 dimOp.getIndex());
262 return success();
263 }
264};
265
266/// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor".
267struct EmptyOpInterface
268 : public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
269 tensor::EmptyOp> {
270 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
271
272 bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
273 const AnalysisState &state) const {
274 // The returned tensor does not have specified contents.
275 return false;
276 }
277
278 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
279 const BufferizationOptions &options,
280 BufferizationState &state) const {
281 auto emptyOp = cast<tensor::EmptyOp>(op);
282
283 // Optimization: Fold away the op if it has no uses.
284 if (op->getUses().empty()) {
285 rewriter.eraseOp(op);
286 return success();
287 }
288
289 // Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
290 FailureOr<Value> allocTensor = allocateTensorForShapedValue(
291 rewriter, op->getLoc(), emptyOp.getResult(), options, state,
292 /*copy=*/false);
293 if (failed(Result: allocTensor))
294 return failure();
295 rewriter.replaceOp(op, newValues: *allocTensor);
296 return success();
297 }
298};
299
300/// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
301struct ExpandShapeOpInterface
302 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
303 tensor::ExpandShapeOp> {
304 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
305 const AnalysisState &state) const {
306 // In contrast to tensor.collapse_shape, this op can always be bufferized
307 // without a copy.
308 return false;
309 }
310
311 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
312 const AnalysisState &state) const {
313 return false;
314 }
315
316 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
317 const AnalysisState &state) const {
318 return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}};
319 }
320
321 FailureOr<BaseMemRefType>
322 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
323 const BufferizationState &state,
324 SmallVector<Value> &invocationStack) const {
325 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
326 auto maybeSrcBufferType = bufferization::getBufferType(
327 value: expandShapeOp.getSrc(), options, state, invocationStack);
328 if (failed(maybeSrcBufferType))
329 return failure();
330 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
331 auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
332 srcBufferType, expandShapeOp.getResultType().getShape(),
333 expandShapeOp.getReassociationIndices());
334 if (failed(maybeResultType))
335 return failure();
336 return *maybeResultType;
337 }
338
339 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
340 const BufferizationOptions &options,
341 BufferizationState &state) const {
342 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
343 auto tensorResultType = expandShapeOp.getResultType();
344 FailureOr<Value> buffer =
345 getBuffer(rewriter, expandShapeOp.getSrc(), options, state);
346 if (failed(Result: buffer))
347 return failure();
348
349 auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
350 op->getLoc(), tensorResultType.getShape(), *buffer,
351 expandShapeOp.getReassociationIndices(),
352 expandShapeOp.getMixedOutputShape());
353 replaceOpWithBufferizedValues(rewriter, op,
354 memrefExpandShape->getResults());
355 return success();
356 }
357};
358
359/// Bufferization of tensor.extract_slice. Replace with memref.subview.
360struct ExtractSliceOpInterface
361 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
362 tensor::ExtractSliceOp> {
363 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
364 const AnalysisState &state) const {
365 return false;
366 }
367
368 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
369 const AnalysisState &state) const {
370 return false;
371 }
372
373 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
374 const AnalysisState &state) const {
375 return {{op->getOpResult(idx: 0), BufferRelation::Unknown}};
376 }
377
378 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
379 const BufferizationOptions &options,
380 BufferizationState &state) const {
381 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
382 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
383 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
384 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
385 Location loc = extractSliceOp.getLoc();
386
387 // Get source buffer.
388 FailureOr<Value> srcMemref =
389 getBuffer(rewriter, extractSliceOp.getSource(), options, state);
390 if (failed(Result: srcMemref))
391 return failure();
392
393 // Take a subview of the source buffer.
394 auto resultMemrefType = bufferization::getBufferType(
395 value: extractSliceOp.getResult(), options, state);
396 if (failed(resultMemrefType))
397 return failure();
398 Value subView = rewriter.create<memref::SubViewOp>(
399 loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
400 mixedOffsets, mixedSizes, mixedStrides);
401
402 replaceOpWithBufferizedValues(rewriter, op, values: subView);
403 return success();
404 }
405
406 FailureOr<BaseMemRefType>
407 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
408 const BufferizationState &state,
409 SmallVector<Value> &invocationStack) const {
410 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
411 assert(value == extractSliceOp.getResult() && "invalid value");
412 auto srcMemrefType = bufferization::getBufferType(
413 value: extractSliceOp.getSource(), options, state, invocationStack);
414 if (failed(srcMemrefType))
415 return failure();
416 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
417 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
418 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
419 return memref::SubViewOp::inferRankReducedResultType(
420 extractSliceOp.getType().getShape(),
421 llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
422 mixedStrides);
423 }
424};
425
426/// Bufferization of tensor.extract. Replace with memref.load.
427struct ExtractOpInterface
428 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
429 tensor::ExtractOp> {
430 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
431 const AnalysisState &state) const {
432 return true;
433 }
434
435 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
436 const AnalysisState &state) const {
437 return false;
438 }
439
440 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
441 const AnalysisState &state) const {
442 return {};
443 }
444
445 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
446 const BufferizationOptions &options,
447 BufferizationState &state) const {
448 auto extractOp = cast<tensor::ExtractOp>(op);
449 FailureOr<Value> srcMemref =
450 getBuffer(rewriter, extractOp.getTensor(), options, state);
451 if (failed(Result: srcMemref))
452 return failure();
453 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
454 extractOp.getIndices());
455 return success();
456 }
457};
458
459// Implements backtracking to traverse indices of the output buffer while
460// iterating over op.elements().
461static void createStores(RewriterBase &rewriter, Location loc, int dim,
462 Value buffer, ArrayRef<int64_t> shape,
463 ArrayRef<Value> constants,
464 OperandRange::iterator &elementIt,
465 SmallVectorImpl<Value> &indices) {
466 if (dim == static_cast<int>(shape.size()) - 1) {
467 for (int i = 0; i < shape.back(); ++i) {
468 indices.back() = constants[i];
469 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
470 ++elementIt;
471 }
472 return;
473 }
474 for (int i = 0; i < shape[dim]; ++i) {
475 indices[dim] = constants[i];
476 createStores(rewriter, loc, dim: dim + 1, buffer, shape, constants, elementIt,
477 indices);
478 }
479}
480
481/// Bufferization of tensor.from_elements.
482struct FromElementsOpInterface
483 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
484 tensor::FromElementsOp> {
485
486 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
487
488 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
489 const BufferizationOptions &options,
490 BufferizationState &state) const {
491 auto fromElementsOp = cast<tensor::FromElementsOp>(op);
492 auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
493
494 // Allocate a buffer for the result.
495 Location loc = op->getLoc();
496 auto shape = tensorType.getShape();
497 // TODO: Create alloc_tensor ops during TensorCopyInsertion.
498 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
499 rewriter, loc, fromElementsOp.getResult(), options, state,
500 /*copy=*/false);
501 if (failed(Result: tensorAlloc))
502 return failure();
503 FailureOr<BaseMemRefType> memrefType =
504 bufferization::getBufferType(value: *tensorAlloc, options, state);
505 if (failed(Result: memrefType))
506 return failure();
507 Value buffer = rewriter.create<bufferization::ToBufferOp>(
508 op->getLoc(), *memrefType, *tensorAlloc);
509
510 // Case: tensor<0xelem_type>.
511 if (fromElementsOp.getElements().empty()) {
512 replaceOpWithBufferizedValues(rewriter, op, values: buffer);
513 return success();
514 }
515
516 // Case: tensor<elem_type>.
517 if (shape.empty()) {
518 rewriter.create<memref::StoreOp>(
519 loc, fromElementsOp.getElements().front(), buffer);
520 replaceOpWithBufferizedValues(rewriter, op, values: buffer);
521 return success();
522 }
523
524 // Create constants for the range of possible indices [0, max{shape_i}).
525 auto maxDim = *llvm::max_element(shape);
526 SmallVector<Value, 2> constants;
527 constants.reserve(N: maxDim);
528 for (int i = 0; i < maxDim; ++i)
529 constants.push_back(rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i));
530
531 // Traverse all `elements` and create `memref.store` ops.
532 auto elementIt = fromElementsOp.getElements().begin();
533 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
534 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
535 indices);
536
537 replaceOpWithBufferizedValues(rewriter, op, values: buffer);
538
539 return success();
540 }
541};
542
543/// Lower the body of a tensor.generate like op (one index-typed bbArg per dim).
544/// Such ops are lowered to linalg.map with the given tensor as a destination.
545///
546/// Example:
547/// ```
548/// %r = tensor.generate %x, %y {
549/// ^bb0(%arg0: index, %arg1: index):
550/// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index)
551/// tensor.yield %0 : index
552/// } : tensor<?x?xindex>
553/// ```
554///
555/// Is lowered to:
556/// ```
557/// linalg.map ins() outs(%dest) {
558/// %d0 = linalg.index 0 : index
559/// %d1 = linalg.index 1 : index
560/// %0 = "some_op"(%d0, %d1) : (index, index) -> (index)
561/// linalg.yield %0 : index
562/// }
563/// ```
564static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
565 Value tensorDestination,
566 ValueRange dynamicSizes,
567 Region &generateBody) {
568 assert(generateBody.hasOneBlock() && "expected body with single block");
569 auto tensorType = cast<RankedTensorType>(tensorDestination.getType());
570 assert(generateBody.getNumArguments() == tensorType.getRank() &&
571 "rank mismatch");
572
573 // Create linalg::MapOp.
574 OpBuilder::InsertionGuard g(rewriter);
575 auto linalgOp =
576 rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
577 /*init=*/tensorDestination);
578 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
579
580 // Create linalg::IndexOps.
581 rewriter.setInsertionPointToStart(&linalgBody);
582 SmallVector<Value> indices;
583 for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
584 indices.push_back(rewriter.create<linalg::IndexOp>(loc, dim));
585
586 // Move over body.
587 rewriter.mergeBlocks(source: &generateBody.front(), dest: &linalgBody, argValues: indices);
588 auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
589 rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
590
591 return linalgOp.getResult()[0];
592}
593
594/// Bufferization of tensor.generate.
595struct GenerateOpInterface
596 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
597 tensor::GenerateOp> {
598
599 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
600
601 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
602 const BufferizationOptions &options,
603 BufferizationState &state) const {
604 auto generateOp = cast<tensor::GenerateOp>(op);
605
606 auto type = generateOp.getResult().getType();
607
608 // TODO: Implement memory space for this op.
609 if (options.defaultMemorySpaceFn(type) != Attribute())
610 return op->emitError(message: "memory space not implemented yet");
611
612 // Allocate memory.
613 Location loc = op->getLoc();
614 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
615 rewriter, loc, generateOp.getResult(), options, state,
616 /*copy=*/false);
617 if (failed(Result: tensorAlloc))
618 return failure();
619
620 Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
621 generateOp.getDynamicExtents(),
622 generateOp.getBody());
623 rewriter.replaceOp(generateOp, result);
624
625 return success();
626 }
627};
628
629/// Bufferization of tensor.insert. Replace with memref.store.
630///
631/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
632/// implementations for DestinationStyle ops.
633struct InsertOpInterface
634 : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
635 tensor::InsertOp> {
636 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
637 const BufferizationOptions &options,
638 BufferizationState &state) const {
639 auto insertOp = cast<tensor::InsertOp>(op);
640 FailureOr<Value> destMemref =
641 getBuffer(rewriter, insertOp.getDest(), options, state);
642 if (failed(Result: destMemref))
643 return failure();
644 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
645 *destMemref, insertOp.getIndices());
646 replaceOpWithBufferizedValues(rewriter, op, values: *destMemref);
647 return success();
648 }
649};
650
651template <typename InsertOpTy>
652static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
653 OpOperand &opOperand) {
654 // The source is always read.
655 if (opOperand == insertSliceOp.getSourceMutable())
656 return true;
657
658 // For the destination, it depends...
659 assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
660
661 // Dest is not read if it is entirely overwritten. E.g.:
662 // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
663 bool allOffsetsZero =
664 llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger);
665 RankedTensorType destType = insertSliceOp.getDestType();
666 bool sizesMatchDestSizes =
667 areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
668 bool allStridesOne =
669 areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
670 return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
671}
672
673/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
674/// certain circumstances, this op can also be a no-op.
675///
676/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
677/// implementations for DestinationStyle ops.
678struct InsertSliceOpInterface
679 : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
680 tensor::InsertSliceOp> {
681 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
682 const AnalysisState &state) const {
683 return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
684 opOperand);
685 }
686
687 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
688 const BufferizationOptions &options,
689 BufferizationState &state) const {
690 // insert_slice ops arise from tiling and bufferizing them out-of-place is
691 // generally a deal breaker. When used with loops, this ends up cloning the
692 // whole tensor on every single iteration and is a symptom of a
693 // catastrophically bad scheduling decision.
694 // TODO: be very loud about it or even consider failing the pass.
695 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
696 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
697 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
698 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
699 Location loc = insertSliceOp.getLoc();
700
701 // Get destination buffer.
702 FailureOr<Value> dstMemref =
703 getBuffer(rewriter, insertSliceOp.getDest(), options, state);
704 if (failed(Result: dstMemref))
705 return failure();
706
707 // Take a subview of the destination buffer.
708 auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
709 MemRefType subviewMemRefType =
710 memref::SubViewOp::inferRankReducedResultType(
711 insertSliceOp.getSourceType().getShape(), dstMemrefType,
712 mixedOffsets, mixedSizes, mixedStrides);
713 Value subView = rewriter.create<memref::SubViewOp>(
714 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
715 mixedStrides);
716
717 // Copy tensor. If this tensor.insert_slice has a matching
718 // tensor.extract_slice, the copy operation will eventually fold away.
719 FailureOr<Value> srcMemref =
720 getBuffer(rewriter, insertSliceOp.getSource(), options, state);
721 if (failed(Result: srcMemref))
722 return failure();
723 if (failed(Result: options.createMemCpy(b&: rewriter, loc, from: *srcMemref, to: subView)))
724 return failure();
725
726 replaceOpWithBufferizedValues(rewriter, op, values: *dstMemref);
727 return success();
728 }
729};
730
731/// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor +
732/// linalg.map + insert_slice.
733/// For best performance, vectorize before bufferization (better performance in
734/// case of padding with a constant).
735struct PadOpInterface
736 : public BufferizableOpInterface::ExternalModel<PadOpInterface,
737 tensor::PadOp> {
738 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
739
740 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
741 const AnalysisState &state) const {
742 return true;
743 }
744
745 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
746 const AnalysisState &state) const {
747 return false;
748 }
749
750 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
751 const AnalysisState &state) const {
752 return {};
753 }
754
755 FailureOr<BaseMemRefType>
756 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
757 const BufferizationState &state,
758 SmallVector<Value> &invocationStack) const {
759 // Infer memory space from the source tensor.
760 auto padOp = cast<tensor::PadOp>(op);
761 auto maybeSrcBufferType = bufferization::getBufferType(
762 value: padOp.getSource(), options, state, invocationStack);
763 if (failed(maybeSrcBufferType))
764 return failure();
765 MemRefLayoutAttrInterface layout;
766 return MemRefType::get(padOp.getResultType().getShape(),
767 padOp.getResultType().getElementType(), layout,
768 maybeSrcBufferType->getMemorySpace());
769 }
770
771 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
772 const BufferizationOptions &options,
773 BufferizationState &state) const {
774 auto padOp = cast<tensor::PadOp>(op);
775 Location loc = padOp.getLoc();
776 RankedTensorType resultType = padOp.getResultType();
777 RankedTensorType srcType = padOp.getSourceType();
778
779 auto toValue = [&](OpFoldResult ofr) {
780 if (auto value = dyn_cast<Value>(Val&: ofr))
781 return value;
782 return rewriter
783 .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr))
784 .getResult();
785 };
786
787 // Compute dynamic result dimensions.
788 SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad();
789 SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad();
790 SmallVector<Value> dynamicSizes;
791 for (int64_t i = 0; i < resultType.getRank(); ++i) {
792 if (!resultType.isDynamicDim(i))
793 continue;
794 Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i);
795 Value lowPad = toValue(mixedLowPad[i]);
796 Value highPad = toValue(mixedHighPad[i]);
797 AffineExpr s0, s1, s2;
798 bindSymbols(ctx: op->getContext(), exprs&: s0, exprs&: s1, exprs&: s2);
799 AffineExpr sumExpr = s0 + s1 + s2;
800 Value sum = rewriter.create<affine::AffineApplyOp>(
801 loc, sumExpr, ValueRange{srcDim, lowPad, highPad});
802 dynamicSizes.push_back(Elt: sum);
803 }
804
805 // Allocate a buffer for the padded result.
806 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
807 rewriter, loc, padOp.getResult(), options, state,
808 /*copy=*/false);
809 if (failed(Result: tensorAlloc))
810 return failure();
811
812 // tensor::PadOp is like tensor::GenerateOp: The only difference is that
813 // only a part of the generated tensor is needed. For simplicity, we reuse
814 // the same functionality here.
815 Value filledBuffer = lowerGenerateLikeOpBody(
816 rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
817
818 // Create tensor::InsertSliceOp.
819 SmallVector<OpFoldResult> sliceSizes =
820 getMixedSizes(rewriter, loc, padOp.getSource());
821 SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
822 rewriter.getIndexAttr(1));
823 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
824 padOp, padOp.getSource(), filledBuffer,
825 /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
826
827 return success();
828 }
829};
830
831/// Bufferization of tensor.rank. Replace with memref.rank.
832struct RankOpInterface
833 : public BufferizableOpInterface::ExternalModel<RankOpInterface,
834 tensor::RankOp> {
835 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
836 const AnalysisState &state) const {
837 // The op reads the tensor's metadata but not its contents.
838 return false;
839 }
840
841 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
842 const AnalysisState &state) const {
843 return false;
844 }
845
846 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
847 const AnalysisState &state) const {
848 return {};
849 }
850
851 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
852 const BufferizationOptions &options,
853 BufferizationState &state) const {
854 auto rankOp = cast<tensor::RankOp>(op);
855 FailureOr<Value> v =
856 getBuffer(rewriter, rankOp.getTensor(), options, state);
857 if (failed(Result: v))
858 return failure();
859 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
860 *v);
861 return success();
862 }
863};
864
865/// Bufferization of tensor.reshape. Replace with memref.reshape.
866struct ReshapeOpInterface
867 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
868 tensor::ReshapeOp> {
869 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
870 const AnalysisState &state) const {
871 // Depending on the layout map, the source buffer may have to be copied.
872 auto reshapeOp = cast<tensor::ReshapeOp>(op);
873 return opOperand == reshapeOp.getShapeMutable();
874 }
875
876 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
877 const AnalysisState &state) const {
878 return false;
879 }
880
881 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
882 const AnalysisState &state) const {
883 // Only the 'source' operand aliases the result.
884 auto reshapeOp = cast<tensor::ReshapeOp>(op);
885 if (reshapeOp.getSourceMutable() != opOperand)
886 return {};
887 return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}};
888 }
889
890 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
891 const BufferizationOptions &options,
892 BufferizationState &state) const {
893 auto reshapeOp = cast<tensor::ReshapeOp>(op);
894 FailureOr<Value> srcBuffer =
895 getBuffer(rewriter, reshapeOp.getSource(), options, state);
896 FailureOr<Value> shapeBuffer =
897 getBuffer(rewriter, reshapeOp.getShape(), options, state);
898 if (failed(Result: srcBuffer) || failed(Result: shapeBuffer))
899 return failure();
900 auto maybeResultMemRefType =
901 bufferization::getBufferType(value: reshapeOp.getResult(), options, state);
902 if (failed(maybeResultMemRefType))
903 return failure();
904
905 // memref.reshape requires the source buffer to have an identity layout.
906 // If the source memref does not have an identity layout, copy the source
907 // into a new buffer with an identity layout.
908 auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
909 if (srcType && !srcType.getLayout().isIdentity()) {
910 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
911 rewriter, op->getLoc(), reshapeOp.getSource(), options, state);
912 if (failed(Result: tensorAlloc))
913 return failure();
914 auto memrefType = MemRefType::get(
915 srcType.getShape(), srcType.getElementType(), AffineMap(),
916 cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
917 srcBuffer = rewriter
918 .create<bufferization::ToBufferOp>(
919 op->getLoc(), memrefType, *tensorAlloc)
920 .getResult();
921 }
922
923 replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
924 rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
925 return success();
926 }
927
928 FailureOr<BaseMemRefType>
929 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
930 const BufferizationState &state,
931 SmallVector<Value> &invocationStack) const {
932 auto reshapeOp = cast<tensor::ReshapeOp>(op);
933 assert(value == reshapeOp.getResult() && "unexpected value provided");
934 auto maybeSourceBufferType = bufferization::getBufferType(
935 value: reshapeOp.getSource(), options, state, invocationStack);
936 if (failed(maybeSourceBufferType))
937 return failure();
938 return getMemRefTypeWithStaticIdentityLayout(
939 reshapeOp.getResult().getType(),
940 cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
941 }
942};
943
944/// Analysis of ParallelInsertSliceOp.
945struct ParallelInsertSliceOpInterface
946 : public BufferizableOpInterface::ExternalModel<
947 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
948 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
949 const AnalysisState &state) const {
950 return {};
951 }
952
953 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
954 const AnalysisState &state) const {
955 return opOperand == cast<ParallelInsertSliceOp>(op).getSourceMutable();
956 }
957
958 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
959 const AnalysisState &state) const {
960 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
961 return opOperand == parallelInsertSliceOp.getDestMutable();
962 }
963
964 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
965 const BufferizationOptions &options,
966 BufferizationState &state) const {
967 OpBuilder::InsertionGuard g(rewriter);
968 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
969 ParallelCombiningOpInterface parallelCombiningParent =
970 parallelInsertSliceOp.getParallelCombiningParent();
971
972 // Bufferize the op outside of the parallel combining terminator.
973 rewriter.setInsertionPoint(parallelCombiningParent);
974
975 // Get source and destination buffers.
976 FailureOr<Value> destBuffer =
977 getBuffer(rewriter, parallelInsertSliceOp.getDest(), options, state);
978 if (failed(Result: destBuffer))
979 return failure();
980 FailureOr<Value> srcBuffer =
981 getBuffer(rewriter, parallelInsertSliceOp.getSource(), options, state);
982 if (failed(Result: srcBuffer))
983 return failure();
984
985 // Take a subview of the destination buffer.
986 auto destBufferType = cast<MemRefType>(destBuffer->getType());
987 MemRefType subviewMemRefType =
988 memref::SubViewOp::inferRankReducedResultType(
989 parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
990 parallelInsertSliceOp.getMixedOffsets(),
991 parallelInsertSliceOp.getMixedSizes(),
992 parallelInsertSliceOp.getMixedStrides());
993 Value subview = rewriter.create<memref::SubViewOp>(
994 parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
995 parallelInsertSliceOp.getMixedOffsets(),
996 parallelInsertSliceOp.getMixedSizes(),
997 parallelInsertSliceOp.getMixedStrides());
998
999 // This memcpy will fold away if everything bufferizes in-place.
1000 if (failed(options.createMemCpy(b&: rewriter, loc: parallelInsertSliceOp.getLoc(),
1001 from: *srcBuffer, to: subview)))
1002 return failure();
1003
1004 // In case the source was allocated in the same block, make sure that the
1005 // deallocation op (if any) appears after the memcpy. By default, deallocs
1006 // are placed before the terminator, but this does not work for ForallOp
1007 // because the terminator does more than just yielding a value.
1008 //
1009 // Note: This is not a problem for the destination buffer because these are
1010 // assumed to always bufferize in-place.
1011 for (Operation *user : srcBuffer->getUsers()) {
1012 if (hasEffect<MemoryEffects::Free>(user)) {
1013 if (user->getBlock() == parallelCombiningParent->getBlock())
1014 rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
1015 break;
1016 }
1017 }
1018
1019 // Delete the op.
1020 rewriter.eraseOp(op);
1021 return success();
1022 }
1023
1024 /// tensor.parallel_insert_slice op has implicit inplace behavior. We
1025 /// shouldn't create copy to resolve conflict.
1026 LogicalResult
1027 resolveConflicts(Operation *op, RewriterBase &rewriter,
1028 const AnalysisState &analysisState,
1029 const BufferizationState &bufferizationState) const {
1030 return success();
1031 }
1032};
1033
1034/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
1035/// with a linalg.map. Similar to tensor.generate.
1036struct SplatOpInterface
1037 : public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1038 tensor::SplatOp> {
1039
1040 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1041
1042 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1043 const BufferizationOptions &options,
1044 BufferizationState &state) const {
1045 OpBuilder::InsertionGuard g(rewriter);
1046 auto splatOp = cast<tensor::SplatOp>(op);
1047
1048 // Allocate memory.
1049 Location loc = op->getLoc();
1050 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1051 rewriter, loc, splatOp.getResult(), options, state,
1052 /*copy=*/false);
1053 if (failed(Result: tensorAlloc))
1054 return failure();
1055
1056 // Create linalg::MapOp.
1057 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1058
1059 // TODO: Implement memory space for this op.
1060 if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1061 return op->emitError(message: "memory space not implemented yet");
1062
1063 auto linalgOp =
1064 rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
1065 /*init=*/*tensorAlloc);
1066 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1067
1068 // Create linalg::IndexOps.
1069 rewriter.setInsertionPointToStart(&linalgBody);
1070 rewriter.create<linalg::YieldOp>(loc, splatOp.getInput());
1071 rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
1072
1073 return success();
1074 }
1075};
1076
1077/// Bufferization of tensor.concat. Bufferizes to a new allocation that is
1078/// filled with copy ops. Similar to tensor.from_elements, but using memref.copy
1079/// on subviews instead of memref.store.
1080struct ConcatOpInterface
1081 : public BufferizableOpInterface::ExternalModel<ConcatOpInterface,
1082 tensor::ConcatOp> {
1083
1084 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1085
1086 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1087 const AnalysisState &state) const {
1088 return false;
1089 }
1090
1091 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1092 const AnalysisState &state) const {
1093 return true;
1094 }
1095
1096 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1097 const AnalysisState &state) const {
1098 return {};
1099 }
1100
1101 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1102 const BufferizationOptions &options,
1103 BufferizationState &state) const {
1104 OpBuilder::InsertionGuard g(rewriter);
1105 auto concatOp = cast<tensor::ConcatOp>(op);
1106
1107 // Allocate memory.
1108 Location loc = op->getLoc();
1109 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1110 rewriter, loc, concatOp.getResult(), options, state,
1111 /*copy=*/false);
1112 if (failed(Result: tensorAlloc))
1113 return failure();
1114 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1115
1116 // TODO: Implement memory space for this op.
1117 if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1118 return op->emitError(message: "memory space not implemented yet");
1119
1120 MemRefLayoutAttrInterface layout;
1121 MemRefType memrefType =
1122 MemRefType::get(concatOp.getResultType().getShape(),
1123 concatOp.getResultType().getElementType(), layout);
1124 Value dstBuffer = rewriter.create<bufferization::ToBufferOp>(
1125 op->getLoc(), memrefType, *tensorAlloc);
1126
1127 // Extract the dimension for the concat op
1128 uint64_t concatDim = concatOp.getDim();
1129 bool dynamicConcatDim = false;
1130
1131 SmallVector<OpFoldResult> offsets(tensorType.getRank(),
1132 rewriter.getIndexAttr(0));
1133 SmallVector<OpFoldResult> strides(tensorType.getRank(),
1134 rewriter.getIndexAttr(1));
1135 SmallVector<OpFoldResult> sizes;
1136
1137 for (const auto &[dimIdx, dimSize] :
1138 llvm::enumerate(tensorType.getShape())) {
1139 if (dimSize == ShapedType::kDynamic) {
1140 auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx);
1141 sizes.push_back(dimOp.getResult());
1142 if (dimIdx == concatDim)
1143 dynamicConcatDim = true;
1144 } else {
1145 sizes.push_back(rewriter.getIndexAttr(dimSize));
1146 }
1147 }
1148
1149 int64_t concatDimOffset = 0;
1150 std::optional<Value> dynamicOffset;
1151 std::optional<Value> dynamicSize;
1152 if (dynamicConcatDim) {
1153 // One or more operands have dynamic size, so we must accumulate the
1154 // offset with arith ops.
1155 dynamicOffset = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1156 }
1157
1158 for (auto operand : concatOp.getInputs()) {
1159 // Get the buffer for the operand.
1160 FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
1161 if (failed(srcBuffer))
1162 return failure();
1163
1164 // Each operand may have a different size along the concat dimension,
1165 // so the offset on that axis must accumulate through the loop, and the
1166 // size must change to the size of the current operand.
1167 auto operandTensorType = cast<RankedTensorType>(operand.getType());
1168 int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
1169
1170 if (dynamicConcatDim) {
1171 offsets[concatDim] = dynamicOffset.value();
1172 dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim)
1173 .getResult();
1174 sizes[concatDim] = dynamicSize.value();
1175 } else {
1176 sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
1177 offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
1178 }
1179
1180 // Create a subview of the destination buffer.
1181 auto dstMemrefType = cast<MemRefType>(memrefType);
1182 MemRefType subviewMemRefType =
1183 memref::SubViewOp::inferRankReducedResultType(
1184 operandTensorType.getShape(), dstMemrefType, offsets, sizes,
1185 strides);
1186 Value subview = rewriter.create<memref::SubViewOp>(
1187 loc, subviewMemRefType, dstBuffer, offsets, sizes, strides);
1188
1189 // Copy the source buffer into the destination subview.
1190 if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
1191 return failure();
1192
1193 if (dynamicConcatDim) {
1194 dynamicOffset = rewriter.create<arith::AddIOp>(
1195 loc, dynamicOffset.value(), dynamicSize.value());
1196 } else {
1197 concatDimOffset += operandConcatDimSize;
1198 }
1199 }
1200
1201 replaceOpWithBufferizedValues(rewriter, op, values: dstBuffer);
1202 return success();
1203 }
1204};
1205
1206} // namespace
1207} // namespace tensor
1208} // namespace mlir
1209
1210void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
1211 DialectRegistry &registry) {
1212 registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1213 CastOp::attachInterface<CastOpInterface>(*ctx);
1214 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1215 ConcatOp::attachInterface<ConcatOpInterface>(*ctx);
1216 DimOp::attachInterface<DimOpInterface>(*ctx);
1217 EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1218 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1219 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1220 ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1221 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1222 GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1223 InsertOp::attachInterface<InsertOpInterface>(*ctx);
1224 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1225 PadOp::attachInterface<PadOpInterface>(*ctx);
1226 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1227 *ctx);
1228 RankOp::attachInterface<RankOpInterface>(*ctx);
1229 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1230 SplatOp::attachInterface<SplatOpInterface>(*ctx);
1231
1232 // Load additional dialects of which ops may get created.
1233 ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
1234 });
1235
1236 // Bufferization requires SubsetInsertionOpInterface models. Make sure that
1237 // they are registered.
1238 tensor::registerSubsetOpInterfaceExternalModels(registry);
1239}
1240

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp