1//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Fusion pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Linalg/IR/Linalg.h"
14#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15#include "mlir/Dialect/Linalg/Utils/Utils.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/Tensor/IR/Tensor.h"
18#include "mlir/Dialect/Tensor/Utils/Utils.h"
19#include "mlir/IR/AffineExpr.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/Dominance.h"
22#include "mlir/Support/LLVM.h"
23#include "llvm/ADT/SmallBitVector.h"
24#include "llvm/Support/Debug.h"
25
26#define DEBUG_TYPE "linalg-fusion"
27
28using namespace mlir;
29using namespace mlir::linalg;
30
31/// Implements a simple high-level fusion pass on linalg structured operations.
32///
33/// In each block, linalg ops are processed in reverse textual order.
34/// Given a linalg op `O`, fusion occurs by:
35/// 1. inspecting the linalg ops that write into the views read by `O`. There
36/// are 2 cases:
37/// a) buffer case: use the SSA value of the views and a simple alias
38/// analysis on subview ops to determine producer-consumer dependences;
39/// b) tensor case: use SSA use-def chains on extract_slice ops;
40/// 2. greedily fuse the linalg ops that produce the subview/extract_slice.
41/// 3. inspect the fused ops and determine whether they have other remaining
42/// LinalgOp uses. If not, then erase the original producing linalg op.
43///
44/// More advanced use cases, analyses as well as profitability heuristics are
45/// left for future work.
46
47struct ShapeDimension {
48 Value shape;
49 unsigned dimension;
50};
51
52// Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
53// the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
54// guarantees at least one such dimension is found. If multiple candidates exist
55// they must agree by construction (i.e. have the same size) and we just return
56// the first one.
57static ShapeDimension
58getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
59 bool fromSubViewOpOnly = false) {
60 // Iterate over the inputs and outputs in order.
61 // Extract the subranges from the linearized ranges.
62 for (OpOperand &opOperand : op->getOpOperands()) {
63 // The method `getRangeFromOperandShape` requires using SubViewOp or
64 // ExtractSliceOps. If the value isn't defined from there continue.
65 // todo: The method should be adapted to get the values from
66 // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
67 // currently returns a `linalg.range`. The fix here is to move this op to
68 // `std` dialect and add the method to `ViewInterface`.
69 if (fromSubViewOpOnly &&
70 !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>(
71 Val: opOperand.get().getDefiningOp()))
72 continue;
73
74 AffineMap map = op.getMatchingIndexingMap(opOperand: &opOperand);
75 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
76 << opOperand.getOperandNumber() << "\n");
77 LLVM_DEBUG(llvm::dbgs()
78 << "getShapeDefiningLoopRange map: " << map << "\n");
79 for (const auto &en : llvm::enumerate(First: map.getResults())) {
80 auto dimExpr = dyn_cast<AffineDimExpr>(Val: en.value());
81 if (!dimExpr)
82 continue;
83 if (loopDepth == cast<AffineDimExpr>(Val: en.value()).getPosition()) {
84 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
85 << loopDepth << "\n");
86 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
87 << opOperand.get() << "\n");
88 return ShapeDimension{.shape: opOperand.get(),
89 .dimension: static_cast<unsigned>(en.index())};
90 }
91 }
92 }
93 llvm_unreachable("Expect to be able to extract a shape defining loop range");
94}
95
96static SmallVector<Value> getTiledOperands(LinalgOp producer) {
97 return producer->getOperands();
98}
99
100/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
101/// provides the loop range information for the fused loops. The rest are
102/// obtained from the producer itself, since they are not tiled + fused.
103static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
104 const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
105 SmallVector<OpFoldResult> ivs, tileSizes, sizeBounds;
106 SmallVector<Range> loopRanges;
107 Location loc = producer.getLoc();
108
109 for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
110 auto shapeDim = getShapeDefiningLoopRange(op: producer, loopDepth: i);
111 OpFoldResult dim =
112 createFoldedDimOp(b, loc, val: shapeDim.shape, dim: shapeDim.dimension);
113 sizeBounds.push_back(Elt: dim);
114 auto it = fusedLoopsAndRanges.find(Val: i);
115 if (it != fusedLoopsAndRanges.end()) {
116 ivs.push_back(Elt: it->second.offset);
117 tileSizes.push_back(Elt: it->second.size);
118 loopRanges.push_back(Elt: it->second);
119 LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
120 << loopRanges.back() << "\n");
121 } else {
122 tileSizes.push_back(Elt: b.getIndexAttr(value: 0));
123 loopRanges.push_back(Elt: Range{.offset: b.getIndexAttr(value: 0), .size: dim, .stride: b.getIndexAttr(value: 1)});
124 LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
125 << loopRanges.back() << "\n");
126 }
127 }
128
129 SmallVector<Value, 8> clonedShapes;
130 clonedShapes.reserve(N: producer->getNumOperands());
131
132 // Compute subranges for all tensor input/output operands.
133 clonedShapes.append(RHS: makeTiledShapes(
134 builder&: b, loc, linalgOp: producer, valuesToTile: getTiledOperands(producer), ivs, tileSizes, sizeBounds,
135 /**omitPartialTileCheck=*/omitPartialTileCheck: false));
136
137 // Take result types from the tiled init operands.
138 MutableOperandRange producerDpsInits = producer.getDpsInitsMutable();
139 SmallVector<Type, 4> resultTypes;
140 resultTypes.reserve(N: producer->getNumResults());
141 int64_t firstInitOperandIdx =
142 producerDpsInits.getAsOperandRange().getBeginOperandIndex();
143 for (int64_t i = 0, e = producer->getNumResults(); i < e; ++i) {
144 resultTypes.push_back(Elt: clonedShapes[firstInitOperandIdx + i].getType());
145 }
146
147 // Clone the producer with new operands and result types.
148 LinalgOp clonedOp = clone(b, op: producer, newResultTypes: resultTypes, newOperands: clonedShapes);
149
150 // Shift all IndexOp results by the tile offset.
151 SmallVector<OpFoldResult> allIvs = llvm::to_vector(
152 Range: llvm::map_range(C&: loopRanges, F: [&](Range range) { return range.offset; }));
153 offsetIndices(b, linalgOp: clonedOp, offests: allIvs);
154
155 return clonedOp;
156}
157
158/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
159/// expected to be defined by a subview op or an extract_slice op.
160static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
161 Value shapedOperand, unsigned dim) {
162 Operation *shapeProducingOp = shapedOperand.getDefiningOp();
163 if (auto subViewOp = dyn_cast<memref::SubViewOp>(Val: shapeProducingOp))
164 return subViewOp.getOrCreateRanges(b, loc)[dim];
165 if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(Val: shapeProducingOp))
166 return sliceOp.getOrCreateRanges(b, loc)[dim];
167 llvm_unreachable("SubviewOp or ExtractSliceOp expected");
168}
169
170/// Fuses the producer into the loop immediately enclosing the consumer.
171/// This is achieved by "recomputing" the producer at the time it
172/// is needed just before the consumer.
173static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
174 OpOperand &consumerOpOperand) {
175 LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
176 DenseMap<unsigned, Range> fusedLoopsAndRanges;
177 Value shapedOperand = consumerOpOperand.get();
178 for (const auto &en : llvm::enumerate(First: producerMap.getResults())) {
179 unsigned posInProducerLoop = cast<AffineDimExpr>(Val: en.value()).getPosition();
180 fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
181 b, loc: consumerOpOperand.getOwner()->getLoc(), shapedOperand, dim: en.index());
182 }
183 return fuse(b, producer: producerOp, fusedLoopsAndRanges);
184}
185
186/// Walk back use-def chain through scf::For yields.
187/// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
188
189// TODO(ravishankarm, ntv): This can be moved into the dependence graphs
190// dependence tracking since the dependence tracking is similar to what is done
191// w.r.t to buffers.
192static void getProducerOfTensor(Value tensor, OpResult &opResult) {
193 if (!isa<RankedTensorType>(Val: tensor.getType()))
194 return;
195
196 while (true) {
197 LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
198 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
199 opResult = cast<OpResult>(Val&: tensor);
200 return;
201 }
202 if (auto sliceOp = tensor.getDefiningOp<tensor::ExtractSliceOp>()) {
203 tensor = sliceOp.getSource();
204 continue;
205 }
206 if (auto blockArg = dyn_cast<BlockArgument>(Val&: tensor)) {
207 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
208 tensor = forOp.getInitArgs()[blockArg.getArgNumber()];
209 continue;
210 }
211 }
212 return;
213 }
214}
215
216FailureOr<FusionInfo>
217mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
218 Value inputTensor = consumerOpOperand.get();
219 OpResult producerOpResult;
220 getProducerOfTensor(tensor: inputTensor, opResult&: producerOpResult);
221 if (!producerOpResult) {
222 LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
223 return failure();
224 }
225 return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
226}
227
228FailureOr<FusionInfo>
229mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
230 OpOperand &consumerOpOperand) {
231 auto producerOp = dyn_cast<LinalgOp>(Val: producerOpResult.getOwner());
232 if (!producerOp)
233 return failure();
234
235 LinalgOp consumerOp = dyn_cast<LinalgOp>(Val: consumerOpOperand.getOwner());
236 if (!consumerOp)
237 return failure();
238
239 Value inputTensor = consumerOpOperand.get();
240
241 // Must be an extract_slice op to guarantee there are loops we can fuse into.
242 auto sliceOp = inputTensor.getDefiningOp<tensor::ExtractSliceOp>();
243 if (!sliceOp) {
244 LLVM_DEBUG(llvm::dbgs()
245 << "\nNot fusable, not an extract_slice op: " << inputTensor);
246 return failure();
247 }
248
249 // If producer is already in the same block as consumer, we are done.
250 if (consumerOpOperand.get().getParentBlock() ==
251 producerOpResult.getParentBlock())
252 return failure();
253
254 // Insert fused `producer` just before `consumer`.
255 OpBuilder::InsertionGuard g(b);
256 b.setInsertionPoint(consumerOp);
257 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
258 OpOperand *opOperand =
259 producerOp.getDpsInitOperand(i: producerOpResult.getResultNumber());
260 LinalgOp fusedProducer =
261 fuse(b, producerOp, producerMap: producerOp.getMatchingIndexingMap(opOperand),
262 consumerOpOperand);
263
264 // Replace use.
265 Value def = fusedProducer->getResult(idx: producerOpResult.getResultNumber());
266 Type consumerType = consumerOpOperand.get().getType();
267 // Check if rank-reduction occurred as part of the extract_slice. If yes,
268 // collapse the dropped dimensions.
269 if (cast<ShapedType>(Val&: consumerType).getRank() !=
270 cast<ShapedType>(Val: def.getType()).getRank()) {
271 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
272 def =
273 tensor::dropGivenUnitDims(b, loc: fusedProducer.getLoc(), src: def, dropDims: droppedDims);
274 }
275 // Canonicalizations are not guaranteed to have happened before constructing
276 // `fusedProducer`. In the tensor case this can result in temporary type
277 // mismatches. Insert a `tensor.cast` op to propagate the transformation
278 // invariant that types are compatible.
279 if (consumerType != def.getType())
280 def = b.create<tensor::CastOp>(location: fusedProducer.getLoc(), args&: consumerType, args&: def);
281 consumerOpOperand.set(def);
282 return FusionInfo{.originalProducer: cast<LinalgOp>(Val: producerOpResult.getOwner()), .fusedProducer: fusedProducer};
283}
284

source code of mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp