1//===- EraseUnusedOperandsAndResults.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Linalg/IR/Linalg.h"
12
13using namespace mlir;
14using namespace mlir::linalg;
15
16/// Return `true` if the `result` of an operation `genericOp` is dead.
17static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
18 if (!result.use_empty())
19 return false;
20 // If out operand not used in payload, we can drop it.
21 OpOperand *outputOpOperand =
22 genericOp.getDpsInitOperand(result.getResultNumber());
23 if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
24 return true;
25
26 // The out operand that is part of a payload can be dropped if
27 // these conditions are met:
28 // - Result from out operand is dead.
29 // - User of arg is yield.
30 // - outArg data is not being used by other outArgs.
31
32 // Check block arg and cycle from out operand has a single use.
33 BlockArgument outputArg =
34 genericOp.getRegionOutputArgs()[result.getResultNumber()];
35 if (!outputArg.hasOneUse())
36 return false;
37 Operation *argUserOp = *outputArg.user_begin();
38
39 // Check argUser has no other use.
40 if (!argUserOp->use_empty())
41 return false;
42
43 // Check that argUser is a yield.
44 auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
45 if (!yieldOp)
46 return false;
47
48 // Check outArg data is not being used by other outArgs.
49 if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
50 return false;
51
52 return true;
53}
54
55namespace {
56
57struct DeduplicateAndRemoveDeadOperandsAndResults
58 : public OpRewritePattern<GenericOp> {
59 DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
60 bool removeOutputs)
61 : OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
62
63 LogicalResult matchAndRewrite(GenericOp genericOp,
64 PatternRewriter &rewriter) const override {
65 // Create a map from argument position in the original op to the argument
66 // position in the new op. If the argument is dropped it wont have an entry.
67 SmallVector<OpOperand *> droppedOpOperands;
68
69 // Information needed to build the new op.
70 SmallVector<Value> newInputOperands, newOutputOperands;
71 SmallVector<AffineMap> newIndexingMaps;
72
73 // Gather information about duplicate input operands.
74 llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
75 deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
76 newIndexingMaps);
77
78 // Gather information about the dropped outputs.
79 llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
80 deduplicateOutputOperands(genericOp, droppedOpOperands,
81 newOutputOperands, newIndexingMaps);
82
83 // Check if there is any change to operands.
84 if (newInputOperands.size() + newOutputOperands.size() ==
85 genericOp->getNumOperands())
86 return failure();
87
88 // Create the new op with the body being empty.
89 Location loc = genericOp.getLoc();
90 SmallVector<Type> newResultTypes;
91 for (Value v : newOutputOperands)
92 if (isa<TensorType>(Val: v.getType()))
93 newResultTypes.push_back(Elt: v.getType());
94 auto newOp = rewriter.create<GenericOp>(
95 loc, newResultTypes, newInputOperands, newOutputOperands,
96 rewriter.getAffineMapArrayAttr(newIndexingMaps),
97 genericOp.getIteratorTypes(), genericOp.getDocAttr(),
98 genericOp.getLibraryCallAttr(),
99 [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
100 return;
101 });
102 // Copy over unknown attributes. They might be load bearing for some flow.
103 ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
104 for (NamedAttribute kv : genericOp->getAttrs())
105 if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
106 newOp->setAttr(kv.getName(), kv.getValue());
107
108 // Fix up the payload of the canonicalized operation.
109 populateOpPayload(genericOp, newOp, origInsToNewInsPos,
110 origOutsToNewOutsPos, rewriter);
111
112 // Replace all live uses of the op.
113 SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
114 for (const auto &result : llvm::enumerate(genericOp.getResults())) {
115 auto it = origOutsToNewOutsPos.find(result.index());
116 if (it == origOutsToNewOutsPos.end())
117 continue;
118 replacementsVals[result.index()] = newOp.getResult(it->second);
119 }
120 rewriter.replaceOp(genericOp, replacementsVals);
121 return success();
122 }
123
124private:
125 /// If unset, outputs are not modified by this pattern.
126 bool removeOutputs;
127
128 // Deduplicate input operands, and return the
129 // - Mapping from operand position in the original op, to operand position in
130 // the canonicalized op.
131 // - The preserved input operands list (by reference).
132 llvm::SmallDenseMap<unsigned, unsigned>
133 deduplicateInputOperands(GenericOp genericOp,
134 SmallVector<OpOperand *> &droppedOpOperands,
135 SmallVector<Value> &newInputOperands,
136 SmallVector<AffineMap> &newIndexingMaps) const {
137 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
138 llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
139 for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
140 OpOperand *inputOpOperand = en.value();
141 // Check if operand is dead and if dropping the indexing map makes the
142 // loops to shape computation invalid.
143 if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
144 // Add the current operands to the list of potentially droppable
145 // operands. If it cannot be dropped, this needs to be popped back.
146 droppedOpOperands.push_back(inputOpOperand);
147 if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
148 continue;
149 droppedOpOperands.pop_back();
150 }
151
152 // Check if this operand is a duplicate.
153 AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
154 auto it = dedupedInputs.find(
155 std::make_pair(inputOpOperand->get(), indexingMap));
156 if (it != dedupedInputs.end()) {
157 origToNewPos[en.index()] = it->second;
158 droppedOpOperands.push_back(inputOpOperand);
159 continue;
160 }
161
162 // This is a preserved argument.
163 origToNewPos[en.index()] = newInputOperands.size();
164 dedupedInputs[{inputOpOperand->get(), indexingMap}] =
165 newInputOperands.size();
166 newInputOperands.push_back(inputOpOperand->get());
167 newIndexingMaps.push_back(indexingMap);
168 }
169 return origToNewPos;
170 }
171
172 // Deduplicate output operands, and return the
173 // - Mapping from operand position in the original op, to operand position in
174 // the canonicalized op.
175 // - The preserved output operands list (by reference).
176 llvm::SmallDenseMap<unsigned, unsigned>
177 deduplicateOutputOperands(GenericOp genericOp,
178 SmallVector<OpOperand *> &droppedOpOperands,
179 SmallVector<Value> &newOutputOperands,
180 SmallVector<AffineMap> &newIndexingMaps) const {
181 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
182 llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
183 dedupedOutpts;
184 // If the op doesn't have tensor semantics or outputs should not be removed,
185 // keep all the outputs as preserved.
186 if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
187 for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
188 origToNewPos[en.index()] = newOutputOperands.size();
189 newOutputOperands.push_back(en.value().get());
190 newIndexingMaps.push_back(
191 genericOp.getMatchingIndexingMap(&en.value()));
192 }
193 return origToNewPos;
194 }
195 // Output argument can be dropped if the result has
196 // - no users, and
197 // - it is not used in the payload, and
198 // - the corresponding indexing maps are not needed for loop bound
199 // computation.
200 auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
201 for (const auto &outputOpOperand :
202 llvm::enumerate(genericOp.getDpsInitsMutable())) {
203 OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
204 AffineMap indexingMap =
205 genericOp.getMatchingIndexingMap(&outputOpOperand.value());
206 auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
207 yieldOp->getOperand(outputOpOperand.index()));
208 if (isResultValueDead(genericOp, result)) {
209 // Check if the opoperand can be dropped without affecting loop
210 // bound computation. Add the operand to the list of dropped op
211 // operand for checking. If it cannot be dropped, need to pop the
212 // value back.
213 droppedOpOperands.push_back(&outputOpOperand.value());
214 if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
215 continue;
216 }
217 droppedOpOperands.pop_back();
218 }
219
220 if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
221 // The out operand can also be dropped if it is computed redundantly
222 // by another result, the conditions for that are
223 // - The same operand is used as the out operand
224 // - The same indexing map is used
225 // - The same yield value is used.
226 auto it = dedupedOutpts.find(key);
227 if (it != dedupedOutpts.end()) {
228 origToNewPos[outputOpOperand.index()] = it->second;
229 droppedOpOperands.push_back(&outputOpOperand.value());
230 continue;
231 }
232 }
233
234 origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
235 dedupedOutpts[key] = newOutputOperands.size();
236 newOutputOperands.push_back(outputOpOperand.value().get());
237 newIndexingMaps.push_back(
238 genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
239 }
240 return origToNewPos;
241 }
242
243 // Populate the body of the canonicalized operation.
244 void populateOpPayload(
245 GenericOp genericOp, GenericOp newOp,
246 const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
247 const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
248 PatternRewriter &rewriter) const {
249 // Merge the body of the original op with the new op.
250 Block *newOpBlock = &newOp.getRegion().front();
251 assert(newOpBlock->empty() && "expected new op to have an empty payload");
252 Block *origOpBlock = &genericOp.getRegion().front();
253 SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
254
255 // Replace all arguments in the original op, with arguments from the
256 // canonicalized op.
257 auto updateReplacements =
258 [&](SmallVector<OpOperand *> &origOperands,
259 SmallVector<OpOperand *> &newOperands,
260 const llvm::SmallDenseMap<unsigned, unsigned> &map) {
261 for (const auto &origOperand : llvm::enumerate(First&: origOperands)) {
262 auto it = map.find(Val: origOperand.index());
263 if (it == map.end())
264 continue;
265 OpOperand *newOperand = newOperands[it->second];
266 replacements[origOperand.value()->getOperandNumber()] =
267 newOpBlock->getArgument(i: newOperand->getOperandNumber());
268 }
269 };
270
271 SmallVector<OpOperand *> origInputOperands =
272 genericOp.getDpsInputOperands();
273 SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
274 updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
275
276 SmallVector<OpOperand *> origOutputOperands =
277 llvm::to_vector(llvm::map_range(genericOp.getDpsInitsMutable(),
278 [](OpOperand &o) { return &o; }));
279 SmallVector<OpOperand *> newOutputOperands =
280 llvm::to_vector(llvm::map_range(newOp.getDpsInitsMutable(),
281 [](OpOperand &o) { return &o; }));
282 updateReplacements(origOutputOperands, newOutputOperands,
283 origOutsToNewOutsPos);
284
285 // Drop the unused yield args.
286 if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
287 OpBuilder::InsertionGuard g(rewriter);
288 YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
289 rewriter.setInsertionPoint(origYieldOp);
290
291 SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
292 for (const auto &yieldOpOperands :
293 llvm::enumerate(origYieldOp.getValues())) {
294 auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
295 if (it == origOutsToNewOutsPos.end())
296 continue;
297 newYieldVals[it->second] = yieldOpOperands.value();
298 }
299 rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
300 }
301
302 rewriter.mergeBlocks(source: origOpBlock, dest: newOpBlock, argValues: replacements);
303 }
304};
305
306/// Remove unused cycles.
307/// We can remove unused cycle within a payload of generic region
308/// if these conditions are met:
309/// - Result from out operand is dead.
310/// - Block arg from out operand has a single use in the %cycle
311/// instruction.
312/// - Cycle has a single use and it is in yield.
313struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
314 using OpRewritePattern<GenericOp>::OpRewritePattern;
315
316 LogicalResult matchAndRewrite(GenericOp genericOp,
317 PatternRewriter &rewriter) const override {
318
319 // If the op doesnt have tensor semantics, preserve the outputs as is.
320 if (!genericOp.hasPureTensorSemantics())
321 return failure();
322
323 bool hasRemovedCycles = false;
324 // Iterate over output operands and remove any unused cycles.
325 for (const auto &outputOpOperand :
326 llvm::enumerate(genericOp.getDpsInits())) {
327
328 // Check that result from out operand is dead.
329 Value result = genericOp.getResult(outputOpOperand.index());
330 if (!result.use_empty())
331 continue;
332
333 // Check that outputArg has one use in cycle.
334 BlockArgument outputArg =
335 genericOp.getRegionOutputArgs()[outputOpOperand.index()];
336 if (!outputArg.hasOneUse())
337 continue;
338
339 // Check cycle has at most one use.
340 Operation *cycleOp = *outputArg.user_begin();
341 if (!cycleOp->hasOneUse())
342 continue;
343
344 // Check that the cycleUser is a yield.
345 Operation *cycleUserOp = *cycleOp->user_begin();
346 if (!isa<linalg::YieldOp>(cycleUserOp))
347 continue;
348
349 // Check that argIndex matches yieldIndex, else data is being used.
350 if (cycleUserOp->getOperand(outputOpOperand.index()) !=
351 cycleOp->getResult(0))
352 continue;
353
354 // Directly replace the cycle with the blockArg such that
355 // Deduplicate pattern can eliminate it along with unused yield.
356 rewriter.replaceOp(cycleOp, outputArg);
357 rewriter.modifyOpInPlace(genericOp, [] {});
358 hasRemovedCycles = true;
359 }
360
361 if (hasRemovedCycles) {
362 return success();
363 }
364
365 return failure();
366 }
367};
368
369/// Fold uses of duplicate inputs in the body of a linalg.generic. E.g.:
370/// ```
371/// linalg.generic ins(%a, %b, %a, %b) outs(%a)
372/// ^bb0(%in0, %in1, %in2, %in3, %out1)
373/// ```
374/// Assuming that all %a and %b have the same index map:
375/// * All uses of %in0 and %in2 are replaced with %out1
376/// * All uses of %in1 are replaced with %in3
377/// This pattern can enable additional canonicalizations: In the above example,
378/// %in0, %in1 and %in3 have no uses anymore and their corresponding operands
379/// can be folded away. This pattern does not modify uses of output block args.
380struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
381 using OpRewritePattern<GenericOp>::OpRewritePattern;
382
383 LogicalResult matchAndRewrite(GenericOp genericOp,
384 PatternRewriter &rewriter) const override {
385 // Find replacement bbArgs for all input bbArg.
386 DenseMap<int, int> replacements;
387 for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
388 // Skip bbArgs that have no uses.
389 if (genericOp.getBody()->getArgument(i).getUses().empty())
390 continue;
391 // Find replacement bbArg. This can be an input or an output bbArg.
392 for (int j = genericOp->getNumOperands() - 1; j > i; --j) {
393 if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
394 genericOp.getIndexingMapsArray()[i] ==
395 genericOp.getIndexingMapsArray()[j]) {
396 replacements[i] = j;
397 break;
398 }
399 }
400 }
401
402 // Stop here if no replacements were found.
403 if (replacements.empty())
404 return failure();
405
406 // Rewrite the op.
407 rewriter.modifyOpInPlace(genericOp, [&]() {
408 for (auto [before, after] : replacements) {
409 BlockArgument bbArg = genericOp.getBody()->getArgument(before);
410 BlockArgument replacement = genericOp.getBody()->getArgument(after);
411 rewriter.replaceAllUsesWith(from: bbArg, to: replacement);
412 }
413 });
414
415 return success();
416 }
417};
418
419} // namespace
420
421void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns(
422 RewritePatternSet &patterns) {
423 patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
424 arg: patterns.getContext(), /*removeOutputs=*/args: true);
425 patterns.insert<RemoveUnusedCycleInGenericOp>(arg: patterns.getContext());
426}
427
428void mlir::linalg::populateEraseUnnecessaryInputsPatterns(
429 RewritePatternSet &patterns) {
430 patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
431 arg: patterns.getContext(), /*removeOutputs=*/args: false);
432 patterns.insert<FoldDuplicateInputBbArgs>(arg: patterns.getContext());
433}
434

source code of mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp