1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/IR/Matchers.h"
16#include <optional>
17
18using namespace mlir;
19using namespace mlir::bufferization;
20
21//===----------------------------------------------------------------------===//
22// Helper functions
23//===----------------------------------------------------------------------===//
24
25FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
26 OpBuilder &b, Value value, MemRefType destType,
27 const BufferizationOptions &options) {
28 auto srcType = llvm::cast<MemRefType>(Val: value.getType());
29
30 // Element type and rank must match.
31 if (srcType.getElementType() != destType.getElementType())
32 return failure();
33 if (srcType.getRank() != destType.getRank())
34 return failure();
35
36 // In case the affine maps are different, we may need to use a copy if we go
37 // from dynamic to static offset or stride (the canonicalization cannot know
38 // at this point that it is really cast compatible).
39 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
40 int64_t sourceOffset, targetOffset;
41 SmallVector<int64_t, 4> sourceStrides, targetStrides;
42 if (failed(Result: source.getStridesAndOffset(strides&: sourceStrides, offset&: sourceOffset)) ||
43 failed(Result: target.getStridesAndOffset(strides&: targetStrides, offset&: targetOffset)))
44 return false;
45 auto dynamicToStatic = [](int64_t a, int64_t b) {
46 return ShapedType::isDynamic(dValue: a) && ShapedType::isStatic(dValue: b);
47 };
48 if (dynamicToStatic(sourceOffset, targetOffset))
49 return false;
50 for (auto it : zip(t&: sourceStrides, u&: targetStrides))
51 if (dynamicToStatic(std::get<0>(t&: it), std::get<1>(t&: it)))
52 return false;
53 return true;
54 };
55
56 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
57 // ensure that we only generate casts that always succeed at runtime, we check
58 // a fix extra conditions in `isGuaranteedCastCompatible`.
59 if (memref::CastOp::areCastCompatible(inputs: srcType, outputs: destType) &&
60 isGuaranteedCastCompatible(srcType, destType)) {
61 Value casted = b.create<memref::CastOp>(location: value.getLoc(), args&: destType, args&: value);
62 return casted;
63 }
64
65 auto loc = value.getLoc();
66 SmallVector<Value, 4> dynamicOperands;
67 for (int i = 0; i < destType.getRank(); ++i) {
68 if (destType.getShape()[i] != ShapedType::kDynamic)
69 continue;
70 Value size = b.create<memref::DimOp>(location: loc, args&: value, args&: i);
71 dynamicOperands.push_back(Elt: size);
72 }
73
74 FailureOr<Value> copy =
75 options.createAlloc(b, loc, type: destType, dynShape: dynamicOperands);
76 if (failed(Result: copy))
77 return failure();
78 if (failed(Result: options.createMemCpy(b, loc, from: value, to: *copy)))
79 return failure();
80 return copy;
81}
82
83/// Try to fold to_buffer(to_tensor(x)). If x's type and the result type of the
84/// to_buffer op are different, a memref.cast is needed.
85LogicalResult mlir::bufferization::foldToBufferToTensorPair(
86 RewriterBase &rewriter, ToBufferOp toBuffer,
87 const BufferizationOptions &options) {
88 auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
89 if (!bufferToTensor)
90 return failure();
91
92 Type srcType = bufferToTensor.getBuffer().getType();
93 Type destType = toBuffer.getType();
94
95 // Directly rewrite if the type did not change.
96 if (srcType == destType) {
97 rewriter.replaceOp(op: toBuffer, newValues: bufferToTensor.getBuffer());
98 return success();
99 }
100
101 auto rankedSrcType = llvm::dyn_cast<MemRefType>(Val&: srcType);
102 auto rankedDestType = llvm::dyn_cast<MemRefType>(Val&: destType);
103 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(Val&: srcType);
104
105 // Ranked memref -> Ranked memref cast.
106 if (rankedSrcType && rankedDestType) {
107 FailureOr<Value> replacement = castOrReallocMemRefValue(
108 b&: rewriter, value: bufferToTensor.getBuffer(), destType: rankedDestType, options);
109 if (failed(Result: replacement))
110 return failure();
111
112 rewriter.replaceOp(op: toBuffer, newValues: *replacement);
113 return success();
114 }
115
116 // Unranked memref -> Ranked memref cast: May require a copy.
117 // TODO: Not implemented at the moment.
118 if (unrankedSrcType && rankedDestType)
119 return failure();
120
121 // Unranked memref -> unranked memref cast
122 // Ranked memref -> unranked memref cast: No copy needed.
123 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
124 "expected that types are cast compatible");
125 rewriter.replaceOpWithNewOp<memref::CastOp>(op: toBuffer, args&: destType,
126 args: bufferToTensor.getBuffer());
127 return success();
128}
129
130void mlir::bufferization::populateDynamicDimSizes(
131 OpBuilder &b, Location loc, Value shapedValue,
132 SmallVector<Value> &dynamicDims) {
133 auto shapedType = llvm::cast<ShapedType>(Val: shapedValue.getType());
134 for (int64_t i = 0; i < shapedType.getRank(); ++i) {
135 if (shapedType.isDynamicDim(idx: i)) {
136 if (llvm::isa<MemRefType>(Val: shapedType)) {
137 dynamicDims.push_back(Elt: b.create<memref::DimOp>(location: loc, args&: shapedValue, args&: i));
138 } else {
139 assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
140 dynamicDims.push_back(Elt: b.create<tensor::DimOp>(location: loc, args&: shapedValue, args&: i));
141 }
142 }
143 }
144}
145
146//===----------------------------------------------------------------------===//
147// AllocTensorOp
148//===----------------------------------------------------------------------===//
149
150LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
151 const BufferizationOptions &options,
152 BufferizationState &state) {
153 OpBuilder::InsertionGuard g(rewriter);
154 Location loc = getLoc();
155
156 // Nothing to do for dead AllocTensorOps.
157 if (getOperation()->getUses().empty()) {
158 rewriter.eraseOp(op: getOperation());
159 return success();
160 }
161
162 // Get "copy" buffer.
163 Value copyBuffer;
164 if (getCopy()) {
165 FailureOr<Value> maybeCopyBuffer =
166 getBuffer(rewriter, value: getCopy(), options, state);
167 if (failed(Result: maybeCopyBuffer))
168 return failure();
169 copyBuffer = *maybeCopyBuffer;
170 }
171
172 // Create memory allocation.
173 auto allocType = bufferization::getBufferType(value: getResult(), options, state);
174 if (failed(Result: allocType))
175 return failure();
176 SmallVector<Value> dynamicDims = getDynamicSizes();
177 if (getCopy()) {
178 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
179 populateDynamicDimSizes(b&: rewriter, loc, shapedValue: copyBuffer, dynamicDims);
180 }
181 FailureOr<Value> alloc = options.createAlloc(
182 b&: rewriter, loc, type: llvm::cast<MemRefType>(Val&: *allocType), dynShape: dynamicDims);
183 if (failed(Result: alloc))
184 return failure();
185
186 // Create memory copy (if any).
187 if (getCopy()) {
188 if (failed(Result: options.createMemCpy(b&: rewriter, loc, from: copyBuffer, to: *alloc)))
189 return failure();
190 }
191
192 // Replace op.
193 replaceOpWithBufferizedValues(rewriter, op: getOperation(), values: *alloc);
194
195 return success();
196}
197
198bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
199 const AnalysisState &state) {
200 // AllocTensorOps do not write unless they have a `copy` value.
201 return static_cast<bool>(getCopy());
202}
203
204bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
205 const AnalysisState &state) {
206 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
207 "expected copy operand");
208 return true;
209}
210
211bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
212 const AnalysisState &state) {
213 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
214 "expected copy operand");
215 return false;
216}
217
218AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
219 const AnalysisState &state) {
220 // This is a new allocation. It does not alias with any other buffer.
221 return {};
222}
223
224FailureOr<BufferLikeType>
225AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
226 const BufferizationState &state,
227 SmallVector<Value> &invocationStack) {
228 assert(value == getResult() && "invalid value");
229
230 // Compute memory space of this allocation.
231 Attribute memorySpace;
232 if (getMemorySpace().has_value()) {
233 memorySpace = *getMemorySpace();
234 } else if (getCopy()) {
235 auto copyBufferType =
236 bufferization::detail::asMemRefType(bufferType: bufferization::getBufferType(
237 value: getCopy(), options, state, invocationStack));
238 if (failed(Result: copyBufferType))
239 return failure();
240 memorySpace = copyBufferType->getMemorySpace();
241 } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
242 memorySpace = *ms;
243 } else {
244 return getOperation()->emitError(message: "could not infer memory space");
245 }
246
247 return cast<BufferLikeType>(
248 Val: getMemRefTypeWithStaticIdentityLayout(tensorType: getType(), memorySpace));
249}
250
251LogicalResult AllocTensorOp::verify() {
252 if (getCopy() && !getDynamicSizes().empty())
253 return emitError(message: "dynamic sizes not needed when copying a tensor");
254 if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
255 return emitError(message: "expected ")
256 << getType().getNumDynamicDims() << " dynamic sizes";
257 if (getCopy() && getCopy().getType() != getType())
258 return emitError(message: "expected that `copy` and return type match");
259 return success();
260}
261
262void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
263 RankedTensorType type, ValueRange dynamicSizes) {
264 build(odsBuilder&: builder, odsState&: result, result: type, dynamic_sizes: dynamicSizes, /*copy=*/Value(),
265 /*size_hint=*/Value(),
266 /*memory_space=*/IntegerAttr());
267}
268
269void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
270 RankedTensorType type, ValueRange dynamicSizes,
271 Value copy) {
272 build(odsBuilder&: builder, odsState&: result, result: type, dynamic_sizes: dynamicSizes, copy, /*size_hint=*/Value(),
273 /*memory_space=*/IntegerAttr());
274}
275
276void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
277 TensorType type, ValueRange dynamicSizes, Value copy,
278 IntegerAttr memorySpace) {
279 build(odsBuilder&: builder, odsState&: result, result: type, dynamic_sizes: dynamicSizes, copy, /*size_hint=*/Value(),
280 memory_space: memorySpace);
281}
282
283namespace {
284/// Change the type of the result of a `bufferization.alloc_tensor` by making
285/// the result type statically sized along dimension that in the original
286/// operation where defined as dynamic, but the size was defined using a
287/// `constant` op. For example:
288///
289/// %c5 = arith.constant 5: index
290/// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
291///
292/// to
293///
294/// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
295struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
296 using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
297
298 LogicalResult matchAndRewrite(AllocTensorOp op,
299 PatternRewriter &rewriter) const override {
300 if (op.getCopy())
301 return failure();
302 SmallVector<int64_t> newShape = llvm::to_vector(Range: op.getType().getShape());
303 SmallVector<Value> newDynamicSizes;
304 unsigned int dynValCounter = 0;
305 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
306 if (!op.isDynamicDim(idx: i))
307 continue;
308 Value value = op.getDynamicSizes()[dynValCounter++];
309 APInt intVal;
310 if (matchPattern(value, pattern: m_ConstantInt(bind_value: &intVal))) {
311 int64_t dim = intVal.getSExtValue();
312 if (dim >= 0)
313 newShape[i] = intVal.getSExtValue();
314 else
315 newDynamicSizes.push_back(Elt: value);
316 } else {
317 newDynamicSizes.push_back(Elt: value);
318 }
319 }
320 RankedTensorType newType = RankedTensorType::get(
321 shape: newShape, elementType: op.getType().getElementType(), encoding: op.getType().getEncoding());
322 if (newType == op.getType())
323 return failure();
324 auto newOp = rewriter.create<AllocTensorOp>(
325 location: op.getLoc(), args&: newType, args&: newDynamicSizes, /*copy=*/args: Value());
326 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, args: op.getType(), args&: newOp);
327 return success();
328 }
329};
330
331struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
332 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
333
334 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
335 PatternRewriter &rewriter) const override {
336 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
337 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
338 if (!allocTensorOp || !maybeConstantIndex)
339 return failure();
340 if (*maybeConstantIndex < 0 ||
341 *maybeConstantIndex >= allocTensorOp.getType().getRank())
342 return failure();
343 if (!allocTensorOp.getType().isDynamicDim(idx: *maybeConstantIndex))
344 return failure();
345 rewriter.replaceOp(
346 op: dimOp, newValues: allocTensorOp.getDynamicSize(b&: rewriter, idx: *maybeConstantIndex));
347 return success();
348 }
349};
350} // namespace
351
352void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
353 MLIRContext *ctx) {
354 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(arg&: ctx);
355}
356
357LogicalResult AllocTensorOp::reifyResultShapes(
358 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
359 auto shapes = llvm::to_vector<4>(
360 Range: llvm::map_range(C: llvm::seq<int64_t>(Begin: 0, End: getType().getRank()),
361 F: [&](int64_t dim) -> OpFoldResult {
362 if (isDynamicDim(idx: dim))
363 return getDynamicSize(b&: builder, idx: dim);
364 return builder.getIndexAttr(value: getStaticSize(idx: dim));
365 }));
366 reifiedReturnShapes.emplace_back(Args: std::move(shapes));
367 return success();
368}
369
370ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
371 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizesOperands;
372 if (parser.parseLParen() || parser.parseOperandList(result&: dynamicSizesOperands) ||
373 parser.parseRParen())
374 return failure();
375 ParseResult copyKeyword = parser.parseOptionalKeyword(keyword: "copy");
376 OpAsmParser::UnresolvedOperand copyOperand;
377 if (copyKeyword.succeeded())
378 if (parser.parseLParen() || parser.parseOperand(result&: copyOperand) ||
379 parser.parseRParen())
380 return failure();
381 ParseResult sizeHintKeyword = parser.parseOptionalKeyword(keyword: "size_hint");
382 OpAsmParser::UnresolvedOperand sizeHintOperand;
383 if (sizeHintKeyword.succeeded())
384 if (parser.parseEqual() || parser.parseOperand(result&: sizeHintOperand))
385 return failure();
386 if (parser.parseOptionalAttrDict(result&: result.attributes) || parser.parseColon())
387 return failure();
388
389 TensorType type;
390 if (parser.parseCustomTypeWithFallback(result&: type))
391 return failure();
392 result.addTypes(newTypes: type);
393
394 Type indexType = parser.getBuilder().getIndexType();
395 if (parser.resolveOperands(operands&: dynamicSizesOperands, type: indexType, result&: result.operands))
396 return failure();
397 if (copyKeyword.succeeded())
398 if (parser.resolveOperand(operand: copyOperand, type, result&: result.operands))
399 return failure();
400 if (sizeHintKeyword.succeeded())
401 if (parser.resolveOperand(operand: sizeHintOperand, type: indexType, result&: result.operands))
402 return failure();
403 result.addAttribute(name: AllocTensorOp::getOperandSegmentSizeAttr(),
404 attr: parser.getBuilder().getDenseI32ArrayAttr(
405 values: {static_cast<int32_t>(dynamicSizesOperands.size()),
406 static_cast<int32_t>(copyKeyword.succeeded()),
407 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
408 return success();
409}
410
411void AllocTensorOp::print(OpAsmPrinter &p) {
412 p << "(" << getDynamicSizes() << ")";
413 if (getCopy())
414 p << " copy(" << getCopy() << ")";
415 if (getSizeHint())
416 p << " size_hint=" << getSizeHint();
417 p.printOptionalAttrDict(attrs: (*this)->getAttrs(), /*elidedAttrs=*/{
418 AllocTensorOp::getOperandSegmentSizeAttr()});
419 p << " : ";
420 auto type = getResult().getType();
421 if (auto validType = llvm::dyn_cast<::mlir::TensorType>(Val&: type))
422 p.printStrippedAttrOrType(attrOrType: validType);
423 else
424 p << type;
425}
426
427Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
428 assert(isDynamicDim(idx) && "expected dynamic dim");
429 if (getCopy())
430 return b.create<tensor::DimOp>(location: getLoc(), args: getCopy(), args&: idx);
431 return getOperand(i: getIndexOfDynamicSize(idx));
432}
433
434//===----------------------------------------------------------------------===//
435// CloneOp
436//===----------------------------------------------------------------------===//
437
438OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
439 return succeeded(Result: memref::foldMemRefCast(op: *this)) ? getResult() : Value();
440}
441
442namespace {
443
444/// Merge the clone and its source (by converting the clone to a cast) when
445/// possible.
446struct SimplifyClones : public OpRewritePattern<CloneOp> {
447 using OpRewritePattern<CloneOp>::OpRewritePattern;
448
449 LogicalResult matchAndRewrite(CloneOp cloneOp,
450 PatternRewriter &rewriter) const override {
451 if (cloneOp.use_empty()) {
452 rewriter.eraseOp(op: cloneOp);
453 return success();
454 }
455
456 Value source = cloneOp.getInput();
457 if (source.getType() != cloneOp.getType() &&
458 !memref::CastOp::areCastCompatible(inputs: {source.getType()},
459 outputs: {cloneOp.getType()}))
460 return failure();
461
462 // Aims to find the dealloc op for the canonical source
463 // which otherwise could prevent removal of unnecessary allocs.
464 Value canonicalSource = source;
465 while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
466 Val: canonicalSource.getDefiningOp()))
467 canonicalSource = iface.getViewSource();
468
469 std::optional<Operation *> maybeCloneDeallocOp =
470 memref::findDealloc(allocValue: cloneOp.getOutput());
471 // Skip if either of them has > 1 deallocate operations.
472 if (!maybeCloneDeallocOp.has_value())
473 return failure();
474 std::optional<Operation *> maybeSourceDeallocOp =
475 memref::findDealloc(allocValue: canonicalSource);
476 if (!maybeSourceDeallocOp.has_value())
477 return failure();
478 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
479 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
480
481 // If both are deallocated in the same block, their in-block lifetimes
482 // might not fully overlap, so we cannot decide which one to drop.
483 if (cloneDeallocOp && sourceDeallocOp &&
484 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
485 return failure();
486
487 Block *currentBlock = cloneOp->getBlock();
488 Operation *redundantDealloc = nullptr;
489 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
490 redundantDealloc = cloneDeallocOp;
491 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
492 redundantDealloc = sourceDeallocOp;
493 }
494
495 if (!redundantDealloc)
496 return failure();
497
498 // Safety check that there are no other deallocations inbetween
499 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
500 // of source before the uses of the clone. With alias information, we could
501 // restrict this to only fail of the dealloc's operand is an alias
502 // of the source.
503 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
504 pos = pos->getNextNode()) {
505 // Bail if we run out of operations while looking for a deallocation op.
506 if (!pos)
507 return failure();
508 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(Val: pos);
509 if (!effectInterface)
510 continue;
511 if (effectInterface.hasEffect<MemoryEffects::Free>())
512 return failure();
513 }
514
515 if (source.getType() != cloneOp.getType())
516 source = rewriter.create<memref::CastOp>(location: cloneOp.getLoc(),
517 args: cloneOp.getType(), args&: source);
518 rewriter.replaceOp(op: cloneOp, newValues: source);
519 rewriter.eraseOp(op: redundantDealloc);
520 return success();
521 }
522};
523
524} // namespace
525
526void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
527 MLIRContext *context) {
528 results.add<SimplifyClones>(arg&: context);
529}
530
531//===----------------------------------------------------------------------===//
532// DeallocTensorOp
533//===----------------------------------------------------------------------===//
534
535LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
536 const BufferizationOptions &options,
537 BufferizationState &state) {
538 FailureOr<Value> buffer = getBuffer(rewriter, value: getTensor(), options, state);
539 if (failed(Result: buffer))
540 return failure();
541 rewriter.create<memref::DeallocOp>(location: getLoc(), args&: *buffer);
542 rewriter.eraseOp(op: getOperation());
543 return success();
544}
545
546//===----------------------------------------------------------------------===//
547// MaterializeInDestinationOp
548//===----------------------------------------------------------------------===//
549
550bool MaterializeInDestinationOp::bufferizesToMemoryRead(
551 OpOperand &opOperand, const AnalysisState &state) {
552 return opOperand == getSourceMutable();
553}
554
555bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
556 OpOperand &opOperand, const AnalysisState &state) {
557 if (opOperand == getDestMutable()) {
558 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
559 return true;
560 }
561 return false;
562}
563
564bool MaterializeInDestinationOp::mustBufferizeInPlace(
565 OpOperand &opOperand, const AnalysisState &state) {
566 // The source is only read and not written, so it always bufferizes in-place
567 // by default. The destination is written and is forced to bufferize in-place
568 // (if it is a tensor).
569 return true;
570}
571
572AliasingValueList
573MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
574 const AnalysisState &state) {
575 if (opOperand == getDestMutable()) {
576 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
577 return {{getOperation()->getResult(idx: 0), BufferRelation::Equivalent}};
578 }
579 return {};
580}
581
582LogicalResult
583MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
584 const BufferizationOptions &options,
585 BufferizationState &state) {
586 bool tensorDest = isa<TensorType>(Val: getDest().getType());
587 Value buffer;
588 if (tensorDest) {
589 FailureOr<Value> maybeBuffer =
590 getBuffer(rewriter, value: getDest(), options, state);
591 if (failed(Result: maybeBuffer))
592 return failure();
593 buffer = *maybeBuffer;
594 } else {
595 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
596 buffer = getDest();
597 }
598 auto srcBuffer = getBuffer(rewriter, value: getSource(), options, state);
599 if (failed(Result: srcBuffer))
600 return failure();
601 if (failed(Result: options.createMemCpy(b&: rewriter, loc: getLoc(), from: *srcBuffer, to: buffer)))
602 return failure();
603 replaceOpWithBufferizedValues(rewriter, op: getOperation(),
604 values: tensorDest ? ValueRange(buffer) : ValueRange());
605 return success();
606}
607
608bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
609 const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
610 // As elements are copied from the "source" buffer to the "dest" buffer,
611 // already copied elements are not read a second time.
612 return true;
613}
614
615LogicalResult MaterializeInDestinationOp::reifyResultShapes(
616 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
617 if (getOperation()->getNumResults() == 1) {
618 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
619 reifiedReturnShapes.resize(N: 1,
620 NV: SmallVector<OpFoldResult>(getType().getRank()));
621 reifiedReturnShapes[0] =
622 tensor::getMixedSizes(builder, loc: getLoc(), value: getDest());
623 }
624 return success();
625}
626
627Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
628 Location loc) {
629 if (isa<TensorType>(Val: getDest().getType())) {
630 // The subset is the entire destination tensor.
631 return getDest();
632 }
633
634 // The "restrict" attribute is transferred from this op to the newly created
635 // to_tensor op. If this op does not the "restrict" attribute, the subset
636 // extraction cannot be built because there is no guarantee that there is no
637 // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
638 if (!getRestrict())
639 return {};
640
641 // Build a bufferization.to_tensor op.
642 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
643 assert(getRestrict() &&
644 "expected that ops with memrefs dest have 'restrict'");
645 setRestrict(false);
646 return builder.create<ToTensorOp>(
647 location: loc, args: memref::getTensorTypeFromMemRefType(type: getDest().getType()), args: getDest(),
648 /*restrict=*/args: true, args: getWritable());
649}
650
651bool MaterializeInDestinationOp::isEquivalentSubset(
652 Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
653 return equivalenceFn(getDest(), candidate);
654}
655
656SmallVector<Value>
657MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
658 return {getDest()};
659}
660
661OpOperand &MaterializeInDestinationOp::getSourceOperand() {
662 return getOperation()->getOpOperand(idx: 0) /*source*/;
663}
664
665bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
666 SubsetOpInterface subsetOp,
667 function_ref<bool(Value, Value)> equivalenceFn) {
668 return false;
669}
670
671bool MaterializeInDestinationOp::operatesOnDisjointSubset(
672 SubsetOpInterface subsetOp,
673 function_ref<bool(Value, Value)> equivalenceFn) {
674 return false;
675}
676
677LogicalResult MaterializeInDestinationOp::verify() {
678 if (!isa<TensorType, BaseMemRefType>(Val: getDest().getType()))
679 return emitOpError(message: "'dest' must be a tensor or a memref");
680 if (auto destType = dyn_cast<TensorType>(Val: getDest().getType())) {
681 if (getOperation()->getNumResults() != 1)
682 return emitOpError(message: "tensor 'dest' implies exactly one tensor result");
683 if (destType != getResult().getType())
684 return emitOpError(message: "result and 'dest' types must match");
685 }
686 if (isa<BaseMemRefType>(Val: getDest().getType()) &&
687 getOperation()->getNumResults() != 0)
688 return emitOpError(message: "memref 'dest' implies zero results");
689 if (getRestrict() && !isa<BaseMemRefType>(Val: getDest().getType()))
690 return emitOpError(message: "'restrict' is valid only for memref destinations");
691 if (getWritable() != isa<BaseMemRefType>(Val: getDest().getType()))
692 return emitOpError(message: "'writable' must be specified if and only if the "
693 "destination is of memref type");
694 TensorType srcType = getSource().getType();
695 ShapedType destType = cast<ShapedType>(Val: getDest().getType());
696 if (srcType.hasRank() != destType.hasRank())
697 return emitOpError(message: "source/destination shapes are incompatible");
698 if (srcType.hasRank()) {
699 if (srcType.getRank() != destType.getRank())
700 return emitOpError(message: "rank mismatch between source and destination shape");
701 for (auto [src, dest] :
702 llvm::zip(t: srcType.getShape(), u: destType.getShape())) {
703 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
704 // Cannot verify dynamic dimension size. Assume that that they match at
705 // runtime.
706 continue;
707 }
708 if (src != dest)
709 return emitOpError(message: "source/destination shapes are incompatible");
710 }
711 }
712 return success();
713}
714
715void MaterializeInDestinationOp::build(OpBuilder &builder,
716 OperationState &state, Value source,
717 Value dest) {
718 auto destTensorType = dyn_cast<TensorType>(Val: dest.getType());
719 build(odsBuilder&: builder, odsState&: state, /*result=*/destTensorType ? destTensorType : Type(),
720 source, dest);
721}
722
723bool MaterializeInDestinationOp::isWritable(Value value,
724 const AnalysisState &state) {
725 return isa<TensorType>(Val: getDest().getType()) ? true : getWritable();
726}
727
728MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
729 return getDestMutable();
730}
731
732void MaterializeInDestinationOp::getEffects(
733 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
734 &effects) {
735 if (isa<BaseMemRefType>(Val: getDest().getType()))
736 effects.emplace_back(Args: MemoryEffects::Write::get(), Args: &getDestMutable(),
737 Args: SideEffects::DefaultResource::get());
738}
739
740//===----------------------------------------------------------------------===//
741// ToTensorOp
742//===----------------------------------------------------------------------===//
743
744bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
745 return getWritable();
746}
747
748OpFoldResult ToTensorOp::fold(FoldAdaptor) {
749 if (auto toBuffer = getBuffer().getDefiningOp<ToBufferOp>())
750 // Approximate alias analysis by conservatively folding only when no there
751 // is no interleaved operation.
752 if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
753 toBuffer->getNextNode() == this->getOperation())
754 return toBuffer.getTensor();
755 return {};
756}
757
758namespace {
759struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
760 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
761
762 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
763 PatternRewriter &rewriter) const override {
764 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
765 if (!memrefToTensorOp)
766 return failure();
767
768 rewriter.replaceOpWithNewOp<memref::DimOp>(
769 op: dimOp, args: memrefToTensorOp.getBuffer(), args: dimOp.getIndex());
770 return success();
771 }
772};
773} // namespace
774
775void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
776 MLIRContext *context) {
777 results.add<DimOfToTensorFolder>(arg&: context);
778}
779
780//===----------------------------------------------------------------------===//
781// ToBufferOp
782//===----------------------------------------------------------------------===//
783
784OpFoldResult ToBufferOp::fold(FoldAdaptor) {
785 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
786 if (memrefToTensor.getBuffer().getType() == getType())
787 return memrefToTensor.getBuffer();
788 return {};
789}
790
791namespace {
792
793/// Replace tensor.cast + to_buffer by to_buffer + memref.cast.
794struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
795 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
796
797 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
798 PatternRewriter &rewriter) const final {
799 auto tensorCastOperand =
800 toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
801 if (!tensorCastOperand)
802 return failure();
803 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
804 Val: tensorCastOperand.getOperand().getType());
805 if (!srcTensorType)
806 return failure();
807 auto memrefType = MemRefType::get(shape: srcTensorType.getShape(),
808 elementType: srcTensorType.getElementType());
809 Value memref = rewriter.create<ToBufferOp>(location: toBuffer.getLoc(), args&: memrefType,
810 args: tensorCastOperand.getOperand());
811 rewriter.replaceOpWithNewOp<memref::CastOp>(op: toBuffer, args: toBuffer.getType(),
812 args&: memref);
813 return success();
814 }
815};
816
817/// Canonicalize bufferization.to_tensor + bufferization.to_buffer. Insert a
818/// cast if necessary.
819struct ToBufferToTensorFolding : public OpRewritePattern<ToBufferOp> {
820 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
821
822 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
823 PatternRewriter &rewriter) const final {
824 BufferizationOptions options;
825 options.bufferAlignment = 0;
826 return foldToBufferToTensorPair(rewriter, toBuffer, options);
827 }
828};
829
830/// Fold a load on a to_buffer operation into an tensor.extract on the
831/// corresponding tensor.
832struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
833 using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
834
835 LogicalResult matchAndRewrite(memref::LoadOp load,
836 PatternRewriter &rewriter) const override {
837 auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
838 if (!toBuffer)
839 return failure();
840
841 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op: load, args: toBuffer.getTensor(),
842 args: load.getIndices());
843 return success();
844 }
845};
846
847/// Fold dim of a to_buffer into the dim of the tensor.
848struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
849 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
850
851 LogicalResult matchAndRewrite(memref::DimOp dimOp,
852 PatternRewriter &rewriter) const override {
853 auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
854 if (!castOp)
855 return failure();
856 Value newSource = castOp.getOperand();
857 rewriter.replaceOpWithNewOp<tensor::DimOp>(op: dimOp, args&: newSource,
858 args: dimOp.getIndex());
859 return success();
860 }
861};
862
863} // namespace
864
865void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
866 MLIRContext *context) {
867 results.add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
868 ToBufferToTensorFolding>(arg&: context);
869}
870
871LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
872 const BufferizationOptions &options,
873 BufferizationState &state) {
874 // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
875 (void)foldToBufferToTensorPair(rewriter, toBuffer: *this, options);
876 // Note: The return value of `bufferize` indicates whether there was an error
877 // or not. (And not whether the pattern matched or not.)
878 return success();
879}
880
881std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
882 Value alloc) {
883 return builder.create<memref::DeallocOp>(location: alloc.getLoc(), args&: alloc)
884 .getOperation();
885}
886
887std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
888 return builder.create<CloneOp>(location: alloc.getLoc(), args&: alloc).getResult();
889}
890
891//===----------------------------------------------------------------------===//
892// DeallocOp
893//===----------------------------------------------------------------------===//
894
895LogicalResult DeallocOp::inferReturnTypes(
896 MLIRContext *context, std::optional<::mlir::Location> location,
897 ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
898 RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
899 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
900 inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
901 IntegerType::get(context, width: 1));
902 return success();
903}
904
905LogicalResult DeallocOp::verify() {
906 if (getMemrefs().size() != getConditions().size())
907 return emitOpError(
908 message: "must have the same number of conditions as memrefs to deallocate");
909 if (getRetained().size() != getUpdatedConditions().size())
910 return emitOpError(message: "must have the same number of updated conditions "
911 "(results) as retained operands");
912 return success();
913}
914
915static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
916 ValueRange memrefs,
917 ValueRange conditions,
918 PatternRewriter &rewriter) {
919 if (deallocOp.getMemrefs() == memrefs &&
920 deallocOp.getConditions() == conditions)
921 return failure();
922
923 rewriter.modifyOpInPlace(root: deallocOp, callable: [&]() {
924 deallocOp.getMemrefsMutable().assign(values: memrefs);
925 deallocOp.getConditionsMutable().assign(values: conditions);
926 });
927 return success();
928}
929
930namespace {
931
932/// Remove duplicate values in the list of memrefs to be deallocated. We need to
933/// make sure the corresponding condition value is updated accordingly since
934/// their two conditions might not cover the same set of cases. In that case, we
935/// have to combine them (by computing the disjunction of them).
936/// Example:
937/// ```mlir
938/// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
939/// ```
940/// is canonicalized to
941/// ```mlir
942/// %0 = arith.ori %arg1, %arg2 : i1
943/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
944/// ```
945struct DeallocRemoveDuplicateDeallocMemrefs
946 : public OpRewritePattern<DeallocOp> {
947 using OpRewritePattern<DeallocOp>::OpRewritePattern;
948
949 LogicalResult matchAndRewrite(DeallocOp deallocOp,
950 PatternRewriter &rewriter) const override {
951 // Unique memrefs to be deallocated.
952 DenseMap<Value, unsigned> memrefToCondition;
953 SmallVector<Value> newMemrefs, newConditions;
954 for (auto [i, memref, cond] :
955 llvm::enumerate(First: deallocOp.getMemrefs(), Rest: deallocOp.getConditions())) {
956 if (memrefToCondition.count(Val: memref)) {
957 // If the dealloc conditions don't match, we need to make sure that the
958 // dealloc happens on the union of cases.
959 Value &newCond = newConditions[memrefToCondition[memref]];
960 if (newCond != cond)
961 newCond =
962 rewriter.create<arith::OrIOp>(location: deallocOp.getLoc(), args&: newCond, args&: cond);
963 } else {
964 memrefToCondition.insert(KV: {memref, newConditions.size()});
965 newMemrefs.push_back(Elt: memref);
966 newConditions.push_back(Elt: cond);
967 }
968 }
969
970 // Return failure if we don't change anything such that we don't run into an
971 // infinite loop of pattern applications.
972 return updateDeallocIfChanged(deallocOp, memrefs: newMemrefs, conditions: newConditions,
973 rewriter);
974 }
975};
976
977/// Remove duplicate values in the list of retained memrefs. We need to make
978/// sure the corresponding result condition value is replaced properly.
979/// Example:
980/// ```mlir
981/// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
982/// ```
983/// is canonicalized to
984/// ```mlir
985/// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
986/// ```
987struct DeallocRemoveDuplicateRetainedMemrefs
988 : public OpRewritePattern<DeallocOp> {
989 using OpRewritePattern<DeallocOp>::OpRewritePattern;
990
991 LogicalResult matchAndRewrite(DeallocOp deallocOp,
992 PatternRewriter &rewriter) const override {
993 // Unique retained values
994 DenseMap<Value, unsigned> seen;
995 SmallVector<Value> newRetained;
996 SmallVector<unsigned> resultReplacementIdx;
997 unsigned i = 0;
998 for (auto retained : deallocOp.getRetained()) {
999 if (seen.count(Val: retained)) {
1000 resultReplacementIdx.push_back(Elt: seen[retained]);
1001 continue;
1002 }
1003
1004 seen[retained] = i;
1005 newRetained.push_back(Elt: retained);
1006 resultReplacementIdx.push_back(Elt: i++);
1007 }
1008
1009 // Return failure if we don't change anything such that we don't run into an
1010 // infinite loop of pattern applications.
1011 if (newRetained.size() == deallocOp.getRetained().size())
1012 return failure();
1013
1014 // We need to create a new op because the number of results is always the
1015 // same as the number of condition operands.
1016 auto newDeallocOp =
1017 rewriter.create<DeallocOp>(location: deallocOp.getLoc(), args: deallocOp.getMemrefs(),
1018 args: deallocOp.getConditions(), args&: newRetained);
1019 SmallVector<Value> replacements(
1020 llvm::map_range(C&: resultReplacementIdx, F: [&](unsigned idx) {
1021 return newDeallocOp.getUpdatedConditions()[idx];
1022 }));
1023 rewriter.replaceOp(op: deallocOp, newValues: replacements);
1024 return success();
1025 }
1026};
1027
1028/// Erase deallocation operations where the variadic list of memrefs to
1029/// deallocate is empty. Example:
1030/// ```mlir
1031/// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1032/// ```
1033struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1034 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1035
1036 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1037 PatternRewriter &rewriter) const override {
1038 if (deallocOp.getMemrefs().empty()) {
1039 Value constFalse = rewriter.create<arith::ConstantOp>(
1040 location: deallocOp.getLoc(), args: rewriter.getBoolAttr(value: false));
1041 rewriter.replaceOp(
1042 op: deallocOp, newValues: SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1043 constFalse));
1044 return success();
1045 }
1046 return failure();
1047 }
1048};
1049
1050/// Removes memrefs from the deallocation list if their associated condition is
1051/// always 'false'.
1052///
1053/// Example:
1054/// ```
1055/// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1056/// if (%arg2, %false)
1057/// ```
1058/// becomes
1059/// ```
1060/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1061/// ```
1062struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1063 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1064
1065 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1066 PatternRewriter &rewriter) const override {
1067 SmallVector<Value> newMemrefs, newConditions;
1068 for (auto [memref, cond] :
1069 llvm::zip(t: deallocOp.getMemrefs(), u: deallocOp.getConditions())) {
1070 if (!matchPattern(value: cond, pattern: m_Zero())) {
1071 newMemrefs.push_back(Elt: memref);
1072 newConditions.push_back(Elt: cond);
1073 }
1074 }
1075
1076 return updateDeallocIfChanged(deallocOp, memrefs: newMemrefs, conditions: newConditions,
1077 rewriter);
1078 }
1079};
1080
1081/// The `memref.extract_strided_metadata` is often inserted to get the base
1082/// memref if the operand is not already guaranteed to be the result of a memref
1083/// allocation operation. This canonicalization pattern removes this extraction
1084/// operation if the operand is now produced by an allocation operation (e.g.,
1085/// due to other canonicalizations simplifying the IR).
1086///
1087/// Example:
1088/// ```mlir
1089/// %alloc = memref.alloc() : memref<2xi32>
1090/// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1091/// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1092/// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1093/// ```
1094/// is canonicalized to
1095/// ```mlir
1096/// %alloc = memref.alloc() : memref<2xi32>
1097/// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1098/// ```
1099struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1100 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1101
1102 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1103 PatternRewriter &rewriter) const override {
1104 SmallVector<Value> newMemrefs(
1105 llvm::map_range(C: deallocOp.getMemrefs(), F: [&](Value memref) {
1106 auto extractStridedOp =
1107 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1108 if (!extractStridedOp)
1109 return memref;
1110 Value allocMemref = extractStridedOp.getOperand();
1111 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1112 if (!allocOp)
1113 return memref;
1114 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(value: allocMemref))
1115 return allocMemref;
1116 return memref;
1117 }));
1118
1119 return updateDeallocIfChanged(deallocOp, memrefs: newMemrefs,
1120 conditions: deallocOp.getConditions(), rewriter);
1121 }
1122};
1123
1124/// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1125/// other user of the allocated value and the allocating operation can be safely
1126/// removed. If the same value is present multiple times, this pattern relies on
1127/// other canonicalization patterns to remove the duplicate first.
1128///
1129/// Example:
1130/// ```mlir
1131/// %alloc = memref.alloc() : memref<2xi32>
1132/// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1133/// ```
1134/// is canonicalized to
1135/// ```mlir
1136/// bufferization.dealloc (%arg0 : ...) if (%true)
1137/// ```
1138struct RemoveAllocDeallocPairWhenNoOtherUsers
1139 : public OpRewritePattern<DeallocOp> {
1140 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1141
1142 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1143 PatternRewriter &rewriter) const override {
1144 SmallVector<Value> newMemrefs, newConditions;
1145 SmallVector<Operation *> toDelete;
1146 for (auto [memref, cond] :
1147 llvm::zip(t: deallocOp.getMemrefs(), u: deallocOp.getConditions())) {
1148 if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1149 // Check that it is indeed an allocate effect, that the op has no other
1150 // side effects (which would not allow us to remove the op), and that
1151 // there are no other users.
1152 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(value: memref) &&
1153 hasSingleEffect<MemoryEffects::Allocate>(op: allocOp, value: memref) &&
1154 memref.hasOneUse()) {
1155 toDelete.push_back(Elt: allocOp);
1156 continue;
1157 }
1158 }
1159
1160 newMemrefs.push_back(Elt: memref);
1161 newConditions.push_back(Elt: cond);
1162 }
1163
1164 if (failed(Result: updateDeallocIfChanged(deallocOp, memrefs: newMemrefs, conditions: newConditions,
1165 rewriter)))
1166 return failure();
1167
1168 for (Operation *op : toDelete)
1169 rewriter.eraseOp(op);
1170
1171 return success();
1172 }
1173};
1174
1175} // anonymous namespace
1176
1177void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1178 MLIRContext *context) {
1179 populateDeallocOpCanonicalizationPatterns(patterns&: results, context);
1180}
1181
1182void bufferization::populateDeallocOpCanonicalizationPatterns(
1183 RewritePatternSet &patterns, MLIRContext *context) {
1184 patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1185 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1186 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1187 RemoveAllocDeallocPairWhenNoOtherUsers>(arg&: context);
1188}
1189
1190//===----------------------------------------------------------------------===//
1191// TableGen'd op method definitions
1192//===----------------------------------------------------------------------===//
1193
1194#define GET_OP_CLASSES
1195#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
1196

source code of mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp