1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/Matchers.h"
17#include <optional>
18
19using namespace mlir;
20using namespace mlir::bufferization;
21
22//===----------------------------------------------------------------------===//
23// Helper functions
24//===----------------------------------------------------------------------===//
25
26FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
27 OpBuilder &b, Value value, MemRefType destType,
28 const BufferizationOptions &options) {
29 auto srcType = llvm::cast<MemRefType>(value.getType());
30
31 // Element type, rank and memory space must match.
32 if (srcType.getElementType() != destType.getElementType())
33 return failure();
34 if (srcType.getMemorySpace() != destType.getMemorySpace())
35 return failure();
36 if (srcType.getRank() != destType.getRank())
37 return failure();
38
39 // In case the affine maps are different, we may need to use a copy if we go
40 // from dynamic to static offset or stride (the canonicalization cannot know
41 // at this point that it is really cast compatible).
42 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
43 int64_t sourceOffset, targetOffset;
44 SmallVector<int64_t, 4> sourceStrides, targetStrides;
45 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46 failed(getStridesAndOffset(target, targetStrides, targetOffset)))
47 return false;
48 auto dynamicToStatic = [](int64_t a, int64_t b) {
49 return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
50 };
51 if (dynamicToStatic(sourceOffset, targetOffset))
52 return false;
53 for (auto it : zip(t&: sourceStrides, u&: targetStrides))
54 if (dynamicToStatic(std::get<0>(t&: it), std::get<1>(t&: it)))
55 return false;
56 return true;
57 };
58
59 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
60 // ensure that we only generate casts that always succeed at runtime, we check
61 // a fix extra conditions in `isGuaranteedCastCompatible`.
62 if (memref::CastOp::areCastCompatible(srcType, destType) &&
63 isGuaranteedCastCompatible(srcType, destType)) {
64 Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value);
65 return casted;
66 }
67
68 auto loc = value.getLoc();
69 SmallVector<Value, 4> dynamicOperands;
70 for (int i = 0; i < destType.getRank(); ++i) {
71 if (destType.getShape()[i] != ShapedType::kDynamic)
72 continue;
73 Value size = b.create<memref::DimOp>(loc, value, i);
74 dynamicOperands.push_back(Elt: size);
75 }
76
77 FailureOr<Value> copy =
78 options.createAlloc(b, loc, type: destType, dynShape: dynamicOperands);
79 if (failed(result: copy))
80 return failure();
81 if (failed(result: options.createMemCpy(b, loc, from: value, to: *copy)))
82 return failure();
83 return copy;
84}
85
86/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
87/// to_memref op are different, a memref.cast is needed.
88LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
89 RewriterBase &rewriter, ToMemrefOp toMemref,
90 const BufferizationOptions &options) {
91 auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
92 if (!memrefToTensor)
93 return failure();
94
95 Type srcType = memrefToTensor.getMemref().getType();
96 Type destType = toMemref.getType();
97
98 // Directly rewrite if the type did not change.
99 if (srcType == destType) {
100 rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
101 return success();
102 }
103
104 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
105 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
106 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
107
108 // Ranked memref -> Ranked memref cast.
109 if (rankedSrcType && rankedDestType) {
110 FailureOr<Value> replacement = castOrReallocMemRefValue(
111 rewriter, memrefToTensor.getMemref(), rankedDestType, options);
112 if (failed(result: replacement))
113 return failure();
114
115 rewriter.replaceOp(toMemref, *replacement);
116 return success();
117 }
118
119 // Unranked memref -> Ranked memref cast: May require a copy.
120 // TODO: Not implemented at the moment.
121 if (unrankedSrcType && rankedDestType)
122 return failure();
123
124 // Unranked memref -> unranked memref cast
125 // Ranked memref -> unranked memref cast: No copy needed.
126 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
127 "expected that types are cast compatible");
128 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
129 memrefToTensor.getMemref());
130 return success();
131}
132
133void mlir::bufferization::populateDynamicDimSizes(
134 OpBuilder &b, Location loc, Value shapedValue,
135 SmallVector<Value> &dynamicDims) {
136 auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
137 for (int64_t i = 0; i < shapedType.getRank(); ++i) {
138 if (shapedType.isDynamicDim(i)) {
139 if (llvm::isa<MemRefType>(shapedType)) {
140 dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
141 } else {
142 assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
143 dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
144 }
145 }
146 }
147}
148
149//===----------------------------------------------------------------------===//
150// AllocTensorOp
151//===----------------------------------------------------------------------===//
152
153LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
154 const BufferizationOptions &options) {
155 OpBuilder::InsertionGuard g(rewriter);
156 Location loc = getLoc();
157
158 // Nothing to do for dead AllocTensorOps.
159 if (getOperation()->getUses().empty()) {
160 rewriter.eraseOp(getOperation());
161 return success();
162 }
163
164 // Get "copy" buffer.
165 Value copyBuffer;
166 if (getCopy()) {
167 FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
168 if (failed(maybeCopyBuffer))
169 return failure();
170 copyBuffer = *maybeCopyBuffer;
171 }
172
173 // Create memory allocation.
174 auto allocType = bufferization::getBufferType(getResult(), options);
175 if (failed(allocType))
176 return failure();
177 SmallVector<Value> dynamicDims = getDynamicSizes();
178 if (getCopy()) {
179 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
180 populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
181 }
182 FailureOr<Value> alloc = options.createAlloc(
183 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
184 if (failed(alloc))
185 return failure();
186
187 // Create memory copy (if any).
188 if (getCopy()) {
189 if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
190 return failure();
191 }
192
193 // Replace op.
194 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
195
196 return success();
197}
198
199bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
200 const AnalysisState &state) {
201 // AllocTensorOps do not write unless they have a `copy` value.
202 return static_cast<bool>(getCopy());
203}
204
205bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
206 const AnalysisState &state) {
207 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
208 "expected copy operand");
209 return true;
210}
211
212bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
213 const AnalysisState &state) {
214 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
215 "expected copy operand");
216 return false;
217}
218
219AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
220 const AnalysisState &state) {
221 // This is a new allocation. It does not alias with any other buffer.
222 return {};
223}
224
225FailureOr<BaseMemRefType>
226AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
227 SmallVector<Value> &invocationStack) {
228 assert(value == getResult() && "invalid value");
229
230 // Compute memory space of this allocation.
231 Attribute memorySpace;
232 if (getMemorySpace().has_value()) {
233 memorySpace = *getMemorySpace();
234 } else if (getCopy()) {
235 auto copyBufferType =
236 bufferization::getBufferType(getCopy(), options, invocationStack);
237 if (failed(copyBufferType))
238 return failure();
239 memorySpace = copyBufferType->getMemorySpace();
240 } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
241 memorySpace = *ms;
242 } else {
243 return getOperation()->emitError("could not infer memory space");
244 }
245
246 return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
247}
248
249LogicalResult AllocTensorOp::verify() {
250 if (getCopy() && !getDynamicSizes().empty())
251 return emitError("dynamic sizes not needed when copying a tensor");
252 if (!getCopy() && getType().getNumDynamicDims() !=
253 static_cast<int64_t>(getDynamicSizes().size()))
254 return emitError("expected ")
255 << getType().getNumDynamicDims() << " dynamic sizes";
256 if (getCopy() && getCopy().getType() != getType())
257 return emitError("expected that `copy` and return type match");
258 return success();
259}
260
261void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
262 RankedTensorType type, ValueRange dynamicSizes) {
263 build(builder, result, type, dynamicSizes, /*copy=*/Value(),
264 /*size_hint=*/Value(),
265 /*memory_space=*/IntegerAttr());
266}
267
268void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
269 RankedTensorType type, ValueRange dynamicSizes,
270 Value copy) {
271 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
272 /*memory_space=*/IntegerAttr());
273}
274
275void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
276 TensorType type, ValueRange dynamicSizes, Value copy,
277 IntegerAttr memorySpace) {
278 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
279 memorySpace);
280}
281
282namespace {
283/// Change the type of the result of a `bufferization.alloc_tensor` by making
284/// the result type statically sized along dimension that in the original
285/// operation where defined as dynamic, but the size was defined using a
286/// `constant` op. For example:
287///
288/// %c5 = arith.constant 5: index
289/// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
290///
291/// to
292///
293/// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
294struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
295 using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
296
297 LogicalResult matchAndRewrite(AllocTensorOp op,
298 PatternRewriter &rewriter) const override {
299 if (op.getCopy())
300 return failure();
301 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
302 SmallVector<Value> newDynamicSizes;
303 unsigned int dynValCounter = 0;
304 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
305 if (!op.isDynamicDim(i))
306 continue;
307 Value value = op.getDynamicSizes()[dynValCounter++];
308 APInt intVal;
309 if (matchPattern(value, m_ConstantInt(&intVal))) {
310 int64_t dim = intVal.getSExtValue();
311 if (dim >= 0)
312 newShape[i] = intVal.getSExtValue();
313 else
314 newDynamicSizes.push_back(Elt: value);
315 } else {
316 newDynamicSizes.push_back(Elt: value);
317 }
318 }
319 RankedTensorType newType = RankedTensorType::get(
320 newShape, op.getType().getElementType(), op.getType().getEncoding());
321 if (newType == op.getType())
322 return failure();
323 auto newOp = rewriter.create<AllocTensorOp>(
324 op.getLoc(), newType, newDynamicSizes, /*copy=*/Value());
325 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
326 return success();
327 }
328};
329
330struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
331 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
332
333 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
334 PatternRewriter &rewriter) const override {
335 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
336 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
337 if (!allocTensorOp || !maybeConstantIndex)
338 return failure();
339 if (*maybeConstantIndex < 0 ||
340 *maybeConstantIndex >= allocTensorOp.getType().getRank())
341 return failure();
342 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
343 return failure();
344 rewriter.replaceOp(
345 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
346 return success();
347 }
348};
349} // namespace
350
351void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
352 MLIRContext *ctx) {
353 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
354}
355
356LogicalResult AllocTensorOp::reifyResultShapes(
357 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
358 auto shapes = llvm::to_vector<4>(
359 llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
360 [&](int64_t dim) -> OpFoldResult {
361 if (isDynamicDim(dim))
362 return getDynamicSize(builder, dim);
363 return builder.getIndexAttr(getStaticSize(dim));
364 }));
365 reifiedReturnShapes.emplace_back(std::move(shapes));
366 return success();
367}
368
369ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
370 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
371 if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
372 parser.parseRParen())
373 return failure();
374 ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
375 OpAsmParser::UnresolvedOperand copyOperand;
376 if (copyKeyword.succeeded())
377 if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
378 parser.parseRParen())
379 return failure();
380 ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
381 OpAsmParser::UnresolvedOperand sizeHintOperand;
382 if (sizeHintKeyword.succeeded())
383 if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
384 return failure();
385 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
386 return failure();
387
388 TensorType type;
389 if (parser.parseCustomTypeWithFallback(type))
390 return failure();
391 result.addTypes(type);
392
393 Type indexType = parser.getBuilder().getIndexType();
394 if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
395 return failure();
396 if (copyKeyword.succeeded())
397 if (parser.resolveOperand(copyOperand, type, result.operands))
398 return failure();
399 if (sizeHintKeyword.succeeded())
400 if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
401 return failure();
402 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
403 parser.getBuilder().getDenseI32ArrayAttr(
404 {static_cast<int32_t>(dynamicSizesOperands.size()),
405 static_cast<int32_t>(copyKeyword.succeeded()),
406 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
407 return success();
408}
409
410void AllocTensorOp::print(OpAsmPrinter &p) {
411 p << "(" << getDynamicSizes() << ")";
412 if (getCopy())
413 p << " copy(" << getCopy() << ")";
414 if (getSizeHint())
415 p << " size_hint=" << getSizeHint();
416 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
417 AllocTensorOp::getOperandSegmentSizeAttr()});
418 p << " : ";
419 auto type = getResult().getType();
420 if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
421 p.printStrippedAttrOrType(validType);
422 else
423 p << type;
424}
425
426Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
427 assert(isDynamicDim(idx) && "expected dynamic dim");
428 if (getCopy())
429 return b.create<tensor::DimOp>(getLoc(), getCopy(), idx);
430 return getOperand(getIndexOfDynamicSize(idx));
431}
432
433//===----------------------------------------------------------------------===//
434// CloneOp
435//===----------------------------------------------------------------------===//
436
437OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
438 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
439}
440
441namespace {
442
443/// Merge the clone and its source (by converting the clone to a cast) when
444/// possible.
445struct SimplifyClones : public OpRewritePattern<CloneOp> {
446 using OpRewritePattern<CloneOp>::OpRewritePattern;
447
448 LogicalResult matchAndRewrite(CloneOp cloneOp,
449 PatternRewriter &rewriter) const override {
450 if (cloneOp.use_empty()) {
451 rewriter.eraseOp(op: cloneOp);
452 return success();
453 }
454
455 Value source = cloneOp.getInput();
456 if (source.getType() != cloneOp.getType() &&
457 !memref::CastOp::areCastCompatible({source.getType()},
458 {cloneOp.getType()}))
459 return failure();
460
461 // Aims to find the dealloc op for the canonical source
462 // which otherwise could prevent removal of unnecessary allocs.
463 Value canonicalSource = source;
464 while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
465 canonicalSource.getDefiningOp()))
466 canonicalSource = iface.getViewSource();
467
468 std::optional<Operation *> maybeCloneDeallocOp =
469 memref::findDealloc(allocValue: cloneOp.getOutput());
470 // Skip if either of them has > 1 deallocate operations.
471 if (!maybeCloneDeallocOp.has_value())
472 return failure();
473 std::optional<Operation *> maybeSourceDeallocOp =
474 memref::findDealloc(allocValue: canonicalSource);
475 if (!maybeSourceDeallocOp.has_value())
476 return failure();
477 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
478 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
479
480 // If both are deallocated in the same block, their in-block lifetimes
481 // might not fully overlap, so we cannot decide which one to drop.
482 if (cloneDeallocOp && sourceDeallocOp &&
483 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
484 return failure();
485
486 Block *currentBlock = cloneOp->getBlock();
487 Operation *redundantDealloc = nullptr;
488 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
489 redundantDealloc = cloneDeallocOp;
490 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
491 redundantDealloc = sourceDeallocOp;
492 }
493
494 if (!redundantDealloc)
495 return failure();
496
497 // Safety check that there are no other deallocations inbetween
498 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
499 // of source before the uses of the clone. With alias information, we could
500 // restrict this to only fail of the dealloc's operand is an alias
501 // of the source.
502 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
503 pos = pos->getNextNode()) {
504 // Bail if we run out of operations while looking for a deallocation op.
505 if (!pos)
506 return failure();
507 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
508 if (!effectInterface)
509 continue;
510 if (effectInterface.hasEffect<MemoryEffects::Free>())
511 return failure();
512 }
513
514 if (source.getType() != cloneOp.getType())
515 source = rewriter.create<memref::CastOp>(cloneOp.getLoc(),
516 cloneOp.getType(), source);
517 rewriter.replaceOp(cloneOp, source);
518 rewriter.eraseOp(op: redundantDealloc);
519 return success();
520 }
521};
522
523} // namespace
524
525void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
526 MLIRContext *context) {
527 results.add<SimplifyClones>(context);
528}
529
530//===----------------------------------------------------------------------===//
531// DeallocTensorOp
532//===----------------------------------------------------------------------===//
533
534LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
535 const BufferizationOptions &options) {
536 FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
537 if (failed(buffer))
538 return failure();
539 rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
540 rewriter.eraseOp(getOperation());
541 return success();
542}
543
544//===----------------------------------------------------------------------===//
545// MaterializeInDestinationOp
546//===----------------------------------------------------------------------===//
547
548bool MaterializeInDestinationOp::bufferizesToMemoryRead(
549 OpOperand &opOperand, const AnalysisState &state) {
550 return opOperand == getSourceMutable();
551}
552
553bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
554 OpOperand &opOperand, const AnalysisState &state) {
555 if (opOperand == getDestMutable()) {
556 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
557 return true;
558 }
559 return false;
560}
561
562bool MaterializeInDestinationOp::mustBufferizeInPlace(
563 OpOperand &opOperand, const AnalysisState &state) {
564 // The source is only read and not written, so it always bufferizes in-place
565 // by default. The destination is written and is forced to bufferize in-place
566 // (if it is a tensor).
567 return true;
568}
569
570AliasingValueList
571MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
572 const AnalysisState &state) {
573 if (opOperand == getDestMutable()) {
574 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
575 return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
576 }
577 return {};
578}
579
580LogicalResult
581MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
582 const BufferizationOptions &options) {
583 bool tensorDest = isa<TensorType>(getDest().getType());
584 Value buffer;
585 if (tensorDest) {
586 FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
587 if (failed(maybeBuffer))
588 return failure();
589 buffer = *maybeBuffer;
590 } else {
591 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
592 buffer = getDest();
593 }
594 auto srcBuffer = getBuffer(rewriter, getSource(), options);
595 if (failed(srcBuffer))
596 return failure();
597 if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
598 return failure();
599 replaceOpWithBufferizedValues(rewriter, getOperation(),
600 tensorDest ? ValueRange(buffer) : ValueRange());
601 return success();
602}
603
604bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
605 const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
606 // As elements are copied from the "source" buffer to the "dest" buffer,
607 // already copied elements are not read a second time.
608 return true;
609}
610
611LogicalResult MaterializeInDestinationOp::reifyResultShapes(
612 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
613 if (getOperation()->getNumResults() == 1) {
614 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
615 reifiedReturnShapes.resize(1,
616 SmallVector<OpFoldResult>(getType().getRank()));
617 reifiedReturnShapes[0] =
618 tensor::getMixedSizes(builder, getLoc(), getDest());
619 }
620 return success();
621}
622
623Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
624 Location loc) {
625 if (isa<TensorType>(getDest().getType())) {
626 // The subset is the entire destination tensor.
627 return getDest();
628 }
629
630 // The "restrict" attribute is transferred from this op to the newly created
631 // to_tensor op. If this op does not the "restrict" attribute, the subset
632 // extraction cannot be built because there is no guarantee that there is no
633 // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
634 if (!getRestrict())
635 return {};
636
637 // Build a bufferization.to_tensor op.
638 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
639 assert(getRestrict() &&
640 "expected that ops with memrefs dest have 'restrict'");
641 setRestrict(false);
642 return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
643 getWritable());
644}
645
646bool MaterializeInDestinationOp::isEquivalentSubset(
647 Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
648 return equivalenceFn(getDest(), candidate);
649}
650
651SmallVector<Value>
652MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
653 return {getDest()};
654}
655
656OpOperand &MaterializeInDestinationOp::getSourceOperand() {
657 return getOperation()->getOpOperand(0) /*source*/;
658}
659
660bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
661 SubsetOpInterface subsetOp,
662 function_ref<bool(Value, Value)> equivalenceFn) {
663 return false;
664}
665
666bool MaterializeInDestinationOp::operatesOnDisjointSubset(
667 SubsetOpInterface subsetOp,
668 function_ref<bool(Value, Value)> equivalenceFn) {
669 return false;
670}
671
672LogicalResult MaterializeInDestinationOp::verify() {
673 if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
674 return emitOpError("'dest' must be a tensor or a memref");
675 if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
676 if (getOperation()->getNumResults() != 1)
677 return emitOpError("tensor 'dest' implies exactly one tensor result");
678 if (destType != getResult().getType())
679 return emitOpError("result and 'dest' types must match");
680 }
681 if (isa<BaseMemRefType>(getDest().getType()) &&
682 getOperation()->getNumResults() != 0)
683 return emitOpError("memref 'dest' implies zero results");
684 if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
685 return emitOpError("'restrict' is valid only for memref destinations");
686 if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
687 return emitOpError("'writable' must be specified if and only if the "
688 "destination is of memref type");
689 return success();
690}
691
692void MaterializeInDestinationOp::build(OpBuilder &builder,
693 OperationState &state, Value source,
694 Value dest) {
695 auto destTensorType = dyn_cast<TensorType>(dest.getType());
696 build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
697 source, dest);
698}
699
700bool MaterializeInDestinationOp::isWritable(Value value,
701 const AnalysisState &state) {
702 return isa<TensorType>(getDest().getType()) ? true : getWritable();
703}
704
705MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
706 return getDestMutable();
707}
708
709void MaterializeInDestinationOp::getEffects(
710 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
711 &effects) {
712 if (isa<BaseMemRefType>(getDest().getType()))
713 effects.emplace_back(MemoryEffects::Write::get(), getDest(),
714 SideEffects::DefaultResource::get());
715}
716
717//===----------------------------------------------------------------------===//
718// ToTensorOp
719//===----------------------------------------------------------------------===//
720
721bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
722 return getWritable();
723}
724
725OpFoldResult ToTensorOp::fold(FoldAdaptor) {
726 if (auto toMemref = getMemref().getDefiningOp<ToMemrefOp>())
727 // Approximate alias analysis by conservatively folding only when no there
728 // is no interleaved operation.
729 if (toMemref->getBlock() == this->getOperation()->getBlock() &&
730 toMemref->getNextNode() == this->getOperation())
731 return toMemref.getTensor();
732 return {};
733}
734
735namespace {
736struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
737 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
738
739 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
740 PatternRewriter &rewriter) const override {
741 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
742 if (!memrefToTensorOp)
743 return failure();
744
745 rewriter.replaceOpWithNewOp<memref::DimOp>(
746 dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
747 return success();
748 }
749};
750} // namespace
751
752void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
753 MLIRContext *context) {
754 results.add<DimOfToTensorFolder>(context);
755}
756
757//===----------------------------------------------------------------------===//
758// ToMemrefOp
759//===----------------------------------------------------------------------===//
760
761OpFoldResult ToMemrefOp::fold(FoldAdaptor) {
762 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
763 if (memrefToTensor.getMemref().getType() == getType())
764 return memrefToTensor.getMemref();
765 return {};
766}
767
768namespace {
769
770/// Replace tensor.cast + to_memref by to_memref + memref.cast.
771struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
772 using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
773
774 LogicalResult matchAndRewrite(ToMemrefOp toMemref,
775 PatternRewriter &rewriter) const final {
776 auto tensorCastOperand =
777 toMemref.getOperand().getDefiningOp<tensor::CastOp>();
778 if (!tensorCastOperand)
779 return failure();
780 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
781 tensorCastOperand.getOperand().getType());
782 if (!srcTensorType)
783 return failure();
784 auto memrefType = MemRefType::get(srcTensorType.getShape(),
785 srcTensorType.getElementType());
786 Value memref = rewriter.create<ToMemrefOp>(toMemref.getLoc(), memrefType,
787 tensorCastOperand.getOperand());
788 rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, toMemref.getType(),
789 memref);
790 return success();
791 }
792};
793
794/// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
795/// cast if necessary.
796struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
797 using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
798
799 LogicalResult matchAndRewrite(ToMemrefOp toMemref,
800 PatternRewriter &rewriter) const final {
801 BufferizationOptions options;
802 options.bufferAlignment = 0;
803 return foldToMemrefToTensorPair(rewriter, toMemref, options);
804 }
805};
806
807/// Fold a load on a to_memref operation into an tensor.extract on the
808/// corresponding tensor.
809struct LoadOfToMemref : public OpRewritePattern<memref::LoadOp> {
810 using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
811
812 LogicalResult matchAndRewrite(memref::LoadOp load,
813 PatternRewriter &rewriter) const override {
814 auto toMemref = load.getMemref().getDefiningOp<ToMemrefOp>();
815 if (!toMemref)
816 return failure();
817
818 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toMemref.getTensor(),
819 load.getIndices());
820 return success();
821 }
822};
823
824/// Fold dim of a to_memref into the dim of the tensor.
825struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
826 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
827
828 LogicalResult matchAndRewrite(memref::DimOp dimOp,
829 PatternRewriter &rewriter) const override {
830 auto castOp = dimOp.getSource().getDefiningOp<ToMemrefOp>();
831 if (!castOp)
832 return failure();
833 Value newSource = castOp.getOperand();
834 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
835 dimOp.getIndex());
836 return success();
837 }
838};
839
840} // namespace
841
842void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
843 MLIRContext *context) {
844 results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
845 ToMemrefToTensorFolding>(context);
846}
847
848LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
849 const BufferizationOptions &options) {
850 // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
851 (void)foldToMemrefToTensorPair(rewriter, *this, options);
852 // Note: The return value of `bufferize` indicates whether there was an error
853 // or not. (And not whether the pattern matched or not.)
854 return success();
855}
856
857std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
858 Value alloc) {
859 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
860 .getOperation();
861}
862
863std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
864 return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult();
865}
866
867//===----------------------------------------------------------------------===//
868// DeallocOp
869//===----------------------------------------------------------------------===//
870
871LogicalResult DeallocOp::inferReturnTypes(
872 MLIRContext *context, std::optional<::mlir::Location> location,
873 ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
874 RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
875 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
876 inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
877 IntegerType::get(context, 1));
878 return success();
879}
880
881LogicalResult DeallocOp::verify() {
882 if (getMemrefs().size() != getConditions().size())
883 return emitOpError(
884 "must have the same number of conditions as memrefs to deallocate");
885 if (getRetained().size() != getUpdatedConditions().size())
886 return emitOpError("must have the same number of updated conditions "
887 "(results) as retained operands");
888 return success();
889}
890
891static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
892 ValueRange memrefs,
893 ValueRange conditions,
894 PatternRewriter &rewriter) {
895 if (deallocOp.getMemrefs() == memrefs &&
896 deallocOp.getConditions() == conditions)
897 return failure();
898
899 rewriter.modifyOpInPlace(deallocOp, [&]() {
900 deallocOp.getMemrefsMutable().assign(memrefs);
901 deallocOp.getConditionsMutable().assign(conditions);
902 });
903 return success();
904}
905
906namespace {
907
908/// Remove duplicate values in the list of memrefs to be deallocated. We need to
909/// make sure the corresponding condition value is updated accordingly since
910/// their two conditions might not cover the same set of cases. In that case, we
911/// have to combine them (by computing the disjunction of them).
912/// Example:
913/// ```mlir
914/// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
915/// ```
916/// is canonicalized to
917/// ```mlir
918/// %0 = arith.ori %arg1, %arg2 : i1
919/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
920/// ```
921struct DeallocRemoveDuplicateDeallocMemrefs
922 : public OpRewritePattern<DeallocOp> {
923 using OpRewritePattern<DeallocOp>::OpRewritePattern;
924
925 LogicalResult matchAndRewrite(DeallocOp deallocOp,
926 PatternRewriter &rewriter) const override {
927 // Unique memrefs to be deallocated.
928 DenseMap<Value, unsigned> memrefToCondition;
929 SmallVector<Value> newMemrefs, newConditions;
930 for (auto [i, memref, cond] :
931 llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
932 if (memrefToCondition.count(memref)) {
933 // If the dealloc conditions don't match, we need to make sure that the
934 // dealloc happens on the union of cases.
935 Value &newCond = newConditions[memrefToCondition[memref]];
936 if (newCond != cond)
937 newCond =
938 rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
939 } else {
940 memrefToCondition.insert({memref, newConditions.size()});
941 newMemrefs.push_back(memref);
942 newConditions.push_back(cond);
943 }
944 }
945
946 // Return failure if we don't change anything such that we don't run into an
947 // infinite loop of pattern applications.
948 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
949 rewriter);
950 }
951};
952
953/// Remove duplicate values in the list of retained memrefs. We need to make
954/// sure the corresponding result condition value is replaced properly.
955/// Example:
956/// ```mlir
957/// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
958/// ```
959/// is canonicalized to
960/// ```mlir
961/// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
962/// ```
963struct DeallocRemoveDuplicateRetainedMemrefs
964 : public OpRewritePattern<DeallocOp> {
965 using OpRewritePattern<DeallocOp>::OpRewritePattern;
966
967 LogicalResult matchAndRewrite(DeallocOp deallocOp,
968 PatternRewriter &rewriter) const override {
969 // Unique retained values
970 DenseMap<Value, unsigned> seen;
971 SmallVector<Value> newRetained;
972 SmallVector<unsigned> resultReplacementIdx;
973 unsigned i = 0;
974 for (auto retained : deallocOp.getRetained()) {
975 if (seen.count(retained)) {
976 resultReplacementIdx.push_back(seen[retained]);
977 continue;
978 }
979
980 seen[retained] = i;
981 newRetained.push_back(retained);
982 resultReplacementIdx.push_back(i++);
983 }
984
985 // Return failure if we don't change anything such that we don't run into an
986 // infinite loop of pattern applications.
987 if (newRetained.size() == deallocOp.getRetained().size())
988 return failure();
989
990 // We need to create a new op because the number of results is always the
991 // same as the number of condition operands.
992 auto newDeallocOp =
993 rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(),
994 deallocOp.getConditions(), newRetained);
995 SmallVector<Value> replacements(
996 llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
997 return newDeallocOp.getUpdatedConditions()[idx];
998 }));
999 rewriter.replaceOp(deallocOp, replacements);
1000 return success();
1001 }
1002};
1003
1004/// Erase deallocation operations where the variadic list of memrefs to
1005/// deallocate is empty. Example:
1006/// ```mlir
1007/// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1008/// ```
1009struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1010 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1011
1012 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1013 PatternRewriter &rewriter) const override {
1014 if (deallocOp.getMemrefs().empty()) {
1015 Value constFalse = rewriter.create<arith::ConstantOp>(
1016 deallocOp.getLoc(), rewriter.getBoolAttr(false));
1017 rewriter.replaceOp(
1018 deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1019 constFalse));
1020 return success();
1021 }
1022 return failure();
1023 }
1024};
1025
1026/// Removes memrefs from the deallocation list if their associated condition is
1027/// always 'false'.
1028///
1029/// Example:
1030/// ```
1031/// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1032/// if (%arg2, %false)
1033/// ```
1034/// becomes
1035/// ```
1036/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1037/// ```
1038struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1039 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1040
1041 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1042 PatternRewriter &rewriter) const override {
1043 SmallVector<Value> newMemrefs, newConditions;
1044 for (auto [memref, cond] :
1045 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1046 if (!matchPattern(cond, m_Zero())) {
1047 newMemrefs.push_back(memref);
1048 newConditions.push_back(cond);
1049 }
1050 }
1051
1052 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1053 rewriter);
1054 }
1055};
1056
1057/// The `memref.extract_strided_metadata` is often inserted to get the base
1058/// memref if the operand is not already guaranteed to be the result of a memref
1059/// allocation operation. This canonicalization pattern removes this extraction
1060/// operation if the operand is now produced by an allocation operation (e.g.,
1061/// due to other canonicalizations simplifying the IR).
1062///
1063/// Example:
1064/// ```mlir
1065/// %alloc = memref.alloc() : memref<2xi32>
1066/// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1067/// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1068/// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1069/// ```
1070/// is canonicalized to
1071/// ```mlir
1072/// %alloc = memref.alloc() : memref<2xi32>
1073/// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1074/// ```
1075struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1076 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1077
1078 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1079 PatternRewriter &rewriter) const override {
1080 SmallVector<Value> newMemrefs(
1081 llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1082 auto extractStridedOp =
1083 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1084 if (!extractStridedOp)
1085 return memref;
1086 Value allocMemref = extractStridedOp.getOperand();
1087 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1088 if (!allocOp)
1089 return memref;
1090 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1091 return allocMemref;
1092 return memref;
1093 }));
1094
1095 return updateDeallocIfChanged(deallocOp, newMemrefs,
1096 deallocOp.getConditions(), rewriter);
1097 }
1098};
1099
1100/// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1101/// other user of the allocated value and the allocating operation can be safely
1102/// removed. If the same value is present multiple times, this pattern relies on
1103/// other canonicalization patterns to remove the duplicate first.
1104///
1105/// Example:
1106/// ```mlir
1107/// %alloc = memref.alloc() : memref<2xi32>
1108/// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1109/// ```
1110/// is canonicalized to
1111/// ```mlir
1112/// bufferization.dealloc (%arg0 : ...) if (%true)
1113/// ```
1114struct RemoveAllocDeallocPairWhenNoOtherUsers
1115 : public OpRewritePattern<DeallocOp> {
1116 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1117
1118 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1119 PatternRewriter &rewriter) const override {
1120 SmallVector<Value> newMemrefs, newConditions;
1121 SmallVector<Operation *> toDelete;
1122 for (auto [memref, cond] :
1123 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1124 if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1125 // Check that it is indeed an allocate effect, that the op has no other
1126 // side effects (which would not allow us to remove the op), and that
1127 // there are no other users.
1128 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1129 hasSingleEffect<MemoryEffects::Allocate>(allocOp, memref) &&
1130 memref.hasOneUse()) {
1131 toDelete.push_back(allocOp);
1132 continue;
1133 }
1134 }
1135
1136 newMemrefs.push_back(memref);
1137 newConditions.push_back(cond);
1138 }
1139
1140 if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1141 rewriter)))
1142 return failure();
1143
1144 for (Operation *op : toDelete)
1145 rewriter.eraseOp(op);
1146
1147 return success();
1148 }
1149};
1150
1151} // anonymous namespace
1152
1153void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1154 MLIRContext *context) {
1155 populateDeallocOpCanonicalizationPatterns(results, context);
1156}
1157
1158void bufferization::populateDeallocOpCanonicalizationPatterns(
1159 RewritePatternSet &patterns, MLIRContext *context) {
1160 patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1161 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1162 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1163 RemoveAllocDeallocPairWhenNoOtherUsers>(arg&: context);
1164}
1165
1166//===----------------------------------------------------------------------===//
1167// TableGen'd op method definitions
1168//===----------------------------------------------------------------------===//
1169
1170#define GET_OP_CLASSES
1171#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
1172

source code of mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp