1//===- EraseUnusedOperandsAndResults.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Linalg/IR/Linalg.h"
12
13using namespace mlir;
14using namespace mlir::linalg;
15
16/// Return `true` if the `result` of an operation `genericOp` is dead.
17static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
18 if (!result.use_empty())
19 return false;
20 // If out operand not used in payload, we can drop it.
21 OpOperand *outputOpOperand =
22 genericOp.getDpsInitOperand(result.getResultNumber());
23 if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
24 return true;
25
26 // The out operand that is part of a payload can be dropped if
27 // these conditions are met:
28 // - Result from out operand is dead.
29 // - User of arg is yield.
30 // - outArg data is not being used by other outArgs.
31
32 // Check block arg and cycle from out operand has a single use.
33 BlockArgument outputArg =
34 genericOp.getRegionOutputArgs()[result.getResultNumber()];
35 if (!outputArg.hasOneUse())
36 return false;
37 Operation *argUserOp = *outputArg.user_begin();
38
39 // Check argUser has no other use.
40 if (!argUserOp->use_empty())
41 return false;
42
43 // Check that argUser is a yield.
44 auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
45 if (!yieldOp)
46 return false;
47
48 // Check outArg data is not being used by other outArgs.
49 if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
50 return false;
51
52 return true;
53}
54
55//===---------------------------------------------------------------------===//
56// Helper methods for operand deduplication and dead results elimination
57//===---------------------------------------------------------------------===//
58
59// Deduplicate input operands, and return the
60// - Mapping from operand position in the original op, to operand position in
61// the canonicalized op.
62// - The preserved input operands list (by reference).
63llvm::SmallDenseMap<unsigned, unsigned> static deduplicateInputOperands(
64 GenericOp genericOp, SmallVector<OpOperand *> &droppedOpOperands,
65 SmallVector<Value> &newInputOperands,
66 SmallVector<AffineMap> &newIndexingMaps) {
67 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
68 llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
69 for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
70 OpOperand *inputOpOperand = en.value();
71 // Check if operand is dead and if dropping the indexing map makes the
72 // loops to shape computation invalid.
73 if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
74 // Add the current operands to the list of potentially droppable
75 // operands. If it cannot be dropped, this needs to be popped back.
76 droppedOpOperands.push_back(inputOpOperand);
77 if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
78 continue;
79 droppedOpOperands.pop_back();
80 }
81
82 // Check if this operand is a duplicate.
83 AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
84 auto it =
85 dedupedInputs.find(std::make_pair(inputOpOperand->get(), indexingMap));
86 if (it != dedupedInputs.end()) {
87 origToNewPos[en.index()] = it->second;
88 droppedOpOperands.push_back(inputOpOperand);
89 continue;
90 }
91
92 // This is a preserved argument.
93 origToNewPos[en.index()] = newInputOperands.size();
94 dedupedInputs[{inputOpOperand->get(), indexingMap}] =
95 newInputOperands.size();
96 newInputOperands.push_back(inputOpOperand->get());
97 newIndexingMaps.push_back(indexingMap);
98 }
99 return origToNewPos;
100}
101
102// Deduplicate output operands, and return the
103// - Mapping from operand position in the original op, to operand position in
104// the canonicalized op.
105// - The preserved output operands list (by reference).
106llvm::SmallDenseMap<unsigned, unsigned> static deduplicateOutputOperands(
107 GenericOp genericOp, SmallVector<OpOperand *> &droppedOpOperands,
108 SmallVector<Value> &newOutputOperands,
109 SmallVector<AffineMap> &newIndexingMaps, bool removeOutputs) {
110 llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
111 llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
112 dedupedOutpts;
113 // If the op doesn't have tensor semantics or outputs should not be removed,
114 // keep all the outputs as preserved.
115 if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
116 for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
117 origToNewPos[en.index()] = newOutputOperands.size();
118 newOutputOperands.push_back(en.value().get());
119 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&en.value()));
120 }
121 return origToNewPos;
122 }
123 // Output argument can be dropped if the result has
124 // - no users, and
125 // - it is not used in the payload, and
126 // - the corresponding indexing maps are not needed for loop bound
127 // computation.
128 auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
129 for (const auto &outputOpOperand :
130 llvm::enumerate(genericOp.getDpsInitsMutable())) {
131 OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
132 AffineMap indexingMap =
133 genericOp.getMatchingIndexingMap(&outputOpOperand.value());
134 auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
135 yieldOp->getOperand(outputOpOperand.index()));
136 if (isResultValueDead(genericOp, result)) {
137 // Check if the opoperand can be dropped without affecting loop
138 // bound computation. Add the operand to the list of dropped op
139 // operand for checking. If it cannot be dropped, need to pop the
140 // value back.
141 droppedOpOperands.push_back(&outputOpOperand.value());
142 if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
143 continue;
144 }
145 droppedOpOperands.pop_back();
146 }
147
148 if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
149 // The out operand can also be dropped if it is computed redundantly
150 // by another result, the conditions for that are
151 // - The same operand is used as the out operand
152 // - The same indexing map is used
153 // - The same yield value is used.
154 auto it = dedupedOutpts.find(key);
155 if (it != dedupedOutpts.end()) {
156 origToNewPos[outputOpOperand.index()] = it->second;
157 droppedOpOperands.push_back(&outputOpOperand.value());
158 continue;
159 }
160 }
161
162 origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
163 dedupedOutpts[key] = newOutputOperands.size();
164 newOutputOperands.push_back(outputOpOperand.value().get());
165 newIndexingMaps.push_back(
166 genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
167 }
168 return origToNewPos;
169}
170
171// Populate the body of the canonicalized operation.
172static void populateOpPayload(
173 GenericOp genericOp, GenericOp newOp,
174 const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
175 const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
176 RewriterBase &rewriter) {
177 // Merge the body of the original op with the new op.
178 Block *newOpBlock = &newOp.getRegion().front();
179 assert(newOpBlock->empty() && "expected new op to have an empty payload");
180 Block *origOpBlock = &genericOp.getRegion().front();
181 SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
182
183 // Replace all arguments in the original op, with arguments from the
184 // canonicalized op.
185 auto updateReplacements =
186 [&](SmallVector<OpOperand *> &origOperands,
187 SmallVector<OpOperand *> &newOperands,
188 const llvm::SmallDenseMap<unsigned, unsigned> &map) {
189 for (const auto &origOperand : llvm::enumerate(First&: origOperands)) {
190 auto it = map.find(Val: origOperand.index());
191 if (it == map.end())
192 continue;
193 OpOperand *newOperand = newOperands[it->second];
194 replacements[origOperand.value()->getOperandNumber()] =
195 newOpBlock->getArgument(i: newOperand->getOperandNumber());
196 }
197 };
198
199 SmallVector<OpOperand *> origInputOperands = genericOp.getDpsInputOperands();
200 SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
201 updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
202
203 SmallVector<OpOperand *> origOutputOperands = llvm::to_vector(llvm::map_range(
204 genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
205 SmallVector<OpOperand *> newOutputOperands = llvm::to_vector(llvm::map_range(
206 newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
207 updateReplacements(origOutputOperands, newOutputOperands,
208 origOutsToNewOutsPos);
209
210 // Drop the unused yield args.
211 if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
212 OpBuilder::InsertionGuard g(rewriter);
213 YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
214 rewriter.setInsertionPoint(origYieldOp);
215
216 SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
217 for (const auto &yieldOpOperands :
218 llvm::enumerate(origYieldOp.getValues())) {
219 auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
220 if (it == origOutsToNewOutsPos.end())
221 continue;
222 newYieldVals[it->second] = yieldOpOperands.value();
223 }
224 rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
225 }
226
227 rewriter.mergeBlocks(source: origOpBlock, dest: newOpBlock, argValues: replacements);
228}
229
230FailureOr<linalg::GenericOp>
231mlir::linalg::deduplicateOperandsAndRemoveDeadResults(
232 RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs) {
233 // Create a map from argument position in the original op to the argument
234 // position in the new op. If the argument is dropped it wont have an entry.
235 SmallVector<OpOperand *> droppedOpOperands;
236
237 // Information needed to build the new op.
238 SmallVector<Value> newInputOperands, newOutputOperands;
239 SmallVector<AffineMap> newIndexingMaps;
240
241 // Gather information about duplicate input operands.
242 llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
243 deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
244 newIndexingMaps);
245
246 // Gather information about the dropped outputs.
247 llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
248 deduplicateOutputOperands(genericOp, droppedOpOperands, newOutputOperands,
249 newIndexingMaps, removeOutputs);
250
251 // Check if there is any change to operands.
252 if (newInputOperands.size() + newOutputOperands.size() ==
253 genericOp->getNumOperands())
254 return genericOp;
255
256 // Create the new op with the body being empty.
257 Location loc = genericOp.getLoc();
258 SmallVector<Type> newResultTypes;
259 for (Value v : newOutputOperands)
260 if (isa<TensorType>(Val: v.getType()))
261 newResultTypes.push_back(Elt: v.getType());
262 auto newOp = rewriter.create<GenericOp>(
263 loc, newResultTypes, newInputOperands, newOutputOperands,
264 rewriter.getAffineMapArrayAttr(newIndexingMaps),
265 genericOp.getIteratorTypes(), genericOp.getDocAttr(),
266 genericOp.getLibraryCallAttr(),
267 [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
268 return;
269 });
270 // Copy over unknown attributes. They might be load bearing for some flow.
271 ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
272 for (NamedAttribute kv : genericOp->getAttrs())
273 if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
274 newOp->setAttr(kv.getName(), kv.getValue());
275
276 // Fix up the payload of the canonicalized operation.
277 populateOpPayload(genericOp, newOp, origInsToNewInsPos, origOutsToNewOutsPos,
278 rewriter);
279
280 // Replace all live uses of the op.
281 SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
282 for (const auto &result : llvm::enumerate(genericOp.getResults())) {
283 auto it = origOutsToNewOutsPos.find(result.index());
284 if (it == origOutsToNewOutsPos.end())
285 continue;
286 replacementsVals[result.index()] = newOp.getResult(it->second);
287 }
288 rewriter.replaceOp(genericOp, replacementsVals);
289 return newOp;
290}
291
292namespace {
293
294struct DeduplicateAndRemoveDeadOperandsAndResults
295 : public OpRewritePattern<GenericOp> {
296 DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
297 bool removeOutputs)
298 : OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
299
300 LogicalResult matchAndRewrite(GenericOp genericOp,
301 PatternRewriter &rewriter) const override {
302 FailureOr<GenericOp> newOp = deduplicateOperandsAndRemoveDeadResults(
303 rewriter, genericOp, removeOutputs);
304 if (failed(Result: newOp) || newOp.value() == genericOp) {
305 return rewriter.notifyMatchFailure(
306 genericOp, "failed to dedup operands/remove dead results");
307 }
308 return success();
309 }
310
311private:
312 /// If unset, outputs are not modified by this pattern.
313 bool removeOutputs;
314};
315
316/// Remove unused cycles.
317/// We can remove unused cycle within a payload of generic region
318/// if these conditions are met:
319/// - Result from out operand is dead.
320/// - Block arg from out operand has a single use in the %cycle
321/// instruction.
322/// - Cycle has a single use and it is in yield.
323struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
324 using OpRewritePattern<GenericOp>::OpRewritePattern;
325
326 LogicalResult matchAndRewrite(GenericOp genericOp,
327 PatternRewriter &rewriter) const override {
328
329 // If the op doesnt have tensor semantics, preserve the outputs as is.
330 if (!genericOp.hasPureTensorSemantics())
331 return failure();
332
333 bool hasRemovedCycles = false;
334 // Iterate over output operands and remove any unused cycles.
335 for (const auto &outputOpOperand :
336 llvm::enumerate(genericOp.getDpsInits())) {
337
338 // Check that result from out operand is dead.
339 Value result = genericOp.getResult(outputOpOperand.index());
340 if (!result.use_empty())
341 continue;
342
343 // Check that outputArg has one use in cycle.
344 BlockArgument outputArg =
345 genericOp.getRegionOutputArgs()[outputOpOperand.index()];
346 if (!outputArg.hasOneUse())
347 continue;
348
349 // Check cycle has at most one use.
350 Operation *cycleOp = *outputArg.user_begin();
351 if (!cycleOp->hasOneUse())
352 continue;
353
354 // Check that the cycleUser is a yield.
355 Operation *cycleUserOp = *cycleOp->user_begin();
356 if (!isa<linalg::YieldOp>(cycleUserOp))
357 continue;
358
359 // Check that argIndex matches yieldIndex, else data is being used.
360 if (cycleUserOp->getOperand(outputOpOperand.index()) !=
361 cycleOp->getResult(0))
362 continue;
363
364 // Directly replace the cycle with the blockArg such that
365 // Deduplicate pattern can eliminate it along with unused yield.
366 rewriter.replaceOp(cycleOp, outputArg);
367 rewriter.modifyOpInPlace(genericOp, [] {});
368 hasRemovedCycles = true;
369 }
370
371 if (hasRemovedCycles) {
372 return success();
373 }
374
375 return failure();
376 }
377};
378
379/// Fold uses of duplicate inputs in the body of a linalg.generic. E.g.:
380/// ```
381/// linalg.generic ins(%a, %b, %a, %b) outs(%a)
382/// ^bb0(%in0, %in1, %in2, %in3, %out1)
383/// ```
384/// Assuming that all %a and %b have the same index map:
385/// * All uses of %in0 and %in2 are replaced with %out1
386/// * All uses of %in1 are replaced with %in3
387/// This pattern can enable additional canonicalizations: In the above example,
388/// %in0, %in1 and %in3 have no uses anymore and their corresponding operands
389/// can be folded away. This pattern does not modify uses of output block args.
390struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
391 using OpRewritePattern<GenericOp>::OpRewritePattern;
392
393 LogicalResult matchAndRewrite(GenericOp genericOp,
394 PatternRewriter &rewriter) const override {
395 // Find replacement bbArgs for all input bbArg.
396 DenseMap<int, int> replacements;
397 for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
398 // Skip bbArgs that have no uses.
399 if (genericOp.getBody()->getArgument(i).getUses().empty())
400 continue;
401 // Find replacement bbArg. This can be an input or an output bbArg.
402 for (int j = genericOp->getNumOperands() - 1; j > i; --j) {
403 if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
404 genericOp.getIndexingMapsArray()[i] ==
405 genericOp.getIndexingMapsArray()[j]) {
406 replacements[i] = j;
407 break;
408 }
409 }
410 }
411
412 // Stop here if no replacements were found.
413 if (replacements.empty())
414 return failure();
415
416 // Rewrite the op.
417 rewriter.modifyOpInPlace(genericOp, [&]() {
418 for (auto [before, after] : replacements) {
419 BlockArgument bbArg = genericOp.getBody()->getArgument(before);
420 BlockArgument replacement = genericOp.getBody()->getArgument(after);
421 rewriter.replaceAllUsesWith(from: bbArg, to: replacement);
422 }
423 });
424
425 return success();
426 }
427};
428
429} // namespace
430
431void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns(
432 RewritePatternSet &patterns) {
433 patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
434 arg: patterns.getContext(), /*removeOutputs=*/args: true);
435 patterns.insert<RemoveUnusedCycleInGenericOp>(arg: patterns.getContext());
436}
437
438void mlir::linalg::populateEraseUnnecessaryInputsPatterns(
439 RewritePatternSet &patterns) {
440 patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
441 arg: patterns.getContext(), /*removeOutputs=*/args: false);
442 patterns.insert<FoldDuplicateInputBbArgs>(arg: patterns.getContext());
443}
444

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp