1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
10#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
12#include "mlir/Dialect/Linalg/IR/Linalg.h"
13#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14#include "mlir/IR/Dialect.h"
15#include "mlir/IR/Operation.h"
16#include "mlir/Interfaces/DestinationStyleOpInterface.h"
17
18using namespace mlir;
19using namespace linalg;
20using namespace mlir::bufferization;
21
22namespace {
23
24/// Generic conversion for any DestinationStyleOpInterface on tensors.
25static LogicalResult bufferizeDestinationStyleOpInterface(
26 RewriterBase &rewriter, DestinationStyleOpInterface op,
27 const BufferizationOptions &options, const BufferizationState &state) {
28 // Take a guard before anything else.
29 OpBuilder::InsertionGuard g(rewriter);
30 rewriter.setInsertionPoint(op);
31
32 // Nothing to do. This op is already bufferized.
33 if (op.hasPureBufferSemantics())
34 return success();
35
36 // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
37 // basis.
38 if (!op.hasPureTensorSemantics())
39 return op->emitError() << "op does not have pure tensor semantics";
40
41 // New input operands for the cloned op.
42 SmallVector<Value> newInputBuffers;
43 newInputBuffers.reserve(N: op.getNumDpsInputs());
44 for (OpOperand *opOperand : op.getDpsInputOperands()) {
45 if (op.isScalar(opOperand)) {
46 newInputBuffers.push_back(Elt: opOperand->get());
47 continue;
48 }
49 FailureOr<Value> buffer =
50 getBuffer(rewriter, value: opOperand->get(), options, state);
51 if (failed(Result: buffer))
52 return failure();
53 newInputBuffers.push_back(Elt: *buffer);
54 }
55
56 // New output operands for the cloned op.
57 SmallVector<Value> newOutputBuffers;
58 for (OpResult opResult : op->getOpResults()) {
59 OpOperand *opOperand = op.getDpsInitOperand(i: opResult.getResultNumber());
60 FailureOr<Value> resultBuffer =
61 getBuffer(rewriter, value: opOperand->get(), options, state);
62 if (failed(Result: resultBuffer))
63 return failure();
64 newOutputBuffers.push_back(Elt: *resultBuffer);
65 }
66
67 // Merge input/output operands.
68 SmallVector<Value> newOperands = newInputBuffers;
69 newOperands.append(in_start: newOutputBuffers.begin(), in_end: newOutputBuffers.end());
70
71 // Set insertion point now that potential alloc/dealloc are introduced.
72 rewriter.setInsertionPoint(op);
73 // Clone the op, but use the new operands. Move the existing block into the
74 // new op. Since the new op does not have any tensor results, it does not
75 // return anything.
76 assert(op->getNumRegions() == 1 && "expected that op has 1 region");
77 OperationState opState(op->getLoc(), op->getName(), newOperands, TypeRange{},
78 op->getAttrs());
79 opState.addRegion();
80 Operation *newOp = Operation::create(state: opState);
81 newOp->getRegion(index: 0).getBlocks().splice(where: newOp->getRegion(index: 0).begin(),
82 L2&: op->getRegion(index: 0).getBlocks());
83
84 // We don't want the rewriter tracks an incomplete operation, so insert new
85 // operation after op was fully constructed.
86 rewriter.insert(op: newOp);
87
88 // Replace the results of the old op with the new output buffers.
89 replaceOpWithBufferizedValues(rewriter, op, values: newOutputBuffers);
90
91 return success();
92}
93
94/// Bufferization of linalg.generic. Replace with a new linalg.generic that
95/// operates entirely on memrefs.
96template <typename OpTy>
97struct LinalgOpInterface
98 : public DstBufferizableOpInterfaceExternalModel<LinalgOpInterface<OpTy>,
99 OpTy> {
100 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
101 const AnalysisState &state) const {
102 // Operand is read if it is used in the computation.
103 auto linalgOp = cast<linalg::LinalgOp>(Val: op);
104 return linalgOp.payloadUsesValueFromOperand(opOperand: &opOperand);
105 }
106
107 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
108 const AnalysisState &state) const {
109 // Operand is written to if it is not an input/init.
110 auto dpsOp = cast<DestinationStyleOpInterface>(Val: op);
111 return dpsOp.isDpsInit(opOperand: &opOperand);
112 }
113
114 bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state,
115 ArrayRef<OpOperand *> opOperands) const {
116 auto linalgOp = cast<linalg::LinalgOp>(Val: op);
117
118 // Accesses into sparse data structures are not necessarily elementwise.
119 if (sparse_tensor::hasAnySparseOperand(op: linalgOp))
120 return false;
121
122 // All loops must be parallel.
123 if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
124 return false;
125
126 // All index maps of tensors must be identity maps.
127 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
128 assert(linalgOp->getNumOperands() == indexingMaps.size() &&
129 "unexpected number of indexing maps");
130 for (auto [operand, map] :
131 llvm::zip(t: linalgOp->getOpOperands(), u&: indexingMaps)) {
132 // Non-tensors do not participate in bufferization, so they can be
133 // ignored.
134 if (!isa<RankedTensorType, MemRefType>(Val: operand.get().getType()))
135 continue;
136 // Only consider operands in `opOperands`.
137 if (!llvm::is_contained(Range&: opOperands, Element: &operand))
138 continue;
139 // TODO: This could be generalized to other indexing maps. (All indexing
140 // must be the same.)
141 if (!map.isIdentity())
142 return false;
143 }
144
145 return true;
146 }
147
148 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
149 const BufferizationOptions &options,
150 BufferizationState &state) const {
151 return bufferizeDestinationStyleOpInterface(
152 rewriter, op: cast<DestinationStyleOpInterface>(Val: op), options, state);
153 }
154};
155
156/// Helper structure that iterates over all LinalgOps in `OpTys` and registers
157/// the `BufferizableOpInterface` with each of them.
158template <typename... Ops>
159struct LinalgOpInterfaceHelper {
160 static void registerOpInterface(MLIRContext *ctx) {
161 (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
162 }
163};
164
165struct SoftmaxOpInterface
166 : public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
167 linalg::SoftmaxOp> {
168 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
169 const AnalysisState &state) const {
170 // Output operand is not read.
171 auto softmaxOp = cast<linalg::SoftmaxOp>(Val: op);
172 return &opOperand == &softmaxOp.getInputMutable();
173 }
174
175 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
176 const BufferizationOptions &options,
177 BufferizationState &state) const {
178 auto softmaxOp = cast<linalg::SoftmaxOp>(Val: op);
179 FailureOr<Value> inputBuffer =
180 getBuffer(rewriter, value: softmaxOp.getInput(), options, state);
181 if (failed(Result: inputBuffer))
182 return failure();
183 FailureOr<Value> outputBuffer =
184 getBuffer(rewriter, value: softmaxOp.getOutput(), options, state);
185 if (failed(Result: outputBuffer))
186 return failure();
187 rewriter.create<linalg::SoftmaxOp>(location: softmaxOp.getLoc(),
188 /*result=*/args: TypeRange(), args&: *inputBuffer,
189 args&: *outputBuffer, args: softmaxOp.getDimension());
190 replaceOpWithBufferizedValues(rewriter, op, values: *outputBuffer);
191 return success();
192 }
193};
194} // namespace
195
196void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
197 DialectRegistry &registry) {
198 registry.addExtension(extensionFn: +[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
199 // Register all Linalg structured ops. `LinalgOp` is an interface and it is
200 // not possible to attach an external interface to an existing interface.
201 // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
202 LinalgOpInterfaceHelper<
203#define GET_OP_LIST
204#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
205
206 >::registerOpInterface(ctx);
207
208 SoftmaxOp::attachInterface<SoftmaxOpInterface>(context&: *ctx);
209 });
210}
211

source code of mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp