1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Shape/IR/Shape.h"
14#include "mlir/IR/Dialect.h"
15#include "mlir/IR/Operation.h"
16#include "mlir/IR/PatternMatch.h"
17
18using namespace mlir;
19using namespace mlir::bufferization;
20using namespace mlir::shape;
21
22namespace mlir {
23namespace shape {
24namespace {
25
26/// Bufferization of shape.assuming.
27struct AssumingOpInterface
28 : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
29 shape::AssumingOp> {
30 AliasingOpOperandList
31 getAliasingOpOperands(Operation *op, Value value,
32 const AnalysisState &state) const {
33 // AssumingOps do not have tensor OpOperands. The yielded value can be any
34 // SSA value that is in scope. To allow for use-def chain traversal through
35 // AssumingOps in the analysis, the corresponding yield value is considered
36 // to be aliasing with the result.
37 auto assumingOp = cast<shape::AssumingOp>(op);
38 size_t resultNum = std::distance(first: op->getOpResults().begin(),
39 last: llvm::find(Range: op->getOpResults(), Val: value));
40 // TODO: Support multiple blocks.
41 assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
42 "expected exactly 1 block");
43 auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
44 assumingOp.getDoRegion().front().getTerminator());
45 assert(yieldOp && "expected shape.assuming_yield terminator");
46 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
47 }
48
49 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50 const BufferizationOptions &options,
51 BufferizationState &state) const {
52 auto assumingOp = cast<shape::AssumingOp>(op);
53 assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
54 "only 1 block supported");
55 auto yieldOp = cast<shape::AssumingYieldOp>(
56 assumingOp.getDoRegion().front().getTerminator());
57
58 // Create new op and move over region.
59 TypeRange newResultTypes(yieldOp.getOperands());
60 auto newOp = rewriter.create<shape::AssumingOp>(
61 op->getLoc(), newResultTypes, assumingOp.getWitness());
62 newOp.getDoRegion().takeBody(assumingOp.getRegion());
63
64 // Update all uses of the old op.
65 rewriter.setInsertionPointAfter(newOp);
66 SmallVector<Value> newResults;
67 for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
68 if (isa<TensorType>(it.value())) {
69 newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
70 assumingOp.getLoc(), newOp->getResult(it.index())));
71 } else {
72 newResults.push_back(newOp->getResult(it.index()));
73 }
74 }
75
76 // Replace old op.
77 rewriter.replaceOp(assumingOp, newResults);
78
79 return success();
80 }
81};
82
83/// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
84/// ops, so this is for analysis only.
85struct AssumingYieldOpInterface
86 : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
87 shape::AssumingYieldOp> {
88 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
89 const AnalysisState &state) const {
90 return true;
91 }
92
93 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
94 const AnalysisState &state) const {
95 return false;
96 }
97
98 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
99 const AnalysisState &state) const {
100 assert(isa<shape::AssumingOp>(op->getParentOp()) &&
101 "expected that parent is an AssumingOp");
102 OpResult opResult =
103 op->getParentOp()->getResult(idx: opOperand.getOperandNumber());
104 return {{opResult, BufferRelation::Equivalent}};
105 }
106
107 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
108 const AnalysisState &state) const {
109 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
110 // may be generated inside the block. We should not return/yield allocations
111 // when possible.
112 return true;
113 }
114
115 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
116 const BufferizationOptions &options,
117 BufferizationState &state) const {
118 auto yieldOp = cast<shape::AssumingYieldOp>(op);
119 SmallVector<Value> newResults;
120 for (Value value : yieldOp.getOperands()) {
121 if (isa<TensorType>(value.getType())) {
122 FailureOr<Value> buffer = getBuffer(rewriter, value, options, state);
123 if (failed(buffer))
124 return failure();
125 newResults.push_back(*buffer);
126 } else {
127 newResults.push_back(value);
128 }
129 }
130 replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
131 newResults);
132 return success();
133 }
134};
135
136} // namespace
137} // namespace shape
138} // namespace mlir
139
140void mlir::shape::registerBufferizableOpInterfaceExternalModels(
141 DialectRegistry &registry) {
142 registry.addExtension(extensionFn: +[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
143 shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
144 shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
145 });
146}
147

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp