1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" |
10 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" |
13 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
14 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
16 | #include "mlir/IR/Dialect.h" |
17 | #include "mlir/IR/Operation.h" |
18 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace linalg; |
22 | using namespace mlir::bufferization; |
23 | |
24 | namespace { |
25 | |
26 | /// Generic conversion for any DestinationStyleOpInterface on tensors. |
27 | static LogicalResult bufferizeDestinationStyleOpInterface( |
28 | RewriterBase &rewriter, DestinationStyleOpInterface op, |
29 | const BufferizationOptions &options, const BufferizationState &state) { |
30 | // Take a guard before anything else. |
31 | OpBuilder::InsertionGuard g(rewriter); |
32 | rewriter.setInsertionPoint(op); |
33 | |
34 | // Nothing to do. This op is already bufferized. |
35 | if (op.hasPureBufferSemantics()) |
36 | return success(); |
37 | |
38 | // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need |
39 | // basis. |
40 | if (!op.hasPureTensorSemantics()) |
41 | return op->emitError() << "op does not have pure tensor semantics" ; |
42 | |
43 | // New input operands for the cloned op. |
44 | SmallVector<Value> newInputBuffers; |
45 | newInputBuffers.reserve(N: op.getNumDpsInputs()); |
46 | for (OpOperand *opOperand : op.getDpsInputOperands()) { |
47 | if (op.isScalar(opOperand)) { |
48 | newInputBuffers.push_back(opOperand->get()); |
49 | continue; |
50 | } |
51 | FailureOr<Value> buffer = |
52 | getBuffer(rewriter, opOperand->get(), options, state); |
53 | if (failed(buffer)) |
54 | return failure(); |
55 | newInputBuffers.push_back(*buffer); |
56 | } |
57 | |
58 | // New output operands for the cloned op. |
59 | SmallVector<Value> newOutputBuffers; |
60 | for (OpResult opResult : op->getOpResults()) { |
61 | OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); |
62 | FailureOr<Value> resultBuffer = |
63 | getBuffer(rewriter, opOperand->get(), options, state); |
64 | if (failed(resultBuffer)) |
65 | return failure(); |
66 | newOutputBuffers.push_back(*resultBuffer); |
67 | } |
68 | |
69 | // Merge input/output operands. |
70 | SmallVector<Value> newOperands = newInputBuffers; |
71 | newOperands.append(in_start: newOutputBuffers.begin(), in_end: newOutputBuffers.end()); |
72 | |
73 | // Set insertion point now that potential alloc/dealloc are introduced. |
74 | rewriter.setInsertionPoint(op); |
75 | // Clone the op, but use the new operands. Move the existing block into the |
76 | // new op. Since the new op does not have any tensor results, it does not |
77 | // return anything. |
78 | assert(op->getNumRegions() == 1 && "expected that op has 1 region" ); |
79 | OperationState opState(op->getLoc(), op->getName(), newOperands, TypeRange{}, |
80 | op->getAttrs()); |
81 | opState.addRegion(); |
82 | Operation *newOp = Operation::create(state: opState); |
83 | newOp->getRegion(index: 0).getBlocks().splice(newOp->getRegion(index: 0).begin(), |
84 | op->getRegion(0).getBlocks()); |
85 | |
86 | // We don't want the rewriter tracks an incomplete operation, so insert new |
87 | // operation after op was fully constructed. |
88 | rewriter.insert(op: newOp); |
89 | |
90 | // Replace the results of the old op with the new output buffers. |
91 | replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); |
92 | |
93 | return success(); |
94 | } |
95 | |
96 | /// Bufferization of linalg.generic. Replace with a new linalg.generic that |
97 | /// operates entirely on memrefs. |
98 | template <typename OpTy> |
99 | struct LinalgOpInterface |
100 | : public DstBufferizableOpInterfaceExternalModel<LinalgOpInterface<OpTy>, |
101 | OpTy> { |
102 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
103 | const AnalysisState &state) const { |
104 | // Operand is read if it is used in the computation. |
105 | auto linalgOp = cast<linalg::LinalgOp>(op); |
106 | return linalgOp.payloadUsesValueFromOperand(&opOperand); |
107 | } |
108 | |
109 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
110 | const AnalysisState &state) const { |
111 | // Operand is written to if it is not an input/init. |
112 | auto dpsOp = cast<DestinationStyleOpInterface>(op); |
113 | return dpsOp.isDpsInit(&opOperand); |
114 | } |
115 | |
116 | bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state, |
117 | ArrayRef<OpOperand *> opOperands) const { |
118 | auto linalgOp = cast<linalg::LinalgOp>(op); |
119 | |
120 | // Accesses into sparse data structures are not necessarily elementwise. |
121 | if (sparse_tensor::hasAnySparseOperand(op: linalgOp)) |
122 | return false; |
123 | |
124 | // All loops must be parallel. |
125 | if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) |
126 | return false; |
127 | |
128 | // All index maps of tensors must be identity maps. |
129 | SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); |
130 | assert(linalgOp->getNumOperands() == indexingMaps.size() && |
131 | "unexpected number of indexing maps" ); |
132 | for (auto [operand, map] : |
133 | llvm::zip(linalgOp->getOpOperands(), indexingMaps)) { |
134 | // Non-tensors do not participate in bufferization, so they can be |
135 | // ignored. |
136 | if (!isa<RankedTensorType, MemRefType>(operand.get().getType())) |
137 | continue; |
138 | // Only consider operands in `opOperands`. |
139 | if (!llvm::is_contained(opOperands, &operand)) |
140 | continue; |
141 | // TODO: This could be generalized to other indexing maps. (All indexing |
142 | // must be the same.) |
143 | if (!map.isIdentity()) |
144 | return false; |
145 | } |
146 | |
147 | return true; |
148 | } |
149 | |
150 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
151 | const BufferizationOptions &options, |
152 | BufferizationState &state) const { |
153 | return bufferizeDestinationStyleOpInterface( |
154 | rewriter, cast<DestinationStyleOpInterface>(op), options, state); |
155 | } |
156 | }; |
157 | |
158 | /// Helper structure that iterates over all LinalgOps in `OpTys` and registers |
159 | /// the `BufferizableOpInterface` with each of them. |
160 | template <typename... Ops> |
161 | struct LinalgOpInterfaceHelper { |
162 | static void registerOpInterface(MLIRContext *ctx) { |
163 | (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...); |
164 | } |
165 | }; |
166 | |
167 | struct SoftmaxOpInterface |
168 | : public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface, |
169 | linalg::SoftmaxOp> { |
170 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
171 | const AnalysisState &state) const { |
172 | // Output operand is not read. |
173 | auto softmaxOp = cast<linalg::SoftmaxOp>(op); |
174 | return &opOperand == &softmaxOp.getInputMutable(); |
175 | } |
176 | |
177 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
178 | const BufferizationOptions &options, |
179 | BufferizationState &state) const { |
180 | auto softmaxOp = cast<linalg::SoftmaxOp>(op); |
181 | FailureOr<Value> inputBuffer = |
182 | getBuffer(rewriter, softmaxOp.getInput(), options, state); |
183 | if (failed(Result: inputBuffer)) |
184 | return failure(); |
185 | FailureOr<Value> outputBuffer = |
186 | getBuffer(rewriter, softmaxOp.getOutput(), options, state); |
187 | if (failed(Result: outputBuffer)) |
188 | return failure(); |
189 | rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(), |
190 | /*result=*/TypeRange(), *inputBuffer, |
191 | *outputBuffer, softmaxOp.getDimension()); |
192 | replaceOpWithBufferizedValues(rewriter, op, values: *outputBuffer); |
193 | return success(); |
194 | } |
195 | }; |
196 | } // namespace |
197 | |
198 | void mlir::linalg::registerBufferizableOpInterfaceExternalModels( |
199 | DialectRegistry ®istry) { |
200 | registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { |
201 | // Register all Linalg structured ops. `LinalgOp` is an interface and it is |
202 | // not possible to attach an external interface to an existing interface. |
203 | // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. |
204 | LinalgOpInterfaceHelper< |
205 | #define GET_OP_LIST |
206 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
207 | |
208 | >::registerOpInterface(ctx); |
209 | |
210 | SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx); |
211 | }); |
212 | } |
213 | |