1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
16#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/Operation.h"
24
25using namespace mlir;
26using namespace mlir::bufferization;
27using namespace mlir::tensor;
28
29namespace mlir {
30namespace tensor {
31namespace {
32
33struct CastOpInterface
34 : public BufferizableOpInterface::ExternalModel<CastOpInterface,
35 tensor::CastOp> {
36 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
37 const AnalysisState &state) const {
38 return false;
39 }
40
41 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
42 const AnalysisState &state) const {
43 return false;
44 }
45
46 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
47 const AnalysisState &state) const {
48 return {{op->getResult(idx: 0), BufferRelation::Equivalent}};
49 }
50
51 FailureOr<BaseMemRefType>
52 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
53 SmallVector<Value> &invocationStack) const {
54 auto castOp = cast<tensor::CastOp>(op);
55 auto maybeSrcBufferType = bufferization::getBufferType(
56 value: castOp.getSource(), options, invocationStack);
57 if (failed(maybeSrcBufferType))
58 return failure();
59 Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
60
61 // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
62 // type in case the input is an unranked tensor type.
63
64 // Case 1: Casting an unranked tensor
65 if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
66 // When casting to a ranked tensor, we cannot infer any static offset or
67 // strides from the source. Assume fully dynamic.
68 return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
69 }
70
71 // Case 2: Casting to an unranked tensor type
72 if (isa<UnrankedTensorType>(castOp.getType())) {
73 return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
74 }
75
76 // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
77 // change.
78 auto rankedResultType = cast<RankedTensorType>(castOp.getType());
79 return MemRefType::get(
80 rankedResultType.getShape(), rankedResultType.getElementType(),
81 llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
82 }
83
84 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
85 const BufferizationOptions &options) const {
86 auto castOp = cast<tensor::CastOp>(op);
87
88 // The result buffer still has the old (pre-cast) type.
89 FailureOr<Value> resultBuffer =
90 getBuffer(rewriter, castOp.getSource(), options);
91 if (failed(result: resultBuffer))
92 return failure();
93
94 // Compute the new type.
95 auto resultMemRefType =
96 bufferization::getBufferType(value: castOp.getResult(), options);
97 if (failed(resultMemRefType))
98 return failure();
99 if (resultBuffer->getType() == *resultMemRefType) {
100 // This cast is a no-op.
101 replaceOpWithBufferizedValues(rewriter, op, values: *resultBuffer);
102 return success();
103 }
104
105 // Replace the op with a memref.cast.
106 assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
107 *resultMemRefType) &&
108 "CallOp::bufferize: cast incompatible");
109 replaceOpWithNewBufferizedOp<memref::CastOp>(
110 rewriter, op, *resultMemRefType, *resultBuffer);
111
112 return success();
113 }
114};
115
116/// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
117struct CollapseShapeOpInterface
118 : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
119 tensor::CollapseShapeOp> {
120 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
121 const AnalysisState &state) const {
122 // tensor.collapse_shape may reallocate, at which point the source buffer is
123 // copied. I.e., there will be a memory read side effect on the bufferized
124 // source. This function conservatively returns "true" because whether a
125 // copy will be created or not is not known at this point.
126 return true;
127 }
128
129 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
130 const AnalysisState &state) const {
131 return false;
132 }
133
134 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
135 const AnalysisState &state) const {
136 // TODO: CollapseShapeOp may allocate at runtime.
137 return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}};
138 }
139
140 FailureOr<BaseMemRefType>
141 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
142 SmallVector<Value> &invocationStack) const {
143 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
144 auto maybeSrcBufferType = bufferization::getBufferType(
145 value: collapseShapeOp.getSrc(), options, invocationStack);
146 if (failed(maybeSrcBufferType))
147 return failure();
148 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
149 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
150 srcBufferType, collapseShapeOp.getReassociationIndices());
151
152 if (!canBeCollapsed) {
153 // If dims cannot be collapsed, this op bufferizes to a new allocation.
154 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
155 return bufferization::getMemRefTypeWithStaticIdentityLayout(
156 tensorType: tensorResultType, memorySpace: srcBufferType.getMemorySpace());
157 }
158
159 return memref::CollapseShapeOp::computeCollapsedType(
160 srcBufferType, collapseShapeOp.getReassociationIndices());
161 }
162
163 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
164 const BufferizationOptions &options) const {
165 auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
166 RankedTensorType tensorResultType = collapseShapeOp.getResultType();
167 FailureOr<Value> maybeBuffer =
168 getBuffer(rewriter, collapseShapeOp.getSrc(), options);
169 if (failed(result: maybeBuffer))
170 return failure();
171 Value buffer = *maybeBuffer;
172 auto bufferType = cast<MemRefType>(buffer.getType());
173
174 if (tensorResultType.getRank() == 0) {
175 // 0-d collapses must go through a different op builder.
176 MemRefType resultType;
177
178 if (bufferType.getLayout().isIdentity()) {
179 // Standard layout: result type has no offset.
180 MemRefLayoutAttrInterface layout;
181 resultType = MemRefType::get({}, tensorResultType.getElementType(),
182 layout, bufferType.getMemorySpace());
183 } else {
184 // Source memref has a layout map: result type has the same offset as
185 // the source type.
186 SmallVector<int64_t> strides;
187 int64_t offset;
188 if (failed(getStridesAndOffset(bufferType, strides, offset)))
189 return failure();
190 resultType = MemRefType::get(
191 {}, tensorResultType.getElementType(),
192 StridedLayoutAttr::get(op->getContext(), offset, {}),
193 bufferType.getMemorySpace());
194 }
195
196 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
197 rewriter, op, resultType, buffer, collapseShapeOp.getReassociation());
198 return success();
199 }
200
201 // If the dims are not collapsible (due to an incompatible source layout
202 // map), force an out-of-place bufferization, i.e., a buffer copy. This
203 // newly allocated buffer will have no layout map and thus be collapsible.
204 bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
205 bufferType, collapseShapeOp.getReassociationIndices());
206 if (!canBeCollapsed) {
207 // TODO: Create alloc_tensor ops during TensorCopyInsertion.
208 AnalysisState analysisState(options);
209 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
210 rewriter, op->getLoc(), collapseShapeOp.getSrc(), options);
211 if (failed(result: tensorAlloc))
212 return failure();
213 auto memrefType =
214 MemRefType::get(collapseShapeOp.getSrcType().getShape(),
215 collapseShapeOp.getSrcType().getElementType(),
216 AffineMap(), bufferType.getMemorySpace());
217 buffer = rewriter.create<bufferization::ToMemrefOp>(
218 op->getLoc(), memrefType, *tensorAlloc);
219 }
220
221 // Result type is inferred by the builder.
222 replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
223 rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
224 return success();
225 }
226};
227
228/// Bufferization of tensor.dim. Replace with memref.dim.
229struct DimOpInterface
230 : public BufferizableOpInterface::ExternalModel<DimOpInterface,
231 tensor::DimOp> {
232 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
233 const AnalysisState &state) const {
234 // The op reads the tensor's metadata but not its contents.
235 return false;
236 }
237
238 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
239 const AnalysisState &state) const {
240 return false;
241 }
242
243 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
244 const AnalysisState &state) const {
245 return {};
246 }
247
248 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
249 const BufferizationOptions &options) const {
250 auto dimOp = cast<tensor::DimOp>(op);
251 FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
252 if (failed(result: v))
253 return failure();
254 replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
255 dimOp.getIndex());
256 return success();
257 }
258};
259
260/// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor".
261struct EmptyOpInterface
262 : public BufferizableOpInterface::ExternalModel<EmptyOpInterface,
263 tensor::EmptyOp> {
264 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
265
266 bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult,
267 const AnalysisState &state) const {
268 // The returned tensor does not have specified contents.
269 return false;
270 }
271
272 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
273 const BufferizationOptions &options) const {
274 auto emptyOp = cast<tensor::EmptyOp>(op);
275
276 // Optimization: Fold away the op if it has no uses.
277 if (op->getUses().empty()) {
278 rewriter.eraseOp(op);
279 return success();
280 }
281
282 // Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
283 FailureOr<Value> allocTensor = allocateTensorForShapedValue(
284 rewriter, op->getLoc(), emptyOp.getResult(), options, /*copy=*/false);
285 if (failed(result: allocTensor))
286 return failure();
287 rewriter.replaceOp(op, newValues: *allocTensor);
288 return success();
289 }
290};
291
292/// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
293struct ExpandShapeOpInterface
294 : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
295 tensor::ExpandShapeOp> {
296 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
297 const AnalysisState &state) const {
298 // In contrast to tensor.collapse_shape, this op can always be bufferized
299 // without a copy.
300 return false;
301 }
302
303 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
304 const AnalysisState &state) const {
305 return false;
306 }
307
308 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
309 const AnalysisState &state) const {
310 return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}};
311 }
312
313 FailureOr<BaseMemRefType>
314 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
315 SmallVector<Value> &invocationStack) const {
316 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
317 auto maybeSrcBufferType = bufferization::getBufferType(
318 value: expandShapeOp.getSrc(), options, invocationStack);
319 if (failed(maybeSrcBufferType))
320 return failure();
321 auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
322 auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
323 srcBufferType, expandShapeOp.getResultType().getShape(),
324 expandShapeOp.getReassociationIndices());
325 if (failed(maybeResultType))
326 return failure();
327 return *maybeResultType;
328 }
329
330 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
331 const BufferizationOptions &options) const {
332 auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
333 auto tensorResultType = expandShapeOp.getResultType();
334 FailureOr<Value> buffer =
335 getBuffer(rewriter, expandShapeOp.getSrc(), options);
336 if (failed(result: buffer))
337 return failure();
338
339 // Memref result type is inferred by the builder based on reassociation
340 // indices and result shape.
341 replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
342 rewriter, op, tensorResultType.getShape(), *buffer,
343 expandShapeOp.getReassociationIndices());
344 return success();
345 }
346};
347
348/// Bufferization of tensor.extract_slice. Replace with memref.subview.
349struct ExtractSliceOpInterface
350 : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
351 tensor::ExtractSliceOp> {
352 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
353 const AnalysisState &state) const {
354 return false;
355 }
356
357 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
358 const AnalysisState &state) const {
359 return false;
360 }
361
362 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
363 const AnalysisState &state) const {
364 return {{op->getOpResult(idx: 0), BufferRelation::Unknown}};
365 }
366
367 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
368 const BufferizationOptions &options) const {
369 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
370 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
371 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
372 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
373 Location loc = extractSliceOp.getLoc();
374
375 // Get source buffer.
376 FailureOr<Value> srcMemref =
377 getBuffer(rewriter, extractSliceOp.getSource(), options);
378 if (failed(result: srcMemref))
379 return failure();
380
381 // Take a subview of the source buffer.
382 auto resultMemrefType =
383 bufferization::getBufferType(value: extractSliceOp.getResult(), options);
384 if (failed(resultMemrefType))
385 return failure();
386 Value subView = rewriter.create<memref::SubViewOp>(
387 loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
388 mixedSizes, mixedStrides);
389
390 replaceOpWithBufferizedValues(rewriter, op, values: subView);
391 return success();
392 }
393
394 FailureOr<BaseMemRefType>
395 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
396 SmallVector<Value> &invocationStack) const {
397 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
398 assert(value == extractSliceOp.getResult() && "invalid value");
399 auto srcMemrefType = bufferization::getBufferType(
400 value: extractSliceOp.getSource(), options, invocationStack);
401 if (failed(srcMemrefType))
402 return failure();
403 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
404 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
405 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
406 return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
407 extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
408 mixedOffsets, mixedSizes, mixedStrides));
409 }
410};
411
412/// Bufferization of tensor.extract. Replace with memref.load.
413struct ExtractOpInterface
414 : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
415 tensor::ExtractOp> {
416 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
417 const AnalysisState &state) const {
418 return true;
419 }
420
421 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
422 const AnalysisState &state) const {
423 return false;
424 }
425
426 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
427 const AnalysisState &state) const {
428 return {};
429 }
430
431 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
432 const BufferizationOptions &options) const {
433 auto extractOp = cast<tensor::ExtractOp>(op);
434 FailureOr<Value> srcMemref =
435 getBuffer(rewriter, extractOp.getTensor(), options);
436 if (failed(result: srcMemref))
437 return failure();
438 replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
439 extractOp.getIndices());
440 return success();
441 }
442};
443
444// Implements backtracking to traverse indices of the output buffer while
445// iterating over op.elements().
446static void createStores(RewriterBase &rewriter, Location loc, int dim,
447 Value buffer, ArrayRef<int64_t> shape,
448 ArrayRef<Value> constants,
449 OperandRange::iterator &elementIt,
450 SmallVectorImpl<Value> &indices) {
451 if (dim == static_cast<int>(shape.size()) - 1) {
452 for (int i = 0; i < shape.back(); ++i) {
453 indices.back() = constants[i];
454 rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices);
455 ++elementIt;
456 }
457 return;
458 }
459 for (int i = 0; i < shape[dim]; ++i) {
460 indices[dim] = constants[i];
461 createStores(rewriter, loc, dim: dim + 1, buffer, shape, constants, elementIt,
462 indices);
463 }
464}
465
466/// Bufferization of tensor.from_elements.
467struct FromElementsOpInterface
468 : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
469 tensor::FromElementsOp> {
470
471 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
472
473 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
474 const BufferizationOptions &options) const {
475 auto fromElementsOp = cast<tensor::FromElementsOp>(op);
476 auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
477
478 // TODO: Implement memory space for this op.
479 if (options.defaultMemorySpaceFn(tensorType) != Attribute())
480 return op->emitError(message: "memory space not implemented yet");
481
482 // Allocate a buffer for the result.
483 Location loc = op->getLoc();
484 auto shape = tensorType.getShape();
485 // TODO: Create alloc_tensor ops during TensorCopyInsertion.
486 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
487 rewriter, loc, fromElementsOp.getResult(), options,
488 /*copy=*/false);
489 if (failed(result: tensorAlloc))
490 return failure();
491 auto memrefType =
492 MemRefType::get(tensorType.getShape(), tensorType.getElementType());
493 Value buffer = rewriter.create<bufferization::ToMemrefOp>(
494 op->getLoc(), memrefType, *tensorAlloc);
495
496 // Case: tensor<0xelem_type>.
497 if (fromElementsOp.getElements().empty()) {
498 replaceOpWithBufferizedValues(rewriter, op, values: buffer);
499 return success();
500 }
501
502 // Case: tensor<elem_type>.
503 if (shape.empty()) {
504 rewriter.create<memref::StoreOp>(
505 loc, fromElementsOp.getElements().front(), buffer);
506 replaceOpWithBufferizedValues(rewriter, op, values: buffer);
507 return success();
508 }
509
510 // Create constants for the range of possible indices [0, max{shape_i}).
511 auto maxDim = *llvm::max_element(shape);
512 SmallVector<Value, 2> constants;
513 constants.reserve(N: maxDim);
514 for (int i = 0; i < maxDim; ++i)
515 constants.push_back(rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i));
516
517 // Traverse all `elements` and create `memref.store` ops.
518 auto elementIt = fromElementsOp.getElements().begin();
519 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
520 createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt,
521 indices);
522
523 replaceOpWithBufferizedValues(rewriter, op, values: buffer);
524
525 return success();
526 }
527};
528
529/// Lower the body of a tensor.generate like op (one index-typed bbArg per dim).
530/// Such ops are lowered to linalg.map with the given tensor as a destination.
531///
532/// Example:
533/// ```
534/// %r = tensor.generate %x, %y {
535/// ^bb0(%arg0: index, %arg1: index):
536/// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index)
537/// tensor.yield %0 : index
538/// } : tensor<?x?xindex>
539/// ```
540///
541/// Is lowered to:
542/// ```
543/// linalg.map ins() outs(%dest) {
544/// %d0 = linalg.index 0 : index
545/// %d1 = linalg.index 1 : index
546/// %0 = "some_op"(%d0, %d1) : (index, index) -> (index)
547/// linalg.yield %0 : index
548/// }
549/// ```
550static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
551 Value tensorDestination,
552 ValueRange dynamicSizes,
553 Region &generateBody) {
554 assert(generateBody.hasOneBlock() && "expected body with single block");
555 auto tensorType = cast<RankedTensorType>(tensorDestination.getType());
556 assert(generateBody.getNumArguments() == tensorType.getRank() &&
557 "rank mismatch");
558
559 // Create linalg::MapOp.
560 OpBuilder::InsertionGuard g(rewriter);
561 auto linalgOp =
562 rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
563 /*init=*/tensorDestination);
564 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
565
566 // Create linalg::IndexOps.
567 rewriter.setInsertionPointToStart(&linalgBody);
568 SmallVector<Value> indices;
569 for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
570 indices.push_back(rewriter.create<linalg::IndexOp>(loc, dim));
571
572 // Move over body.
573 rewriter.mergeBlocks(source: &generateBody.front(), dest: &linalgBody, argValues: indices);
574 auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
575 rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
576
577 return linalgOp.getResult()[0];
578}
579
580/// Bufferization of tensor.generate.
581struct GenerateOpInterface
582 : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
583 tensor::GenerateOp> {
584
585 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
586
587 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
588 const BufferizationOptions &options) const {
589 auto generateOp = cast<tensor::GenerateOp>(op);
590
591 auto type = generateOp.getResult().getType();
592
593 // TODO: Implement memory space for this op.
594 if (options.defaultMemorySpaceFn(type) != Attribute())
595 return op->emitError(message: "memory space not implemented yet");
596
597 // Allocate memory.
598 Location loc = op->getLoc();
599 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
600 rewriter, loc, generateOp.getResult(), options,
601 /*copy=*/false);
602 if (failed(result: tensorAlloc))
603 return failure();
604
605 Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
606 generateOp.getDynamicExtents(),
607 generateOp.getBody());
608 rewriter.replaceOp(generateOp, result);
609
610 return success();
611 }
612};
613
614/// Bufferization of tensor.insert. Replace with memref.store.
615///
616/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
617/// implementations for DestinationStyle ops.
618struct InsertOpInterface
619 : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
620 tensor::InsertOp> {
621 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
622 const BufferizationOptions &options) const {
623 auto insertOp = cast<tensor::InsertOp>(op);
624 FailureOr<Value> destMemref =
625 getBuffer(rewriter, insertOp.getDest(), options);
626 if (failed(result: destMemref))
627 return failure();
628 rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
629 *destMemref, insertOp.getIndices());
630 replaceOpWithBufferizedValues(rewriter, op, values: *destMemref);
631 return success();
632 }
633};
634
635/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
636/// certain circumstances, this op can also be a no-op.
637///
638/// Note: DstBufferizableOpInterfaceExternalModel provides many default method
639/// implementations for DestinationStyle ops.
640struct InsertSliceOpInterface
641 : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
642 tensor::InsertSliceOp> {
643 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
644 const AnalysisState &state) const {
645 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
646 RankedTensorType destType = insertSliceOp.getDestType();
647
648 // The source is always read.
649 if (opOperand == insertSliceOp.getSourceMutable())
650 return true;
651
652 // For the destination, it depends...
653 assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
654
655 // Dest is not read if it is entirely overwritten. E.g.:
656 // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
657 bool allOffsetsZero =
658 llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
659 return isConstantIntValue(ofr, value: 0);
660 });
661 bool sizesMatchDestSizes = llvm::all_of(
662 llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
663 return getConstantIntValue(it.value()) ==
664 destType.getDimSize(it.index());
665 });
666 bool allStridesOne =
667 llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
668 return isConstantIntValue(ofr, value: 1);
669 });
670 return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
671 }
672
673 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
674 const BufferizationOptions &options) const {
675 // insert_slice ops arise from tiling and bufferizing them out-of-place is
676 // generally a deal breaker. When used with loops, this ends up cloning the
677 // whole tensor on every single iteration and is a symptom of a
678 // catastrophically bad scheduling decision.
679 // TODO: be very loud about it or even consider failing the pass.
680 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
681 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
682 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
683 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
684 Location loc = insertSliceOp.getLoc();
685
686 // Get destination buffer.
687 FailureOr<Value> dstMemref =
688 getBuffer(rewriter, insertSliceOp.getDest(), options);
689 if (failed(result: dstMemref))
690 return failure();
691
692 // Take a subview of the destination buffer.
693 auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
694 auto subviewMemRefType =
695 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
696 insertSliceOp.getSourceType().getShape(), dstMemrefType,
697 mixedOffsets, mixedSizes, mixedStrides));
698 Value subView = rewriter.create<memref::SubViewOp>(
699 loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
700 mixedStrides);
701
702 // Copy tensor. If this tensor.insert_slice has a matching
703 // tensor.extract_slice, the copy operation will eventually fold away.
704 FailureOr<Value> srcMemref =
705 getBuffer(rewriter, insertSliceOp.getSource(), options);
706 if (failed(result: srcMemref))
707 return failure();
708 if (failed(result: options.createMemCpy(b&: rewriter, loc, from: *srcMemref, to: subView)))
709 return failure();
710
711 replaceOpWithBufferizedValues(rewriter, op, values: *dstMemref);
712 return success();
713 }
714};
715
716/// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor +
717/// linalg.map + insert_slice.
718/// For best performance, vectorize before bufferization (better performance in
719/// case of padding with a constant).
720struct PadOpInterface
721 : public BufferizableOpInterface::ExternalModel<PadOpInterface,
722 tensor::PadOp> {
723 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
724
725 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
726 const AnalysisState &state) const {
727 return true;
728 }
729
730 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
731 const AnalysisState &state) const {
732 return false;
733 }
734
735 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
736 const AnalysisState &state) const {
737 return {};
738 }
739
740 FailureOr<BaseMemRefType>
741 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
742 SmallVector<Value> &invocationStack) const {
743 // Infer memory space from the source tensor.
744 auto padOp = cast<tensor::PadOp>(op);
745 auto maybeSrcBufferType = bufferization::getBufferType(
746 value: padOp.getSource(), options, invocationStack);
747 if (failed(maybeSrcBufferType))
748 return failure();
749 MemRefLayoutAttrInterface layout;
750 return MemRefType::get(padOp.getResultType().getShape(),
751 padOp.getResultType().getElementType(), layout,
752 maybeSrcBufferType->getMemorySpace());
753 }
754
755 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
756 const BufferizationOptions &options) const {
757 auto padOp = cast<tensor::PadOp>(op);
758 Location loc = padOp.getLoc();
759 RankedTensorType resultType = padOp.getResultType();
760 RankedTensorType srcType = padOp.getSourceType();
761
762 auto toValue = [&](OpFoldResult ofr) {
763 if (ofr.is<Value>())
764 return ofr.get<Value>();
765 return rewriter
766 .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr))
767 .getResult();
768 };
769
770 // Compute dynamic result dimensions.
771 SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad();
772 SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad();
773 SmallVector<Value> dynamicSizes;
774 for (int64_t i = 0; i < resultType.getRank(); ++i) {
775 if (!resultType.isDynamicDim(i))
776 continue;
777 Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i);
778 Value lowPad = toValue(mixedLowPad[i]);
779 Value highPad = toValue(mixedHighPad[i]);
780 AffineExpr s0, s1, s2;
781 bindSymbols(ctx: op->getContext(), exprs&: s0, exprs&: s1, exprs&: s2);
782 AffineExpr sumExpr = s0 + s1 + s2;
783 Value sum = rewriter.create<affine::AffineApplyOp>(
784 loc, sumExpr, ValueRange{srcDim, lowPad, highPad});
785 dynamicSizes.push_back(Elt: sum);
786 }
787
788 // Allocate a buffer for the padded result.
789 FailureOr<Value> tensorAlloc =
790 allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), options,
791 /*copy=*/false);
792 if (failed(result: tensorAlloc))
793 return failure();
794
795 // tensor::PadOp is like tensor::GenerateOp: The only difference is that
796 // only a part of the generated tensor is needed. For simplicity, we reuse
797 // the same functionality here.
798 Value filledBuffer = lowerGenerateLikeOpBody(
799 rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion());
800
801 // Create tensor::InsertSliceOp.
802 SmallVector<OpFoldResult> sliceSizes =
803 getMixedSizes(rewriter, loc, padOp.getSource());
804 SmallVector<OpFoldResult> sliceStrides(srcType.getRank(),
805 rewriter.getIndexAttr(1));
806 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
807 padOp, padOp.getSource(), filledBuffer,
808 /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
809
810 return success();
811 }
812};
813
814/// Bufferization of tensor.rank. Replace with memref.rank.
815struct RankOpInterface
816 : public BufferizableOpInterface::ExternalModel<RankOpInterface,
817 tensor::RankOp> {
818 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
819 const AnalysisState &state) const {
820 // The op reads the tensor's metadata but not its contents.
821 return false;
822 }
823
824 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
825 const AnalysisState &state) const {
826 return false;
827 }
828
829 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
830 const AnalysisState &state) const {
831 return {};
832 }
833
834 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
835 const BufferizationOptions &options) const {
836 auto rankOp = cast<tensor::RankOp>(op);
837 FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
838 if (failed(result: v))
839 return failure();
840 replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
841 *v);
842 return success();
843 }
844};
845
846/// Bufferization of tensor.reshape. Replace with memref.reshape.
847struct ReshapeOpInterface
848 : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface,
849 tensor::ReshapeOp> {
850 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
851 const AnalysisState &state) const {
852 // Depending on the layout map, the source buffer may have to be copied.
853 auto reshapeOp = cast<tensor::ReshapeOp>(op);
854 return opOperand == reshapeOp.getShapeMutable();
855 }
856
857 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
858 const AnalysisState &state) const {
859 return false;
860 }
861
862 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
863 const AnalysisState &state) const {
864 return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}};
865 }
866
867 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
868 const BufferizationOptions &options) const {
869 auto reshapeOp = cast<tensor::ReshapeOp>(op);
870 FailureOr<Value> srcBuffer =
871 getBuffer(rewriter, reshapeOp.getSource(), options);
872 FailureOr<Value> shapeBuffer =
873 getBuffer(rewriter, reshapeOp.getShape(), options);
874 if (failed(result: srcBuffer) || failed(result: shapeBuffer))
875 return failure();
876 auto maybeResultMemRefType =
877 bufferization::getBufferType(value: reshapeOp.getResult(), options);
878 if (failed(maybeResultMemRefType))
879 return failure();
880
881 // memref.reshape requires the source buffer to have an identity layout.
882 // If the source memref does not have an identity layout, copy the source
883 // into a new buffer with an identity layout.
884 auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
885 if (srcType && !srcType.getLayout().isIdentity()) {
886 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
887 rewriter, op->getLoc(), reshapeOp.getSource(), options);
888 if (failed(result: tensorAlloc))
889 return failure();
890 auto memrefType = MemRefType::get(
891 srcType.getShape(), srcType.getElementType(), AffineMap(),
892 cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
893 srcBuffer = rewriter
894 .create<bufferization::ToMemrefOp>(
895 op->getLoc(), memrefType, *tensorAlloc)
896 .getResult();
897 }
898
899 replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
900 rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
901 return success();
902 }
903
904 FailureOr<BaseMemRefType>
905 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
906 SmallVector<Value> &invocationStack) const {
907 auto reshapeOp = cast<tensor::ReshapeOp>(op);
908 assert(value == reshapeOp.getResult() && "unexpected value provided");
909 auto maybeSourceBufferType = bufferization::getBufferType(
910 value: reshapeOp.getSource(), options, invocationStack);
911 if (failed(maybeSourceBufferType))
912 return failure();
913 return getMemRefTypeWithStaticIdentityLayout(
914 reshapeOp.getResult().getType(),
915 cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
916 }
917};
918
919/// Analysis of ParallelInsertSliceOp.
920struct ParallelInsertSliceOpInterface
921 : public BufferizableOpInterface::ExternalModel<
922 ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
923 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
924 const AnalysisState &state) const {
925 return {};
926 }
927
928 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
929 const AnalysisState &state) const {
930 return true;
931 }
932
933 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
934 const AnalysisState &state) const {
935 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
936 return opOperand == parallelInsertSliceOp.getDestMutable();
937 }
938
939 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
940 const BufferizationOptions &options) const {
941 OpBuilder::InsertionGuard g(rewriter);
942 auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
943 ParallelCombiningOpInterface parallelCombiningParent =
944 parallelInsertSliceOp.getParallelCombiningParent();
945
946 // Bufferize the op outside of the parallel combining terminator.
947 rewriter.setInsertionPoint(parallelCombiningParent);
948
949 // Get source and destination buffers.
950 FailureOr<Value> destBuffer =
951 getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
952 if (failed(result: destBuffer))
953 return failure();
954 FailureOr<Value> srcBuffer =
955 getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
956 if (failed(result: srcBuffer))
957 return failure();
958
959 // Take a subview of the destination buffer.
960 auto destBufferType = cast<MemRefType>(destBuffer->getType());
961 auto subviewMemRefType =
962 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
963 parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
964 parallelInsertSliceOp.getMixedOffsets(),
965 parallelInsertSliceOp.getMixedSizes(),
966 parallelInsertSliceOp.getMixedStrides()));
967 Value subview = rewriter.create<memref::SubViewOp>(
968 parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
969 parallelInsertSliceOp.getMixedOffsets(),
970 parallelInsertSliceOp.getMixedSizes(),
971 parallelInsertSliceOp.getMixedStrides());
972
973 // This memcpy will fold away if everything bufferizes in-place.
974 if (failed(options.createMemCpy(b&: rewriter, loc: parallelInsertSliceOp.getLoc(),
975 from: *srcBuffer, to: subview)))
976 return failure();
977
978 // In case the source was allocated in the same block, make sure that the
979 // deallocation op (if any) appears after the memcpy. By default, deallocs
980 // are placed before the terminator, but this does not work for ForallOp
981 // because the terminator does more than just yielding a value.
982 //
983 // Note: This is not a problem for the destination buffer because these are
984 // assumed to always bufferize in-place.
985 for (Operation *user : srcBuffer->getUsers()) {
986 if (hasEffect<MemoryEffects::Free>(user)) {
987 if (user->getBlock() == parallelCombiningParent->getBlock())
988 rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
989 break;
990 }
991 }
992
993 // Delete the op.
994 rewriter.eraseOp(op);
995 return success();
996 }
997};
998
999/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
1000/// with a linalg.map. Similar to tensor.generate.
1001struct SplatOpInterface
1002 : public BufferizableOpInterface::ExternalModel<SplatOpInterface,
1003 tensor::SplatOp> {
1004
1005 bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
1006
1007 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1008 const BufferizationOptions &options) const {
1009 OpBuilder::InsertionGuard g(rewriter);
1010 auto splatOp = cast<tensor::SplatOp>(op);
1011
1012 // Allocate memory.
1013 Location loc = op->getLoc();
1014 FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
1015 rewriter, loc, splatOp.getResult(), options,
1016 /*copy=*/false);
1017 if (failed(result: tensorAlloc))
1018 return failure();
1019
1020 // Create linalg::MapOp.
1021 auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
1022
1023 // TODO: Implement memory space for this op.
1024 if (options.defaultMemorySpaceFn(tensorType) != Attribute())
1025 return op->emitError(message: "memory space not implemented yet");
1026
1027 auto linalgOp =
1028 rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
1029 /*init=*/*tensorAlloc);
1030 Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1031
1032 // Create linalg::IndexOps.
1033 rewriter.setInsertionPointToStart(&linalgBody);
1034 rewriter.create<linalg::YieldOp>(loc, splatOp.getInput());
1035 rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
1036
1037 return success();
1038 }
1039};
1040
1041} // namespace
1042} // namespace tensor
1043} // namespace mlir
1044
1045void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
1046 DialectRegistry &registry) {
1047 registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
1048 CastOp::attachInterface<CastOpInterface>(*ctx);
1049 CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
1050 DimOp::attachInterface<DimOpInterface>(*ctx);
1051 EmptyOp::attachInterface<EmptyOpInterface>(*ctx);
1052 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
1053 ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
1054 ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
1055 FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
1056 GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
1057 InsertOp::attachInterface<InsertOpInterface>(*ctx);
1058 InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
1059 PadOp::attachInterface<PadOpInterface>(*ctx);
1060 ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1061 *ctx);
1062 RankOp::attachInterface<RankOpInterface>(*ctx);
1063 ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
1064 SplatOp::attachInterface<SplatOpInterface>(*ctx);
1065
1066 // Load additional dialects of which ops may get created.
1067 ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
1068 });
1069
1070 // Bufferization requires SubsetInsertionOpInterface models. Make sure that
1071 // they are registered.
1072 tensor::registerSubsetOpInterfaceExternalModels(registry);
1073}
1074

source code of mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp