1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" |
10 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" |
13 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
14 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
16 | #include "mlir/IR/Dialect.h" |
17 | #include "mlir/IR/Operation.h" |
18 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace linalg; |
22 | using namespace mlir::bufferization; |
23 | |
24 | namespace { |
25 | |
26 | /// Generic conversion for any DestinationStyleOpInterface on tensors. |
27 | static LogicalResult |
28 | bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, |
29 | DestinationStyleOpInterface op, |
30 | const BufferizationOptions &options) { |
31 | // Take a guard before anything else. |
32 | OpBuilder::InsertionGuard g(rewriter); |
33 | rewriter.setInsertionPoint(op); |
34 | |
35 | // Nothing to do. This op is already bufferized. |
36 | if (op.hasPureBufferSemantics()) |
37 | return success(); |
38 | |
39 | // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need |
40 | // basis. |
41 | if (!op.hasPureTensorSemantics()) |
42 | return op->emitError() << "op does not have pure tensor semantics" ; |
43 | |
44 | // New input operands for the cloned op. |
45 | SmallVector<Value> newInputBuffers; |
46 | newInputBuffers.reserve(N: op.getNumDpsInputs()); |
47 | for (OpOperand *opOperand : op.getDpsInputOperands()) { |
48 | if (op.isScalar(opOperand)) { |
49 | newInputBuffers.push_back(opOperand->get()); |
50 | continue; |
51 | } |
52 | FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options); |
53 | if (failed(buffer)) |
54 | return failure(); |
55 | newInputBuffers.push_back(*buffer); |
56 | } |
57 | |
58 | // New output operands for the cloned op. |
59 | SmallVector<Value> newOutputBuffers; |
60 | for (OpResult opResult : op->getOpResults()) { |
61 | OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); |
62 | FailureOr<Value> resultBuffer = |
63 | getBuffer(rewriter, opOperand->get(), options); |
64 | if (failed(resultBuffer)) |
65 | return failure(); |
66 | newOutputBuffers.push_back(*resultBuffer); |
67 | } |
68 | |
69 | // Merge input/output operands. |
70 | SmallVector<Value> newOperands = newInputBuffers; |
71 | newOperands.append(in_start: newOutputBuffers.begin(), in_end: newOutputBuffers.end()); |
72 | |
73 | // Set insertion point now that potential alloc/dealloc are introduced. |
74 | rewriter.setInsertionPoint(op); |
75 | // Clone the op, but use the new operands. Move the existing block into the |
76 | // new op. Since the new op does not have any tensor results, it does not |
77 | // return anything. |
78 | assert(op->getNumRegions() == 1 && "expected that op has 1 region" ); |
79 | auto newOp = cast<DestinationStyleOpInterface>(cloneWithoutRegions( |
80 | rewriter, op, /*newResultTypes=*/TypeRange{}, newOperands)); |
81 | rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), |
82 | newOp->getRegion(0).begin()); |
83 | |
84 | // Replace the results of the old op with the new output buffers. |
85 | replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); |
86 | |
87 | return success(); |
88 | } |
89 | |
90 | /// Bufferization of linalg.generic. Replace with a new linalg.generic that |
91 | /// operates entirely on memrefs. |
92 | template <typename OpTy> |
93 | struct LinalgOpInterface |
94 | : public DstBufferizableOpInterfaceExternalModel<LinalgOpInterface<OpTy>, |
95 | OpTy> { |
96 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
97 | const AnalysisState &state) const { |
98 | // Operand is read if it is used in the computation. |
99 | auto linalgOp = cast<linalg::LinalgOp>(op); |
100 | return linalgOp.payloadUsesValueFromOperand(&opOperand); |
101 | } |
102 | |
103 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
104 | const AnalysisState &state) const { |
105 | // Operand is written to if it is not an input/init. |
106 | auto dpsOp = cast<DestinationStyleOpInterface>(op); |
107 | return dpsOp.isDpsInit(&opOperand); |
108 | } |
109 | |
110 | bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state, |
111 | ArrayRef<OpOperand *> opOperands) const { |
112 | auto linalgOp = cast<linalg::LinalgOp>(op); |
113 | |
114 | // Accesses into sparse data structures are not necessarily elementwise. |
115 | if (sparse_tensor::hasAnySparseOperand(op: linalgOp)) |
116 | return false; |
117 | |
118 | // All loops must be parallel. |
119 | if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) |
120 | return false; |
121 | |
122 | // All index maps of tensors must be identity maps. |
123 | SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); |
124 | assert(linalgOp->getNumOperands() == indexingMaps.size() && |
125 | "unexpected number of indexing maps" ); |
126 | for (auto [operand, map] : |
127 | llvm::zip(linalgOp->getOpOperands(), indexingMaps)) { |
128 | // Non-tensors do not participate in bufferization, so they can be |
129 | // ignored. |
130 | if (!isa<RankedTensorType, MemRefType>(operand.get().getType())) |
131 | continue; |
132 | // Only consider operands in `opOperands`. |
133 | if (!llvm::is_contained(opOperands, &operand)) |
134 | continue; |
135 | // TODO: This could be generalized to other indexing maps. (All indexing |
136 | // must be the same.) |
137 | if (!map.isIdentity()) |
138 | return false; |
139 | } |
140 | |
141 | return true; |
142 | } |
143 | |
144 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
145 | const BufferizationOptions &options) const { |
146 | return bufferizeDestinationStyleOpInterface( |
147 | rewriter, cast<DestinationStyleOpInterface>(op), options); |
148 | } |
149 | }; |
150 | |
151 | /// Helper structure that iterates over all LinalgOps in `OpTys` and registers |
152 | /// the `BufferizableOpInterface` with each of them. |
153 | template <typename... Ops> |
154 | struct LinalgOpInterfaceHelper { |
155 | static void registerOpInterface(MLIRContext *ctx) { |
156 | (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...); |
157 | } |
158 | }; |
159 | } // namespace |
160 | |
161 | void mlir::linalg::registerBufferizableOpInterfaceExternalModels( |
162 | DialectRegistry ®istry) { |
163 | registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { |
164 | // Register all Linalg structured ops. `LinalgOp` is an interface and it is |
165 | // not possible to attach an external interface to an existing interface. |
166 | // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. |
167 | LinalgOpInterfaceHelper< |
168 | #define GET_OP_LIST |
169 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
170 | >::registerOpInterface(ctx); |
171 | }); |
172 | } |
173 | |