1//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Fusion on tensors operations pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Linalg/Passes.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Support/LLVM.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "mlir/Transforms/RegionUtils.h"
29#include <optional>
30#include <utility>
31
32namespace mlir {
33#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34#include "mlir/Dialect/Linalg/Passes.h.inc"
35} // namespace mlir
36
37using namespace mlir;
38using namespace mlir::linalg;
39
40//===---------------------------------------------------------------------===//
41// Methods and patterns that fuse elementwise `linalg.generic` operations.
42//===---------------------------------------------------------------------===//
43
44/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
45/// the `producer` to use in the fused operation given the indexing map of the
46/// result of the producer in the consumer.
47static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
48 OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
49 AffineMap fusedConsumerArgIndexMap) {
50 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
51 // from consumer loop -> consumer arg tensor index/producer result tensor
52 // index. The fused loop is same as the consumer loop. For each producer arg
53 // the indexing map to be computed is a map from consumer loop -> producer
54 // arg tensor index.
55 // producerResultIndexMap is a map from producer loop -> tensor index.
56 // Compute the inverse to get map from tensor index -> producer loop.
57 // The inverse is a map from producer result tensor index -> producer loop.
58 AffineMap invProducerResultIndexMap =
59 inversePermutation(map: producerResultIndexMap);
60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
62
63 LinalgOp producer = cast<LinalgOp>(Val: producerOpOperand->getOwner());
64 // argMap is a map from producer loop -> producer arg tensor index.
65 AffineMap argMap = producer.getMatchingIndexingMap(opOperand: producerOpOperand);
66
67 // Compose argMap with invProducerResultIndexMap to get a map from
68 // producer result tensor index -> producer arg tensor index.
69 AffineMap t1 = argMap.compose(map: invProducerResultIndexMap);
70
71 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
72 // consumer loop/ fused loop -> producer arg tensor index.
73 return t1.compose(map: fusedConsumerArgIndexMap);
74}
75
76// Checks if the given operand can be dropped, and the remaining operands
77// of the fused producer & consumer after the fusion can still compute the
78// bounds of the op.
79static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
80 GenericOp producer, GenericOp consumer,
81 ArrayRef<OpOperand *> opOperandsToIgnore) {
82 SmallVector<AffineMap> indexingMaps;
83
84 SmallVector<GenericOp> ops = {producer, consumer};
85 for (auto &op : ops) {
86 for (auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(Range&: opOperandsToIgnore, Element: &opOperand)) {
88 continue;
89 }
90 indexingMaps.push_back(Elt: op.getMatchingIndexingMap(opOperand: &opOperand));
91 }
92 }
93 if (indexingMaps.empty()) {
94 // If there are no indexing maps, the operand can only be dropped
95 // if neither op has loops.
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
97 }
98
99 // The concatanation of the remained indexing maps must be invertible, so
100 // the bounds of the op can be still computed after dropping the selected
101 // operand. inversePermutation returns an empty AffineMap in case the
102 // concatanated indexing maps are not invertible.
103 return inversePermutation(map: concatAffineMaps(
104 maps: indexingMaps, context: producer.getContext())) != AffineMap();
105}
106
107/// Returns a set of indices of the producer's results which would
108/// be preserved after the fusion.
109/// * There is a chance that the implementation of the transformation does not
110/// agree with the result of this method. This function gives a prediction based
111/// on an optimized fusion.
112llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
113 GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
115 llvm::SmallVector<OpOperand *> opOperandsToIgnore;
116
117 // The fusedOperand will be removed during the fusion
118 opOperandsToIgnore.emplace_back(Args&: fusedOperand);
119
120 for (const auto &producerResult : llvm::enumerate(First: producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(i: producerResult.index());
122 opOperandsToIgnore.emplace_back(Args&: outputOperand);
123 if (producer.payloadUsesValueFromOperand(opOperand: outputOperand) ||
124 !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
125 opOperandsToIgnore) ||
126 llvm::any_of(Range: producerResult.value().getUsers(), P: [&](Operation *user) {
127 return user != consumer.getOperation();
128 })) {
129 preservedProducerResults.insert(V: producerResult.index());
130
131 // In case the operand can't be dropped
132 (void)opOperandsToIgnore.pop_back_val();
133 }
134 }
135 return preservedProducerResults;
136}
137
138/// Conditions for elementwise fusion of generic operations.
139bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
140 if (!fusedOperand)
141 return false;
142
143 auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
144 auto consumer = dyn_cast<GenericOp>(Val: fusedOperand->getOwner());
145
146 // Check producer and consumer are generic ops.
147 if (!producer || !consumer)
148 return false;
149
150 // Consumer can have mixed semantics, just check operand itself has tensor
151 // type. Producer must have full tensor semantics to avoid potential
152 // aliasing between producer and consumer memrefs.
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(Val: fusedOperand->get().getType()))
155 return false;
156
157 // Verify that
158 // - the producer has all "parallel" iterator type.
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
160 return false;
161
162 // Only allow fusing the producer of an input operand for now.
163 // TODO: allow fusing the producer of an output operand.
164 if (!consumer.isDpsInput(opOperand: fusedOperand))
165 return false;
166
167 // Get the consumer index map. The number of results of the consumer index
168 // map must match the number of loops of the producer.
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(opOperand: fusedOperand);
170 if (consumerIndexMap.getNumResults() != producer.getNumLoops())
171 return false;
172
173 // Finally the index_map for the result must be invertible. For now just
174 // verify it is a permutation.
175 AffineMap producerResultIndexMap =
176 producer.getMatchingIndexingMap(opOperand: producer.getDpsInitOperand(i: 0));
177 if (!producerResultIndexMap.isPermutation())
178 return false;
179
180 // Ensure that the fusion does not remove size information required to
181 // get the loop bounds. For non-reduction generics, this is trivially the
182 // case due to the output operand. For reductions, we need to check that after
183 // the fusion, each loop dimension has at least one input that defines it.
184 if ((consumer.getNumReductionLoops())) {
185 BitVector coveredDims(consumer.getNumLoops(), false);
186
187 auto addToCoveredDims = [&](AffineMap map) {
188 for (auto result : map.getResults())
189 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: result))
190 coveredDims[dimExpr.getPosition()] = true;
191 };
192
193 for (auto pair :
194 llvm::zip(t: consumer->getOperands(), u: consumer.getIndexingMapsArray())) {
195 Value operand = std::get<0>(t&: pair);
196 if (operand == fusedOperand->get())
197 continue;
198 AffineMap operandMap = std::get<1>(t&: pair);
199 addToCoveredDims(operandMap);
200 }
201
202 for (OpOperand *operand : producer.getDpsInputOperands()) {
203 AffineMap newIndexingMap =
204 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
205 producerOpOperand: operand, producerResultIndexMap, fusedConsumerArgIndexMap: consumerIndexMap);
206 addToCoveredDims(newIndexingMap);
207 }
208 if (!coveredDims.all())
209 return false;
210 }
211
212 return true;
213}
214
215/// Generate the region of the fused tensor operation. The region of the fused
216/// op must be empty.
217static void generateFusedElementwiseOpRegion(
218 RewriterBase &rewriter, GenericOp fusedOp,
219 AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
220 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
221 auto producer = cast<GenericOp>(Val: fusedOperand->get().getDefiningOp());
222 auto consumer = cast<GenericOp>(Val: fusedOperand->getOwner());
223 // Build the region of the fused op.
224 Block &producerBlock = producer->getRegion(index: 0).front();
225 Block &consumerBlock = consumer->getRegion(index: 0).front();
226 OpBuilder::InsertionGuard guard(rewriter);
227 Block *fusedBlock = rewriter.createBlock(parent: &fusedOp.getRegion());
228 IRMapping mapper;
229
230 // 2. Add an index operation for every fused loop dimension and use the
231 // `consumerToProducerLoopsMap` to map the producer indices.
232 if (producer.hasIndexSemantics()) {
233 // Add an index operation for every fused loop dimension.
234 unsigned numFusedOpLoops = fusedOp.getNumLoops();
235 SmallVector<Value> fusedIndices;
236 fusedIndices.reserve(N: numFusedOpLoops);
237 llvm::transform(Range: llvm::seq<uint64_t>(Begin: 0, End: numFusedOpLoops),
238 d_first: std::back_inserter(x&: fusedIndices), F: [&](uint64_t dim) {
239 return rewriter.create<IndexOp>(location: producer.getLoc(), args&: dim);
240 });
241 for (IndexOp indexOp :
242 llvm::make_early_inc_range(Range: producerBlock.getOps<IndexOp>())) {
243 Value newIndex = rewriter.create<affine::AffineApplyOp>(
244 location: producer.getLoc(),
245 args: consumerToProducerLoopsMap.getSubMap(resultPos: indexOp.getDim()), args&: fusedIndices);
246 mapper.map(from: indexOp.getResult(), to: newIndex);
247 }
248 }
249 // TODO: allow fusing the producer of an output operand.
250 assert(consumer.isDpsInput(fusedOperand) &&
251 "expected producer of input operand");
252 // 3. Consumer input operands up to consumerIdx (exclusive).
253 for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
254 N: fusedOperand->getOperandNumber())) // input assumption.
255 mapper.map(from: bbArg, to: fusedBlock->addArgument(type: bbArg.getType(), loc: bbArg.getLoc()));
256
257 // Replacing consumerIdx requires getting the cloned, yielded, value from
258 // the (cloned) producer block. This happens in step 9.
259
260 // 4. Splice in producer's input operands.
261 for (BlockArgument bbArg :
262 producerBlock.getArguments().take_front(N: producer.getNumDpsInputs()))
263 mapper.map(from: bbArg, to: fusedBlock->addArgument(type: bbArg.getType(), loc: bbArg.getLoc()));
264
265 // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
266 for (BlockArgument bbArg :
267 consumerBlock.getArguments()
268 .take_front(N: consumer.getNumDpsInputs())
269 .drop_front(N: fusedOperand->getOperandNumber() + 1))
270 mapper.map(from: bbArg, to: fusedBlock->addArgument(type: bbArg.getType(), loc: bbArg.getLoc()));
271
272 // 6. All of the producer's output operands
273 for (const auto &bbArg : llvm::enumerate(
274 First: producerBlock.getArguments().take_back(N: producer.getNumDpsInits()))) {
275 if (!preservedProducerResults.count(V: bbArg.index()))
276 continue;
277 mapper.map(from: bbArg.value(), to: fusedBlock->addArgument(type: bbArg.value().getType(),
278 loc: bbArg.value().getLoc()));
279 }
280
281 // 7. All of consumer's output operands.
282 for (BlockArgument bbArg :
283 consumerBlock.getArguments().take_back(N: consumer.getNumDpsInits()))
284 mapper.map(from: bbArg, to: fusedBlock->addArgument(type: bbArg.getType(), loc: bbArg.getLoc()));
285
286 // 8. Clone all producer operations except for the yield and index operations
287 // to the fused operation.
288 for (auto &op : producerBlock.without_terminator()) {
289 if (!isa<IndexOp>(Val: op))
290 rewriter.clone(op, mapper);
291 }
292 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
293 // forward the yield operand.
294 auto producerYieldOp = cast<linalg::YieldOp>(Val: producerBlock.getTerminator());
295 unsigned producerResultNumber =
296 cast<OpResult>(Val: fusedOperand->get()).getResultNumber();
297 Value replacement =
298 mapper.lookupOrDefault(from: producerYieldOp.getOperand(i: producerResultNumber));
299
300 // Sanity checks, if replacement is not already in the mapper then it must be
301 // produced outside.
302 if (replacement == producerYieldOp.getOperand(i: producerResultNumber)) {
303 if (auto bb = dyn_cast<BlockArgument>(Val&: replacement))
304 assert(bb.getOwner() != &producerBlock &&
305 "yielded block argument must have been mapped");
306 else
307 assert(!producer->isAncestor(replacement.getDefiningOp()) &&
308 "yielded value must have been mapped");
309 }
310 mapper.map(from: consumerBlock.getArgument(i: fusedOperand->getOperandNumber()),
311 to: replacement);
312 // 10. Clone operations from the consumer to the fused op.
313 for (auto &op : consumerBlock.without_terminator())
314 rewriter.clone(op, mapper);
315
316 // 11. Include the final yield (which is the remapped values for all the
317 // yield)
318 auto consumerYieldOp = cast<linalg::YieldOp>(Val: consumerBlock.getTerminator());
319 SmallVector<Value> fusedYieldValues;
320 fusedYieldValues.reserve(N: producerYieldOp.getNumOperands() +
321 consumerYieldOp.getNumOperands());
322 for (const auto &producerYieldVal :
323 llvm::enumerate(First: producerYieldOp.getOperands())) {
324 if (preservedProducerResults.count(V: producerYieldVal.index()))
325 fusedYieldValues.push_back(
326 Elt: mapper.lookupOrDefault(from: producerYieldVal.value()));
327 }
328 for (auto consumerYieldVal : consumerYieldOp.getOperands())
329 fusedYieldValues.push_back(Elt: mapper.lookupOrDefault(from: consumerYieldVal));
330 rewriter.create<YieldOp>(location: fusedOp.getLoc(), args&: fusedYieldValues);
331
332 // Sanity checks.
333 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
334 "Ill-formed GenericOp region");
335}
336
337FailureOr<mlir::linalg::ElementwiseOpFusionResult>
338mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
339 OpOperand *fusedOperand) {
340 assert(areElementwiseOpsFusable(fusedOperand) &&
341 "expected elementwise operation pre-conditions to pass");
342 auto producerResult = cast<OpResult>(Val: fusedOperand->get());
343 auto producer = cast<GenericOp>(Val: producerResult.getOwner());
344 auto consumer = cast<GenericOp>(Val: fusedOperand->getOwner());
345 // TODO: allow fusing the producer of an output operand.
346 assert(consumer.isDpsInput(fusedOperand) &&
347 "expected producer of input operand");
348 /// Find the results of the producer that have uses outside of the consumer,
349 /// after the fusion.
350 llvm::SmallDenseSet<int> preservedProducerResults =
351 mlir::linalg::getPreservedProducerResults(producer, consumer,
352 fusedOperand);
353
354 // Compute the fused operands list and indexing maps.
355 SmallVector<Value> fusedInputOperands, fusedOutputOperands;
356 SmallVector<Type> fusedResultTypes;
357 SmallVector<AffineMap> fusedIndexMaps;
358 fusedInputOperands.reserve(N: producer.getNumDpsInputs() +
359 consumer.getNumDpsInputs());
360 fusedOutputOperands.reserve(N: preservedProducerResults.size() +
361 consumer.getNumDpsInits());
362 fusedResultTypes.reserve(N: preservedProducerResults.size() +
363 consumer.getNumDpsInits());
364 fusedIndexMaps.reserve(N: producer->getNumOperands() +
365 consumer->getNumOperands());
366 // In the following, numbering matches that of `generateFusedTensorOpRegion`.
367 // 3. Consumer input operands/maps up to consumerIdx (exclusive).
368 auto consumerInputs = consumer.getDpsInputOperands();
369 auto *it = llvm::find_if(Range&: consumerInputs, P: [&](OpOperand *operand) {
370 return operand == fusedOperand;
371 });
372 assert(it != consumerInputs.end() && "expected to find the consumer operand");
373 for (OpOperand *opOperand : llvm::make_range(x: consumerInputs.begin(), y: it)) {
374 fusedInputOperands.push_back(Elt: opOperand->get());
375 fusedIndexMaps.push_back(Elt: consumer.getMatchingIndexingMap(opOperand));
376 }
377 // 4. Splice in producer's input operands/maps.
378 AffineMap producerResultIndexMap =
379 producer.getIndexingMapMatchingResult(result: producerResult);
380 for (OpOperand *opOperand : producer.getDpsInputOperands()) {
381 fusedInputOperands.push_back(Elt: opOperand->get());
382 // Compute indexing maps for the producer args in the fused operation.
383 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
384 producerOpOperand: opOperand, producerResultIndexMap,
385 fusedConsumerArgIndexMap: consumer.getMatchingIndexingMap(opOperand: fusedOperand));
386 fusedIndexMaps.push_back(Elt: map);
387 }
388 // 5. Remaining consumer's input operands/maps (drop past index
389 // `consumerIdx`).
390 for (OpOperand *opOperand :
391 llvm::make_range(x: std::next(x: it), y: consumerInputs.end())) {
392 fusedInputOperands.push_back(Elt: opOperand->get());
393 fusedIndexMaps.push_back(Elt: consumer.getMatchingIndexingMap(opOperand));
394 }
395
396 // 6. Collect all of the producer outputs.
397 for (const auto &opOperand : llvm::enumerate(First: producer.getDpsInitsMutable())) {
398 if (!preservedProducerResults.count(V: opOperand.index()))
399 continue;
400
401 fusedOutputOperands.push_back(Elt: opOperand.value().get());
402 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
403 producerOpOperand: &opOperand.value(), producerResultIndexMap,
404 fusedConsumerArgIndexMap: consumer.getMatchingIndexingMap(opOperand: fusedOperand));
405 fusedIndexMaps.push_back(Elt: map);
406 fusedResultTypes.push_back(Elt: opOperand.value().get().getType());
407 }
408
409 // 7. All of consumer's output operands (skip operands: added by the builder).
410 for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
411 fusedOutputOperands.push_back(Elt: opOperand.get());
412 fusedIndexMaps.push_back(Elt: consumer.getMatchingIndexingMap(opOperand: &opOperand));
413 Type resultType = opOperand.get().getType();
414 if (!isa<MemRefType>(Val: resultType))
415 fusedResultTypes.push_back(Elt: resultType);
416 }
417
418 // Generate the fused op.
419 auto fusedOp = rewriter.create<GenericOp>(
420 location: consumer.getLoc(), args&: fusedResultTypes, args&: fusedInputOperands,
421 args&: fusedOutputOperands, args: rewriter.getAffineMapArrayAttr(values: fusedIndexMaps),
422 args: consumer.getIteratorTypes(),
423 /*doc=*/args: nullptr,
424 /*library_call=*/args: nullptr);
425 if (!fusedOp.getShapesToLoopsMap()) {
426 // Fused op has invalid indexing maps. Typically this means something is off
427 // in the input, but going ahead here would result in verification errors.
428 // So cleanup and abort.
429 rewriter.eraseOp(op: fusedOp);
430 return rewriter.notifyMatchFailure(
431 arg&: fusedOp, msg: "fused op failed loop bound computation check");
432 }
433
434 // Construct an AffineMap from consumer loops to producer loops.
435 // consumer loop -> tensor index
436 AffineMap consumerResultIndexMap =
437 consumer.getMatchingIndexingMap(opOperand: fusedOperand);
438 // tensor index -> producer loop
439 AffineMap invProducerResultIndexMap =
440 inversePermutation(map: producerResultIndexMap);
441 assert(invProducerResultIndexMap &&
442 "expected producer result indexig map to be invertible");
443 // consumer loop -> producer loop
444 AffineMap consumerToProducerLoopsMap =
445 invProducerResultIndexMap.compose(map: consumerResultIndexMap);
446
447 generateFusedElementwiseOpRegion(
448 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
449 nloops: consumer.getNumLoops(), preservedProducerResults);
450 ElementwiseOpFusionResult result;
451 result.fusedOp = fusedOp;
452 int resultNum = 0;
453 for (auto [index, producerResult] : llvm::enumerate(First: producer->getResults()))
454 if (preservedProducerResults.count(V: index))
455 result.replacements[producerResult] = fusedOp->getResult(idx: resultNum++);
456 for (auto consumerResult : consumer->getResults())
457 result.replacements[consumerResult] = fusedOp->getResult(idx: resultNum++);
458 return result;
459}
460
461namespace {
462/// Patterns to fuse a generic op, with the producer of its operands.
463class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
464public:
465 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
466 PatternBenefit benefit = 1)
467 : OpRewritePattern<GenericOp>(context, benefit),
468 controlFn(std::move(fun)) {}
469
470 LogicalResult matchAndRewrite(GenericOp genericOp,
471 PatternRewriter &rewriter) const override {
472 // Find the first operand that is defined by another generic op on tensors.
473 for (OpOperand &opOperand : genericOp->getOpOperands()) {
474 if (!areElementwiseOpsFusable(fusedOperand: &opOperand))
475 continue;
476 if (!controlFn(&opOperand))
477 continue;
478
479 Operation *producer = opOperand.get().getDefiningOp();
480
481 // Find the producer of the operand.
482 FailureOr<ElementwiseOpFusionResult> fusionResult =
483 fuseElementwiseOps(rewriter, fusedOperand: &opOperand);
484 if (failed(Result: fusionResult))
485 return rewriter.notifyMatchFailure(arg&: genericOp, msg: "fusion failed");
486
487 // Perform the fusion.
488 for (auto [origVal, replacement] : fusionResult->replacements) {
489 rewriter.replaceUsesWithIf(from: origVal, to: replacement, functor: [&](OpOperand &use) {
490 // Only replace consumer uses.
491 return use.get().getDefiningOp() != producer;
492 });
493 }
494 rewriter.eraseOp(op: genericOp);
495 return success();
496 }
497 return failure();
498 }
499
500private:
501 ControlFusionFn controlFn;
502};
503} // namespace
504
505//===---------------------------------------------------------------------===//
506// Methods and patterns that fuse reshape ops with elementwise operations by
507// expanding the dimensionality of the elementwise operations.
508//===---------------------------------------------------------------------===//
509
510/// Conditions for folding a structured linalg operation with a reshape op by
511/// expanding the iteration space dimensionality for tensor operations. These
512/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
513/// the following fusion pattern.
514///
515/// Consider
516///
517/// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
518/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
519/// affine_map<(d0, d1, d2) -> (d1, d2)>,
520/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
521/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
522/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
523///
524/// The reshape can be folded into the `linalgOp` if its loop dimensionality
525/// is increased to match the result (operand) of the tensor.expand_shape.
526/// The indexing_map of the fused tensor in the `linalgOp` and the
527/// reassociation map helps compute the indexing maps of the modified op.
528/// For the above example, based on the reassociation map it
529/// can be concluded that
530///
531/// - The loop used to access the first dimension of the fused tensor is split
532/// into two.
533/// - The loop used to access the second dimension of the fused tensor is kept
534/// as is.
535/// - The loop used to access the third dimension of the fused tensor is split
536/// into three.
537///
538/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
539/// op, then
540///
541/// d0 -> e0, e1
542/// d1 -> e2, e3, e4
543/// d2 -> e5
544///
545/// substituting this, the structured op can be rewritten as
546///
547/// %d = linalg.generic ins(%0, %1 : )
548/// indexing_maps =
549/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
550/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
551/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
552///
553/// Since operands to the linalg generic are now 5D, reshapes can be introduced
554/// to make it consistent
555///
556/// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
557/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
558/// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
559/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
560///
561/// The added reshapes are again expanding patterns, so they will get fused
562/// with its producers if possible.
563static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
564 OpOperand *fusableOpOperand) {
565 // Is fusable only if:
566 // - All the indexing maps for operands and results are projected
567 // permutations.
568 // - The fused tensor is not a scalar.
569 SmallVector<utils::IteratorType> iteratorTypes =
570 linalgOp.getIteratorTypesArray();
571 AffineMap operandMap = linalgOp.getMatchingIndexingMap(opOperand: fusableOpOperand);
572 return linalgOp.hasPureTensorSemantics() &&
573 llvm::all_of(Range: linalgOp.getIndexingMaps().getValue(),
574 P: [](Attribute attr) {
575 return cast<AffineMapAttr>(Val&: attr)
576 .getValue()
577 .isProjectedPermutation();
578 }) &&
579 operandMap.getNumResults() > 0;
580}
581
582namespace {
583/// Information needed to expand a generic operation to fold the reshape with
584/// it.
585class ExpansionInfo {
586public:
587 // Computes the mapping from original dimensions of the op to the dimensions
588 // of the expanded op given the `indexingMap` of the fused operand/result of
589 // the generic op, the `reassocationMaps` of the reshape op and the shape of
590 // the expanded op.
591 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
592 ArrayRef<AffineMap> reassociationMaps,
593 ArrayRef<OpFoldResult> expandedShape,
594 PatternRewriter &rewriter);
595 unsigned getOrigOpNumDims() const { return reassociation.size(); }
596 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
597 ReassociationIndicesRef getExpandedDims(unsigned i) const {
598 return reassociation[i];
599 }
600 ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
601 return expandedShapeMap[i];
602 }
603 ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
604
605private:
606 /// Reassociation from the dimensions in the original operation to the
607 /// dimension of the expanded operation.
608 SmallVector<ReassociationIndices> reassociation;
609 /// Mapping from extent of loops in the original operation, to the extent of
610 /// loops in the expanded operation.
611 SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
612 /// Extent of the loop in the original operation.
613 SmallVector<OpFoldResult> originalLoopExtent;
614 unsigned expandedOpNumDims;
615};
616} // namespace
617
618LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
619 OpOperand *fusableOpOperand,
620 ArrayRef<AffineMap> reassociationMaps,
621 ArrayRef<OpFoldResult> expandedShape,
622 PatternRewriter &rewriter) {
623 if (reassociationMaps.empty())
624 return failure();
625 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(opOperand: fusableOpOperand);
626
627 OpBuilder::InsertionGuard g(rewriter);
628 rewriter.setInsertionPoint(linalgOp);
629 originalLoopExtent = llvm::map_to_vector(
630 C: linalgOp.createLoopRanges(b&: rewriter, loc: linalgOp->getLoc()),
631 F: [](Range r) { return r.size; });
632
633 reassociation.clear();
634 expandedShapeMap.clear();
635 // Compute the number of dimension in the expanded op that correspond to each
636 // dimension of the original op.
637 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
638 expandedShapeMap.resize(N: fusedIndexMap.getNumDims());
639 for (const auto &resultExpr : llvm::enumerate(First: fusedIndexMap.getResults())) {
640 unsigned pos = cast<AffineDimExpr>(Val: resultExpr.value()).getPosition();
641 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
642 numExpandedDims[pos] = foldedDims.getNumResults();
643 ArrayRef<OpFoldResult> shape =
644 expandedShape.slice(N: foldedDims.getDimPosition(idx: 0), M: numExpandedDims[pos]);
645 expandedShapeMap[pos].assign(in_start: shape.begin(), in_end: shape.end());
646 }
647 // The remaining dimensions remain the same.
648 for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: fusedIndexMap.getNumDims()))
649 if (expandedShapeMap[i].empty())
650 expandedShapeMap[i] = {originalLoopExtent[i]};
651
652 // Compute reassociation map from the original op to the expanded op.
653 unsigned sum = 0;
654 reassociation.reserve(N: fusedIndexMap.getNumDims());
655 for (const auto &numFoldedDim : llvm::enumerate(First&: numExpandedDims)) {
656 auto seq = llvm::seq<int64_t>(Begin: sum, End: sum + numFoldedDim.value());
657 reassociation.emplace_back(Args: seq.begin(), Args: seq.end());
658 sum += numFoldedDim.value();
659 }
660 expandedOpNumDims = sum;
661 return success();
662}
663
664/// Return the indexing map to use in the expanded op for a given the
665/// `indexingMap` of the original operation.
666static AffineMap
667getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
668 const ExpansionInfo &expansionInfo) {
669 SmallVector<AffineExpr> newExprs;
670 for (AffineExpr expr : indexingMap.getResults()) {
671 unsigned pos = cast<AffineDimExpr>(Val&: expr).getPosition();
672 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
673 Range: llvm::map_range(C: expansionInfo.getExpandedDims(i: pos), F: [&](int64_t v) {
674 return builder.getAffineDimExpr(position: static_cast<unsigned>(v));
675 }));
676 newExprs.append(in_start: expandedExprs.begin(), in_end: expandedExprs.end());
677 }
678 return AffineMap::get(dimCount: expansionInfo.getExpandedOpNumDims(),
679 symbolCount: indexingMap.getNumSymbols(), results: newExprs,
680 context: builder.getContext());
681}
682
683/// Return the shape and type of the operand/result to use in the expanded op
684/// given the type in the original op.
685static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
686getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
687 const ExpansionInfo &expansionInfo) {
688 SmallVector<OpFoldResult> expandedShape;
689 for (AffineExpr expr : indexingMap.getResults()) {
690 unsigned dim = cast<AffineDimExpr>(Val&: expr).getPosition();
691 ArrayRef<OpFoldResult> dimExpansion =
692 expansionInfo.getExpandedShapeOfDim(i: dim);
693 expandedShape.append(in_start: dimExpansion.begin(), in_end: dimExpansion.end());
694 }
695 SmallVector<int64_t> expandedStaticShape;
696 std::tie(args&: expandedStaticShape, args: std::ignore) =
697 decomposeMixedValues(mixedValues: expandedShape);
698 return {expandedShape, RankedTensorType::get(shape: expandedStaticShape,
699 elementType: originalType.getElementType())};
700}
701
702/// Returns the reassociation maps to use in the `tensor.expand_shape`
703/// operation to convert the operands of the original operation to operands of
704/// the expanded operation. The same method is used to compute the
705/// `tensor.collapse_shape` used to collapse the result of the expanded
706/// op to get the value that can replace all uses of the results of the original
707/// op.
708static SmallVector<ReassociationIndices>
709getReassociationForExpansion(AffineMap indexingMap,
710 const ExpansionInfo &expansionInfo) {
711 SmallVector<ReassociationIndices> reassociation;
712 unsigned numReshapeDims = 0;
713 for (AffineExpr expr : indexingMap.getResults()) {
714 unsigned dim = cast<AffineDimExpr>(Val&: expr).getPosition();
715 auto numExpandedDims = expansionInfo.getExpandedDims(i: dim).size();
716 SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
717 Range: llvm::seq<int64_t>(Begin: numReshapeDims, End: numReshapeDims + numExpandedDims));
718 reassociation.emplace_back(Args: std::move(indices));
719 numReshapeDims += numExpandedDims;
720 }
721 return reassociation;
722}
723
724/// Update the body of an expanded linalg operation having index semantics. The
725/// indices of the original operation need to be recovered by linearizing the
726/// indices of the correspoding dimensions of the expanded operation. For now it
727/// is assumed that the shapes of the expanded operation needed for
728/// linearization are static.
729static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
730 Location loc, Region &fusedRegion,
731 const ExpansionInfo &expansionInfo) {
732 // Replace the original indices by the linearization of the expanded indices.
733 for (IndexOp indexOp :
734 llvm::make_early_inc_range(Range: fusedRegion.front().getOps<IndexOp>())) {
735 ArrayRef<int64_t> expandedDims =
736 expansionInfo.getExpandedDims(i: indexOp.getDim());
737 assert(!expandedDims.empty() && "expected valid expansion info");
738
739 // Skip index operations that are not affected by the expansion.
740 if (expandedDims.size() == 1 &&
741 expandedDims.front() == (int64_t)indexOp.getDim())
742 continue;
743
744 // Linearize the expanded indices of the original index dimension.
745 OpBuilder::InsertionGuard guard(rewriter);
746 rewriter.setInsertionPointAfter(indexOp);
747 ArrayRef<OpFoldResult> expandedDimsShape =
748 expansionInfo.getExpandedShapeOfDim(i: indexOp.getDim()).drop_front();
749 SmallVector<Value> expandedIndices;
750 expandedIndices.reserve(N: expandedDims.size() - 1);
751 llvm::transform(
752 Range: expandedDims.drop_front(), d_first: std::back_inserter(x&: expandedIndices),
753 F: [&](int64_t dim) { return rewriter.create<IndexOp>(location: loc, args&: dim); });
754 OpFoldResult newIndex =
755 rewriter.create<IndexOp>(location: loc, args: expandedDims.front()).getResult();
756 for (auto [expandedShape, expandedIndex] :
757 llvm::zip(t&: expandedDimsShape, u&: expandedIndices)) {
758 AffineExpr idx, acc, shape;
759 bindDims(ctx: rewriter.getContext(), exprs&: idx, exprs&: acc);
760 bindSymbols(ctx: rewriter.getContext(), exprs&: shape);
761 newIndex = affine::makeComposedFoldedAffineApply(
762 b&: rewriter, loc: indexOp.getLoc(), expr: idx + acc * shape,
763 operands: ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
764 }
765 Value newIndexVal =
766 getValueOrCreateConstantIndexOp(b&: rewriter, loc: indexOp.getLoc(), ofr: newIndex);
767 rewriter.replaceOp(op: indexOp, newValues: newIndexVal);
768 }
769}
770
771// Create an expanded transpose op.
772// the reassociation map is already permuted hence we inverse permute and then
773// flatten it. Then we inverse permute it again to get the final expanded
774// transpose permutation. For example,
775//
776// permutation = [2, 0, 1]
777// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
778//
779// inverse permutation = [1, 2, 0]
780// applied to reassocation_map and then flattened becomes
781// flatened permutation = [2, 3, 4, 5, 0, 1]
782// final permuation is the inverse of the flattened permutation.
783//
784// Becomes
785//
786// permutation=[4, 5, 0, 1, 2, 3]
787
788static Operation *createExpandedTransposeOp(PatternRewriter &rewriter,
789 TransposeOp transposeOp,
790 Value expandedInput, Value output,
791 ExpansionInfo &expansionInfo) {
792 SmallVector<int64_t> newPerm;
793 for (int64_t perm : invertPermutationVector(permutation: transposeOp.getPermutation())) {
794 auto reassoc = expansionInfo.getExpandedDims(i: perm);
795 for (int64_t dim : reassoc) {
796 newPerm.push_back(Elt: dim);
797 }
798 }
799 return rewriter.create<TransposeOp>(location: transposeOp.getLoc(), args&: expandedInput,
800 args&: output, args: invertPermutationVector(permutation: newPerm));
801}
802
803// Create an expanded generic op.
804static Operation *createExpandedGenericOp(
805 PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
806 ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
807 ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
808 // The iterator types of the expanded op are all parallel.
809 SmallVector<utils::IteratorType> iteratorTypes(
810 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
811
812 for (auto [i, type] : llvm::enumerate(First: linalgOp.getIteratorTypesArray()))
813 for (auto j : expansionInfo.getExpandedDims(i))
814 iteratorTypes[j] = type;
815
816 Operation *fused = rewriter.create<GenericOp>(
817 location: linalgOp.getLoc(), args&: resultTypes, args&: expandedOpOperands, args&: outputs,
818 args&: expandedOpIndexingMaps, args&: iteratorTypes);
819
820 Region &fusedRegion = fused->getRegion(index: 0);
821 Region &originalRegion = linalgOp->getRegion(index: 0);
822 rewriter.cloneRegionBefore(region&: originalRegion, parent&: fusedRegion, before: fusedRegion.begin());
823
824 // Update the index accesses after the expansion.
825 updateExpandedGenericOpRegion(rewriter, loc: linalgOp.getLoc(), fusedRegion,
826 expansionInfo);
827
828 return fused;
829}
830
831// Create an expanded fused op that retains the name for certain ops
832// such as fill, copy and transpose and produce a generic op for
833// rest of linalg ops.
834static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
835 TypeRange resultTypes,
836 ArrayRef<Value> expandedOpOperands,
837 ArrayRef<Value> outputs,
838 ArrayRef<AffineMap> expandedOpIndexingMaps,
839 ExpansionInfo &expansionInfo) {
840
841 return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
842 .Case<TransposeOp>(caseFn: [&](TransposeOp transposeOp) {
843 return createExpandedTransposeOp(rewriter, transposeOp,
844 expandedInput: expandedOpOperands[0], output: outputs[0],
845 expansionInfo);
846 })
847 .Case<FillOp, CopyOp>(caseFn: [&](Operation *op) {
848 return clone(b&: rewriter, op: linalgOp, newResultTypes: resultTypes,
849 newOperands: llvm::to_vector(Range: llvm::concat<Value>(
850 Ranges: llvm::to_vector(Range&: expandedOpOperands),
851 Ranges: llvm::to_vector(Range&: outputs))));
852 })
853 .Default(defaultFn: [&](Operation *op) {
854 return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
855 expandedOpOperands, outputs,
856 expansionInfo, expandedOpIndexingMaps);
857 });
858}
859
860/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
861/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
862/// that those conditions have been satisfied.
863static std::optional<SmallVector<Value>>
864fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
865 OpOperand *fusableOpOperand,
866 PatternRewriter &rewriter) {
867 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
868 "preconditions for fuse operation failed");
869
870 Location loc = linalgOp.getLoc();
871 SmallVector<OpFoldResult> expandedShape;
872 SmallVector<AffineMap, 4> reassociationIndices;
873 Value src;
874 if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(Val: reshapeOp)) {
875 // Try to move the dynamic dimensions in output shape before the `linalgOp`
876 // to maintain SSA validity
877 if (failed(Result: moveValueDefinitions(
878 rewriter, values: expandingReshapeOp.getOutputShape(), insertionPoint: linalgOp)))
879 return std::nullopt;
880
881 expandedShape = expandingReshapeOp.getMixedOutputShape();
882 reassociationIndices = expandingReshapeOp.getReassociationMaps();
883 src = expandingReshapeOp.getSrc();
884 } else {
885 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(Val: reshapeOp);
886 if (!collapsingReshapeOp)
887 return std::nullopt;
888
889 expandedShape = tensor::getMixedSizes(
890 builder&: rewriter, loc: collapsingReshapeOp->getLoc(), value: collapsingReshapeOp.getSrc());
891 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
892 src = collapsingReshapeOp.getSrc();
893 }
894
895 ExpansionInfo expansionInfo;
896 if (failed(Result: expansionInfo.compute(linalgOp, fusableOpOperand,
897 reassociationMaps: reassociationIndices, expandedShape,
898 rewriter)))
899 return std::nullopt;
900
901 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
902 Range: llvm::map_range(C: linalgOp.getIndexingMapsArray(), F: [&](AffineMap m) {
903 return getIndexingMapInExpandedOp(builder&: rewriter, indexingMap: m, expansionInfo);
904 }));
905
906 // Set insertion point to the generic op.
907 OpBuilder::InsertionGuard g(rewriter);
908 rewriter.setInsertionPoint(linalgOp);
909
910 SmallVector<Value> expandedOpOperands;
911 expandedOpOperands.reserve(N: linalgOp.getNumDpsInputs());
912 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
913 if (opOperand == fusableOpOperand) {
914 expandedOpOperands.push_back(Elt: src);
915 continue;
916 }
917 if (auto opOperandType =
918 dyn_cast<RankedTensorType>(Val: opOperand->get().getType())) {
919 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
920 SmallVector<OpFoldResult> expandedOperandShape;
921 RankedTensorType expandedOperandType;
922 std::tie(args&: expandedOperandShape, args&: expandedOperandType) =
923 getExpandedShapeAndType(originalType: opOperandType, indexingMap, expansionInfo);
924 if (expandedOperandType != opOperand->get().getType()) {
925 // Reshape the operand to get the right type.
926 SmallVector<ReassociationIndices> reassociation =
927 getReassociationForExpansion(indexingMap, expansionInfo);
928 if (failed(Result: reshapeLikeShapesAreCompatible(
929 emitError: [&](const Twine &msg) {
930 return rewriter.notifyMatchFailure(arg&: linalgOp, msg);
931 },
932 collapsedShape: opOperandType.getShape(), expandedShape: expandedOperandType.getShape(),
933 reassociationMaps: reassociation,
934 /*isExpandingReshape=*/true)))
935 return std::nullopt;
936 expandedOpOperands.push_back(Elt: rewriter.create<tensor::ExpandShapeOp>(
937 location: loc, args&: expandedOperandType, args: opOperand->get(), args&: reassociation,
938 args&: expandedOperandShape));
939 continue;
940 }
941 }
942 expandedOpOperands.push_back(Elt: opOperand->get());
943 }
944
945 SmallVector<Value> outputs;
946 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
947 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand: &opOperand);
948 auto opOperandType = cast<RankedTensorType>(Val: opOperand.get().getType());
949 SmallVector<OpFoldResult> expandedOutputShape;
950 RankedTensorType expandedOutputType;
951 std::tie(args&: expandedOutputShape, args&: expandedOutputType) =
952 getExpandedShapeAndType(originalType: opOperandType, indexingMap, expansionInfo);
953 if (expandedOutputType != opOperand.get().getType()) {
954 SmallVector<ReassociationIndices> reassociation =
955 getReassociationForExpansion(indexingMap, expansionInfo);
956 if (failed(Result: reshapeLikeShapesAreCompatible(
957 emitError: [&](const Twine &msg) {
958 return rewriter.notifyMatchFailure(arg&: linalgOp, msg);
959 },
960 collapsedShape: opOperandType.getShape(), expandedShape: expandedOutputType.getShape(),
961 reassociationMaps: reassociation,
962 /*isExpandingReshape=*/true)))
963 return std::nullopt;
964 outputs.push_back(Elt: rewriter.create<tensor::ExpandShapeOp>(
965 location: loc, args&: expandedOutputType, args: opOperand.get(), args&: reassociation,
966 args&: expandedOutputShape));
967 } else {
968 outputs.push_back(Elt: opOperand.get());
969 }
970 }
971
972 TypeRange resultTypes = ValueRange(outputs).getTypes();
973 Operation *fusedOp =
974 createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
975 outputs, expandedOpIndexingMaps, expansionInfo);
976 // Reshape the result values to their original shape if this is a collapsing
977 // reshape folded into its consumer.
978 SmallVector<Value> resultVals;
979 for (OpResult opResult : linalgOp->getOpResults()) {
980 int64_t resultNumber = opResult.getResultNumber();
981 if (resultTypes[resultNumber] != opResult.getType()) {
982 SmallVector<ReassociationIndices> reassociation =
983 getReassociationForExpansion(
984 indexingMap: linalgOp.getMatchingIndexingMap(
985 opOperand: linalgOp.getDpsInitOperand(i: resultNumber)),
986 expansionInfo);
987 resultVals.push_back(Elt: rewriter.create<tensor::CollapseShapeOp>(
988 location: linalgOp.getLoc(), args: opResult.getType(),
989 args: fusedOp->getResult(idx: resultNumber), args&: reassociation));
990 } else {
991 resultVals.push_back(Elt: fusedOp->getResult(idx: resultNumber));
992 }
993 }
994 // Assuming a single result.
995 return resultVals;
996}
997
998namespace {
999
1000/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
1001/// when the reshape op is collapsing dimensions. The dimensionality of the loop
1002/// in the consumer is expanded.
1003class FoldWithProducerReshapeOpByExpansion
1004 : public OpInterfaceRewritePattern<LinalgOp> {
1005public:
1006 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1007 ControlFusionFn foldReshapes,
1008 PatternBenefit benefit = 1)
1009 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1010 controlFoldingReshapes(std::move(foldReshapes)) {}
1011
1012 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1013 PatternRewriter &rewriter) const override {
1014 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1015 tensor::CollapseShapeOp reshapeOp =
1016 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1017 if (!reshapeOp)
1018 continue;
1019 // Fold only if
1020 // - The tensor reshape op is folding.
1021 // - All constraints of fusing with reshape by expansion are met.
1022 if (!isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand: opOperand) ||
1023 (!controlFoldingReshapes(opOperand)))
1024 continue;
1025
1026 std::optional<SmallVector<Value>> replacementValues =
1027 fuseWithReshapeByExpansion(linalgOp, reshapeOp, fusableOpOperand: opOperand, rewriter);
1028 if (!replacementValues)
1029 return failure();
1030 rewriter.replaceOp(op: linalgOp, newValues: *replacementValues);
1031 return success();
1032 }
1033 return failure();
1034 }
1035
1036private:
1037 ControlFusionFn controlFoldingReshapes;
1038};
1039
1040class FoldPadWithProducerReshapeOpByExpansion
1041 : public OpRewritePattern<tensor::PadOp> {
1042public:
1043 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1044 ControlFusionFn foldReshapes,
1045 PatternBenefit benefit = 1)
1046 : OpRewritePattern<tensor::PadOp>(context, benefit),
1047 controlFoldingReshapes(std::move(foldReshapes)) {}
1048
1049 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1050 PatternRewriter &rewriter) const override {
1051 tensor::CollapseShapeOp reshapeOp =
1052 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1053 if (!reshapeOp)
1054 return failure();
1055 if (!reshapeOp->hasOneUse())
1056 return failure();
1057
1058 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1059 return rewriter.notifyMatchFailure(arg&: padOp,
1060 msg: "fusion blocked by control function");
1061 }
1062
1063 ArrayRef<int64_t> low = padOp.getStaticLow();
1064 ArrayRef<int64_t> high = padOp.getStaticHigh();
1065 SmallVector<ReassociationIndices> reassociations =
1066 reshapeOp.getReassociationIndices();
1067
1068 for (auto [reInd, l, h] : llvm::zip_equal(t&: reassociations, u&: low, args&: high)) {
1069 if (reInd.size() != 1 && (l != 0 || h != 0))
1070 return failure();
1071 }
1072
1073 SmallVector<OpFoldResult> newLow, newHigh;
1074 RankedTensorType expandedType = reshapeOp.getSrcType();
1075 RankedTensorType paddedType = padOp.getResultType();
1076 SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1077 for (auto [idx, reInd] : llvm::enumerate(First&: reassociations)) {
1078 if (reInd.size() == 1) {
1079 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1080 }
1081 for (size_t i = 0; i < reInd.size(); ++i) {
1082 newLow.push_back(Elt: padOp.getMixedLowPad()[idx]);
1083 newHigh.push_back(Elt: padOp.getMixedHighPad()[idx]);
1084 }
1085 }
1086
1087 Location loc = padOp->getLoc();
1088 RankedTensorType expandedPaddedType = paddedType.clone(shape: expandedPaddedShape);
1089 auto newPadOp = rewriter.create<tensor::PadOp>(
1090 location: loc, args&: expandedPaddedType, args: reshapeOp.getSrc(), args&: newLow, args&: newHigh,
1091 args: padOp.getConstantPaddingValue(), args: padOp.getNofold());
1092
1093 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1094 op: padOp, args: padOp.getResultType(), args: newPadOp.getResult(), args&: reassociations);
1095
1096 return success();
1097 }
1098
1099private:
1100 ControlFusionFn controlFoldingReshapes;
1101};
1102
1103/// Pattern to fold a tensor.expand_shape op with its producer generic op
1104/// by expanding the dimensionality of the loop in the producer op.
1105struct FoldReshapeWithGenericOpByExpansion
1106 : public OpRewritePattern<tensor::ExpandShapeOp> {
1107
1108 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1109 ControlFusionFn foldReshapes,
1110 PatternBenefit benefit = 1)
1111 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1112 controlFoldingReshapes(std::move(foldReshapes)) {}
1113
1114 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1115 PatternRewriter &rewriter) const override {
1116 // Fold only if all constraints of fusing with reshape by expansion are met.
1117 auto producerResult = dyn_cast<OpResult>(Val: reshapeOp.getSrc());
1118 if (!producerResult) {
1119 return rewriter.notifyMatchFailure(arg&: reshapeOp,
1120 msg: "source not produced by an operation");
1121 }
1122
1123 auto producer = dyn_cast<LinalgOp>(Val: producerResult.getOwner());
1124 if (!producer) {
1125 return rewriter.notifyMatchFailure(arg&: reshapeOp,
1126 msg: "producer not a generic op");
1127 }
1128
1129 if (!isFusableWithReshapeByDimExpansion(
1130 linalgOp: producer,
1131 fusableOpOperand: producer.getDpsInitOperand(i: producerResult.getResultNumber()))) {
1132 return rewriter.notifyMatchFailure(
1133 arg&: reshapeOp, msg: "failed preconditions of fusion with producer generic op");
1134 }
1135
1136 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1137 return rewriter.notifyMatchFailure(arg&: reshapeOp,
1138 msg: "fusion blocked by control function");
1139 }
1140
1141 std::optional<SmallVector<Value>> replacementValues =
1142 fuseWithReshapeByExpansion(
1143 linalgOp: producer, reshapeOp,
1144 fusableOpOperand: producer.getDpsInitOperand(i: producerResult.getResultNumber()),
1145 rewriter);
1146 if (!replacementValues) {
1147 return rewriter.notifyMatchFailure(arg&: reshapeOp,
1148 msg: "fusion by expansion failed");
1149 }
1150
1151 // Find the replacement for the reshape op. Since the replacements have the
1152 // same type as the returns of the original generic op, the consumer reshape
1153 // op can be replaced by the source of the collapse_shape op that defines
1154 // the replacement.
1155 Value reshapeReplacement =
1156 (*replacementValues)[cast<OpResult>(Val: reshapeOp.getSrc())
1157 .getResultNumber()];
1158 if (auto collapseOp =
1159 reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1160 reshapeReplacement = collapseOp.getSrc();
1161 }
1162 rewriter.replaceOp(op: reshapeOp, newValues: reshapeReplacement);
1163 rewriter.replaceOp(op: producer, newValues: *replacementValues);
1164 return success();
1165 }
1166
1167private:
1168 ControlFusionFn controlFoldingReshapes;
1169};
1170} // namespace
1171
1172//===---------------------------------------------------------------------===//
1173// Methods and patterns to fuse reshape with linalg.generic operations by
1174// contraction of dimensions.
1175//===---------------------------------------------------------------------===//
1176
1177/// For a given list of indices in the range of the `indexingMap` that are
1178/// folded, return the indices of the corresponding domain. Return
1179/// `std::nullopt` on failure. Ensures that all the elements of the returned
1180/// reassociation are distinct.
1181static ReassociationIndices
1182getDomainReassociation(AffineMap indexingMap,
1183 ReassociationIndicesRef rangeReassociation) {
1184 assert(indexingMap.isProjectedPermutation() &&
1185 "expected projected permutation");
1186
1187 ReassociationIndices domainReassociation = llvm::to_vector<4>(
1188 Range: llvm::map_range(C&: rangeReassociation, F: [&](int64_t pos) -> int64_t {
1189 return cast<AffineDimExpr>(Val: indexingMap.getResults()[pos]).getPosition();
1190 }));
1191 // The projected permutation semantics ensures that there is no repetition of
1192 // the domain indices.
1193 return domainReassociation;
1194}
1195
1196/// For a given `dimSequence`, check if the sequence is conserved in the
1197/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1198/// Non-existence of the sequence returns true as well.
1199bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
1200 ReassociationIndicesRef dimSequence) {
1201 assert(!dimSequence.empty() &&
1202 "expected non-empty list for dimension sequence");
1203 assert(indexingMap.isProjectedPermutation() &&
1204 "expected indexing map to be projected permutation");
1205
1206 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1207 sequenceElements.insert_range(R&: dimSequence);
1208
1209 unsigned dimSequenceStart = dimSequence[0];
1210 for (const auto &expr : enumerate(First: indexingMap.getResults())) {
1211 unsigned dimInMapStart = cast<AffineDimExpr>(Val: expr.value()).getPosition();
1212 // 1. Check if this start of the sequence.
1213 if (dimInMapStart == dimSequenceStart) {
1214 if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1215 return false;
1216 // 1a. Check if sequence is preserved.
1217 for (const auto &dimInSequence : enumerate(First&: dimSequence)) {
1218 unsigned dimInMap =
1219 cast<AffineDimExpr>(
1220 Val: indexingMap.getResult(idx: expr.index() + dimInSequence.index()))
1221 .getPosition();
1222 if (dimInMap != dimInSequence.value())
1223 return false;
1224 }
1225 // Found the sequence. Projected permutation
1226 // enforces that all AffineDimExprs in the result are unique, so no
1227 // further checks are needed.
1228 return true;
1229 }
1230 // 2. If position in the expr (which is of type AffineDimExpr) is part
1231 // of sequence, return false here. This implies the entire sequence does not
1232 // exist in the indexing map.
1233 if (sequenceElements.count(V: dimInMapStart))
1234 return false;
1235 }
1236 // 3. No element of sequence found. Return true.
1237 return true;
1238}
1239
1240bool mlir::linalg::areDimSequencesPreserved(
1241 ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {
1242 return llvm::all_of(Range&: maps, P: [&](AffineMap map) {
1243 return llvm::all_of(Range&: dimSequences, P: [&](ReassociationIndicesRef dimSequence) {
1244 return isDimSequencePreserved(indexingMap: map, dimSequence);
1245 });
1246 });
1247}
1248
1249// Return the list of dimensions of the iteration domain that can be
1250// collapsed to allow for fusion with the a producer that is an expand_shape
1251// operation. If all dimensions created by expansion can be collapsed in the
1252// iteration space then the reshape is defunct.
1253//
1254// Example:
1255//
1256// ```mlir
1257// #map = affine_map<(d0, d1) -> (d0, d1)>
1258// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1259// %2 = tensor.empty [..] : tensor<?x4xf32>
1260// %3 = linalg.generic {
1261// indexing_maps = [#map, #map],
1262// iterator_types = ["parallel" ,"parallel"]}
1263// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1264// ```
1265//
1266// can be fused by collapsing the dimensions of the iteration space.
1267//
1268// ```mlir
1269// #map = affine_map<(d0) -> (d0)>
1270// %2 = tensor.empty [..] : tensor<?xf32>
1271// %3 = linalg.generic {
1272// indexing_maps = [#map, #map],
1273// iterator_types = ["parallel"]}
1274// ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1275// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1276// ```
1277//
1278// In the following example,
1279//
1280// ```mlir
1281// #map0 = affine_map<(d0, d1) -> (d0, d1)>
1282// #map1 = affine_map<(d0, d1) -> (d1, d0)>
1283// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1284// %2 = tensor.empty [..] : tensor<4x?xf32>
1285// %2 = linalg.generic {
1286// indexing_maps = [#map0, #map1],
1287// iterator_types = ["parallel" ,"parallel"]}
1288// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1289// ```
1290//
1291// the reshape cannot be fused with the generic op by collapsing the op
1292// dimensions since the indexing maps will have to contain mods and divs
1293// to preserve the accesses pattern. When no dimensions of the iteration
1294// space are collapsable and empty vector is returned.
1295static SmallVector<ReassociationIndices>
1296getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1297 ArrayRef<ReassociationIndices> reassociation) {
1298 // Some basic checks for this fusion to be valid.
1299 if (!genericOp.hasPureTensorSemantics())
1300 return {};
1301
1302 if (!llvm::all_of(Range: genericOp.getIndexingMapsArray(), P: [](AffineMap map) {
1303 return map.isProjectedPermutation();
1304 })) {
1305 return {};
1306 }
1307
1308 // Compute all the loops with the reduction iterator types.
1309 SmallVector<unsigned> reductionDims;
1310 genericOp.getReductionDims(res&: reductionDims);
1311
1312 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1313 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand: fusableOperand);
1314 auto iteratorTypes = genericOp.getIteratorTypesArray();
1315 SmallVector<ReassociationIndices> iterationSpaceReassociation;
1316 for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1317 assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1318
1319 // Ignore dims that are not folded.
1320 if (foldedRangeDims.size() == 1)
1321 continue;
1322
1323 ReassociationIndices foldedIterationSpaceDims =
1324 getDomainReassociation(indexingMap, rangeReassociation: foldedRangeDims);
1325
1326 // Check that the folded iteration dims do not contain already processed
1327 // dims.
1328 if (llvm::any_of(Range&: foldedIterationSpaceDims, P: [&](int64_t dim) {
1329 return processedIterationDims.count(V: dim);
1330 }))
1331 continue;
1332
1333 // Check that all folded iterator types are all parallel or all reductions.
1334 utils::IteratorType startIteratorType =
1335 iteratorTypes[foldedIterationSpaceDims[0]];
1336 if (!isParallelIterator(iteratorType: startIteratorType) &&
1337 !isReductionIterator(iteratorType: startIteratorType))
1338 continue;
1339 if (llvm::any_of(Range&: foldedIterationSpaceDims, P: [&](int64_t dim) {
1340 return iteratorTypes[dim] != startIteratorType;
1341 }))
1342 continue;
1343
1344 // If the folded dimensions correspond to a "reduction" iterator type,
1345 // the folded dimensions need to be "in-order". Strictly speaking this is
1346 // not necessary, for reductions that are associative and commutative, but
1347 // using a more strict definition of reduction for now.
1348 if (isReductionIterator(iteratorType: startIteratorType)) {
1349 bool isContiguous = false;
1350 for (const auto &startDim : llvm::enumerate(First&: reductionDims)) {
1351 // Move window in `reductionDims` to start of the folded iteration dims.
1352 if (startDim.value() != foldedIterationSpaceDims[0])
1353 continue;
1354 // If sizes doesnt match, trivial not contiguous. This condition should
1355 // not be hit.
1356 if (startDim.index() + foldedIterationSpaceDims.size() >
1357 reductionDims.size())
1358 break;
1359 // Check that the contiguity is maintained.
1360 isContiguous = true;
1361 for (const auto &foldedDim :
1362 llvm::enumerate(First&: foldedIterationSpaceDims)) {
1363 if (reductionDims[foldedDim.index() + startDim.index()] !=
1364 foldedDim.value()) {
1365 isContiguous = false;
1366 break;
1367 }
1368 }
1369 break;
1370 }
1371 if (!isContiguous)
1372 continue;
1373 }
1374
1375 // Check that the sequence is preserved in all indexing maps.
1376 if (llvm::any_of(Range: genericOp.getIndexingMapsArray(),
1377 P: [&](AffineMap indexingMap) {
1378 return !isDimSequencePreserved(indexingMap,
1379 dimSequence: foldedIterationSpaceDims);
1380 }))
1381 continue;
1382
1383 processedIterationDims.insert_range(R&: foldedIterationSpaceDims);
1384 iterationSpaceReassociation.emplace_back(
1385 Args: std::move(foldedIterationSpaceDims));
1386 }
1387
1388 return iterationSpaceReassociation;
1389}
1390
1391/// Helper class to carry state while collapsing the `linalg.generic` op.
1392namespace {
1393class CollapsingInfo {
1394public:
1395 LogicalResult initialize(unsigned origNumLoops,
1396 ArrayRef<ReassociationIndices> foldedIterationDims) {
1397 llvm::SmallDenseSet<int64_t, 4> processedDims;
1398 // Find all the dims that are folded.
1399 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1400 if (foldedIterationDim.empty())
1401 continue;
1402 // If the folded dims contain dims already folded, that's illegal
1403 // specification. Repetition within a list is also illegal.
1404 for (auto dim : foldedIterationDim) {
1405 if (dim >= origNumLoops)
1406 return failure();
1407 if (processedDims.count(V: dim))
1408 return failure();
1409 processedDims.insert(V: dim);
1410 }
1411 collapsedOpToOrigOpIterationDim.emplace_back(Args: foldedIterationDim.begin(),
1412 Args: foldedIterationDim.end());
1413 }
1414 if (processedDims.size() > origNumLoops)
1415 return failure();
1416
1417 // Add all the preserved dims of the original op as single
1418 // elements to `collapsedOpToOrigOpIterationDim`.
1419 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: origNumLoops)) {
1420 if (processedDims.count(V: dim))
1421 continue;
1422 collapsedOpToOrigOpIterationDim.emplace_back(Args: ReassociationIndices{dim});
1423 }
1424
1425 llvm::sort(C&: collapsedOpToOrigOpIterationDim,
1426 Comp: [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1427 return lhs[0] < rhs[0];
1428 });
1429 origOpToCollapsedOpIterationDim.resize(N: origNumLoops);
1430 for (const auto &foldedDims :
1431 llvm::enumerate(First&: collapsedOpToOrigOpIterationDim)) {
1432 for (const auto &dim : enumerate(First&: foldedDims.value()))
1433 origOpToCollapsedOpIterationDim[dim.value()] =
1434 std::make_pair<int64_t, unsigned>(x: foldedDims.index(), y: dim.index());
1435 }
1436 return success();
1437 }
1438
1439 /// Return mapping from collapsed loop domain to original loop domain.
1440 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1441 return collapsedOpToOrigOpIterationDim;
1442 }
1443
1444 /// Return mapping from original loop domain to collapsed loop domain. The
1445 /// mapping is a pair. First value is the dimension in the collapsed loop that
1446 /// the original loop is mapped to. Second is the relative position in folded
1447 /// list of this domain. For example if the original loop domain is 3D, and
1448 /// the collapsed loop domain is folding all of it, i.e.
1449 ///
1450 /// ```
1451 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1452 /// ```
1453 ///
1454 /// then
1455 ///
1456 /// ```
1457 /// origOpToCollapsedOpMapping[0] = {0, 0};
1458 /// origOpToCollapsedOpMapping[1] = {0, 1};
1459 /// origOpToCollapsedOpMapping[2] = {0, 2};
1460 /// origOpToCollapsedOpMapping[3] = {1, 0};
1461 /// origOpToCollapsedOpMapping[4] = {1, 1};
1462 /// ```
1463 ///
1464 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1465 return origOpToCollapsedOpIterationDim;
1466 }
1467
1468 /// Return the collapsed op iteration domain rank.
1469 unsigned getCollapsedOpIterationRank() const {
1470 return collapsedOpToOrigOpIterationDim.size();
1471 }
1472
1473private:
1474 /// Map from the iteration domain index in collapsed op to the iteration
1475 /// domain indices in the original op.
1476 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1477
1478 /// Map from iteration domain index in the original op to the iteration domain
1479 /// index in the collapsed op.
1480 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1481};
1482} // namespace
1483
1484/// Get the iterator types for the collapsed operation given the original
1485/// iterator types and collapsed dimensions.
1486static SmallVector<utils::IteratorType>
1487getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1488 const CollapsingInfo &collapsingInfo) {
1489 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1490 for (ReassociationIndicesRef foldedIterDims :
1491 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1492 assert(!foldedIterDims.empty() &&
1493 "reassociation indices expected to have non-empty sets");
1494 // Just pick the iterator type of the first folded dim. Pre-condition checks
1495 // expected to have checked that iterator types of all folded dimensions are
1496 // the same.
1497 collapsedIteratorTypes.push_back(Elt: iteratorTypes[foldedIterDims[0]]);
1498 }
1499 return collapsedIteratorTypes;
1500}
1501
1502/// Compute the indexing map in the collapsed op that corresponds to the given
1503/// `indexingMap` of the original operation.
1504static AffineMap
1505getCollapsedOpIndexingMap(AffineMap indexingMap,
1506 const CollapsingInfo &collapsingInfo) {
1507 MLIRContext *context = indexingMap.getContext();
1508 assert(indexingMap.isProjectedPermutation() &&
1509 "expected indexing map to be projected permutation");
1510 SmallVector<AffineExpr> resultExprs;
1511 auto origOpToCollapsedOpMapping =
1512 collapsingInfo.getOrigOpToCollapsedOpMapping();
1513 for (auto expr : indexingMap.getResults()) {
1514 unsigned dim = cast<AffineDimExpr>(Val&: expr).getPosition();
1515 // If the dim is not the first of the collapsed dim, do nothing.
1516 if (origOpToCollapsedOpMapping[dim].second != 0)
1517 continue;
1518 // The next n-dims are guaranteed to be collapsed. So just use the
1519 // iteration dimension of the collapsed op.
1520 resultExprs.push_back(
1521 Elt: getAffineDimExpr(position: origOpToCollapsedOpMapping[dim].first, context));
1522 }
1523 return AffineMap::get(dimCount: collapsingInfo.getCollapsedOpIterationRank(), symbolCount: 0,
1524 results: resultExprs, context);
1525}
1526
1527/// Return the `reassociation` indices to use to collapse the operand when the
1528/// iteration space of a generic op is collapsed.
1529static SmallVector<ReassociationIndices>
1530getOperandReassociation(AffineMap indexingMap,
1531 const CollapsingInfo &collapsingInfo) {
1532 unsigned counter = 0;
1533 SmallVector<ReassociationIndices> operandReassociation;
1534 auto origOpToCollapsedOpMapping =
1535 collapsingInfo.getOrigOpToCollapsedOpMapping();
1536 auto collapsedOpToOrigOpMapping =
1537 collapsingInfo.getCollapsedOpToOrigOpMapping();
1538 while (counter < indexingMap.getNumResults()) {
1539 unsigned dim =
1540 cast<AffineDimExpr>(Val: indexingMap.getResult(idx: counter)).getPosition();
1541 // This is the start of a collapsed dimensions of the iteration that
1542 // is gauranteed to be preserved in the indexing map. The number of folded
1543 // dims is obtained from the collapsed op to original op mapping.
1544 unsigned numFoldedDims =
1545 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1546 .size();
1547 if (origOpToCollapsedOpMapping[dim].second == 0) {
1548 auto range = llvm::seq<unsigned>(Begin: counter, End: counter + numFoldedDims);
1549 operandReassociation.emplace_back(Args: range.begin(), Args: range.end());
1550 }
1551 counter += numFoldedDims;
1552 }
1553 return operandReassociation;
1554}
1555
1556/// Get the new value to use for a given `OpOperand` in the collapsed operation.
1557static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1558 OpOperand *opOperand,
1559 const CollapsingInfo &collapsingInfo,
1560 OpBuilder &builder) {
1561 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1562 SmallVector<ReassociationIndices> operandReassociation =
1563 getOperandReassociation(indexingMap, collapsingInfo);
1564
1565 // If the number of entries in the reassociation for the operand is same as
1566 // the number of results of the indexing map, then nothing to do for this
1567 // operand.
1568 Value operand = opOperand->get();
1569 if (operandReassociation.size() == indexingMap.getNumResults())
1570 return operand;
1571
1572 // Insert a reshape to collapse the dimensions.
1573 if (isa<MemRefType>(Val: operand.getType())) {
1574 return builder
1575 .create<memref::CollapseShapeOp>(location: loc, args&: operand, args&: operandReassociation)
1576 .getResult();
1577 }
1578 return builder
1579 .create<tensor::CollapseShapeOp>(location: loc, args&: operand, args&: operandReassociation)
1580 .getResult();
1581}
1582
1583/// Modify the `linalg.index` operations in the original generic op, to its
1584/// value in the collapsed operation.
1585static void generateCollapsedIndexingRegion(
1586 Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1587 ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1588 OpBuilder::InsertionGuard g(rewriter);
1589 rewriter.setInsertionPointToStart(block);
1590
1591 // Collect all the original index ops.
1592 auto indexOps = llvm::to_vector(Range: block->getOps<linalg::IndexOp>());
1593
1594 // For each folded dimension list resolve the original induction variable
1595 // values in terms of the folded dimension induction variable.
1596 // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1597 // can be inverted to
1598 // i2 = i_{folded} % d2
1599 // i1 = (i_{folded} / d2) % d1
1600 // i0 = i_{folded} / (d1 * d2)
1601 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1602 for (auto foldedDims :
1603 enumerate(First: collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1604 ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1605 Value newIndexVal =
1606 rewriter.create<linalg::IndexOp>(location: loc, args: foldedDims.index());
1607 for (auto dim : llvm::reverse(C: foldedDimsRef.drop_front())) {
1608 Value loopDim =
1609 getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange[dim]);
1610 indexReplacementVals[dim] =
1611 rewriter.createOrFold<arith::RemSIOp>(location: loc, args&: newIndexVal, args&: loopDim);
1612 newIndexVal =
1613 rewriter.createOrFold<arith::DivSIOp>(location: loc, args&: newIndexVal, args&: loopDim);
1614 }
1615 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1616 }
1617
1618 for (auto indexOp : indexOps) {
1619 auto dim = indexOp.getDim();
1620 rewriter.replaceOp(op: indexOp, newValues: indexReplacementVals[dim]);
1621 }
1622}
1623
1624void collapseOperandsAndResults(LinalgOp op,
1625 const CollapsingInfo &collapsingInfo,
1626 RewriterBase &rewriter,
1627 SmallVectorImpl<Value> &inputOperands,
1628 SmallVectorImpl<Value> &outputOperands,
1629 SmallVectorImpl<Type> &resultTypes) {
1630 Location loc = op->getLoc();
1631 inputOperands =
1632 llvm::map_to_vector(C: op.getDpsInputOperands(), F: [&](OpOperand *opOperand) {
1633 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1634 builder&: rewriter);
1635 });
1636
1637 // Get the output operands and result types.
1638 resultTypes.reserve(N: op.getNumDpsInits());
1639 outputOperands.reserve(N: op.getNumDpsInits());
1640 for (OpOperand &output : op.getDpsInitsMutable()) {
1641 Value newOutput =
1642 getCollapsedOpOperand(loc, op, opOperand: &output, collapsingInfo, builder&: rewriter);
1643 outputOperands.push_back(Elt: newOutput);
1644 // If the op has "buffer semantics", then the init operands are ranked
1645 // memrefs and the op has no results.
1646 if (!op.hasPureBufferSemantics())
1647 resultTypes.push_back(Elt: newOutput.getType());
1648 }
1649}
1650
1651/// Clone a `LinalgOp` to a collapsed version of same name
1652template <typename OpTy>
1653OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1654 const CollapsingInfo &collapsingInfo) {
1655 return nullptr;
1656}
1657
1658/// Collapse any `LinalgOp` that does not require any specialization such as
1659/// indexing_maps, iterator_types, etc.
1660template <>
1661LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1662 const CollapsingInfo &collapsingInfo) {
1663 SmallVector<Value> inputOperands, outputOperands;
1664 SmallVector<Type> resultTypes;
1665 collapseOperandsAndResults(op: origOp, collapsingInfo, rewriter, inputOperands,
1666 outputOperands, resultTypes);
1667
1668 return clone(
1669 b&: rewriter, op: origOp, newResultTypes: resultTypes,
1670 newOperands: llvm::to_vector(Range: llvm::concat<Value>(Ranges&: inputOperands, Ranges&: outputOperands)));
1671}
1672
1673/// Collapse a `GenericOp`
1674template <>
1675GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1676 GenericOp origOp,
1677 const CollapsingInfo &collapsingInfo) {
1678 SmallVector<Value> inputOperands, outputOperands;
1679 SmallVector<Type> resultTypes;
1680 collapseOperandsAndResults(op: origOp, collapsingInfo, rewriter, inputOperands,
1681 outputOperands, resultTypes);
1682 SmallVector<AffineMap> indexingMaps(
1683 llvm::map_range(C: origOp.getIndexingMapsArray(), F: [&](AffineMap map) {
1684 return getCollapsedOpIndexingMap(indexingMap: map, collapsingInfo);
1685 }));
1686
1687 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1688 iteratorTypes: origOp.getIteratorTypesArray(), collapsingInfo));
1689
1690 GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1691 location: origOp.getLoc(), args&: resultTypes, args&: inputOperands, args&: outputOperands, args&: indexingMaps,
1692 args&: iteratorTypes, args: [](OpBuilder &builder, Location loc, ValueRange args) {});
1693 Block *origOpBlock = &origOp->getRegion(index: 0).front();
1694 Block *collapsedOpBlock = &collapsedOp->getRegion(index: 0).front();
1695 rewriter.mergeBlocks(source: origOpBlock, dest: collapsedOpBlock,
1696 argValues: collapsedOpBlock->getArguments());
1697 return collapsedOp;
1698}
1699
1700LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1701 RewriterBase &rewriter) {
1702 if (GenericOp genericOp = dyn_cast<GenericOp>(Val: op.getOperation())) {
1703 return cloneToCollapsedOp(rewriter, origOp: genericOp, collapsingInfo);
1704 } else {
1705 return cloneToCollapsedOp(rewriter, origOp: op, collapsingInfo);
1706 }
1707}
1708
1709/// Implementation of fusion with reshape operation by collapsing dimensions.
1710FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1711 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1712 RewriterBase &rewriter) {
1713 // Bail on trivial no-op cases.
1714 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1715 llvm::all_of(Range&: foldedIterationDims, P: [](ReassociationIndicesRef foldedDims) {
1716 return foldedDims.size() <= 1;
1717 }))
1718 return failure();
1719
1720 CollapsingInfo collapsingInfo;
1721 if (failed(
1722 Result: collapsingInfo.initialize(origNumLoops: op.getNumLoops(), foldedIterationDims))) {
1723 return rewriter.notifyMatchFailure(
1724 arg&: op, msg: "illegal to collapse specified dimensions");
1725 }
1726
1727 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1728 if (hasPureBufferSemantics &&
1729 !llvm::all_of(Range: op->getOpOperands(), P: [&](OpOperand &opOperand) -> bool {
1730 MemRefType memRefToCollapse =
1731 dyn_cast<MemRefType>(Val: opOperand.get().getType());
1732 if (!memRefToCollapse)
1733 return true;
1734
1735 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand: &opOperand);
1736 SmallVector<ReassociationIndices> operandReassociation =
1737 getOperandReassociation(indexingMap, collapsingInfo);
1738 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1739 srcType: memRefToCollapse, reassociation: operandReassociation);
1740 }))
1741 return rewriter.notifyMatchFailure(arg&: op,
1742 msg: "memref is not guaranteed collapsible");
1743
1744 // Bail on non-canonical ranges.
1745 SmallVector<Range> loopRanges = op.createLoopRanges(b&: rewriter, loc: op.getLoc());
1746 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1747 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: ofr))
1748 return cast<IntegerAttr>(Val&: attr).getInt() == value;
1749 llvm::APInt actual;
1750 return matchPattern(value: cast<Value>(Val&: ofr), pattern: m_ConstantInt(bind_value: &actual)) &&
1751 actual.getSExtValue() == value;
1752 };
1753 if (!llvm::all_of(Range&: loopRanges, P: [&](Range range) {
1754 return opFoldIsConstantValue(range.offset, 0) &&
1755 opFoldIsConstantValue(range.stride, 1);
1756 })) {
1757 return rewriter.notifyMatchFailure(
1758 arg&: op, msg: "expected all loop ranges to have zero start and unit stride");
1759 }
1760
1761 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1762
1763 Location loc = op->getLoc();
1764 SmallVector<OpFoldResult> loopBound =
1765 llvm::map_to_vector(C&: loopRanges, F: [](Range range) { return range.size; });
1766
1767 if (collapsedOp.hasIndexSemantics()) {
1768 // Collect the loop range of the generic op.
1769 OpBuilder::InsertionGuard g(rewriter);
1770 rewriter.setInsertionPoint(collapsedOp);
1771 generateCollapsedIndexingRegion(loc, block: &collapsedOp->getRegion(index: 0).front(),
1772 collapsingInfo, loopRange: loopBound, rewriter);
1773 }
1774
1775 // Insert expanding reshape for the result to get back the original result
1776 // type.
1777 SmallVector<Value> results;
1778 for (const auto &originalResult : llvm::enumerate(First: op->getResults())) {
1779 Value collapsedOpResult = collapsedOp->getResult(idx: originalResult.index());
1780 auto originalResultType =
1781 cast<ShapedType>(Val: originalResult.value().getType());
1782 auto collapsedOpResultType = cast<ShapedType>(Val: collapsedOpResult.getType());
1783 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1784 AffineMap indexingMap =
1785 op.getIndexingMapMatchingResult(result: originalResult.value());
1786 SmallVector<ReassociationIndices> reassociation =
1787 getOperandReassociation(indexingMap, collapsingInfo);
1788 assert(
1789 indexingMap.isProjectedPermutation() &&
1790 "Expected indexing map to be a projected permutation for collapsing");
1791 SmallVector<OpFoldResult> resultShape =
1792 applyPermutationMap(map: indexingMap, source: ArrayRef(loopBound));
1793 Value result;
1794 if (isa<MemRefType>(Val: collapsedOpResult.getType())) {
1795 MemRefType expandShapeResultType = MemRefType::get(
1796 shape: originalResultType.getShape(), elementType: originalResultType.getElementType());
1797 result = rewriter.create<memref::ExpandShapeOp>(
1798 location: loc, args&: expandShapeResultType, args&: collapsedOpResult, args&: reassociation,
1799 args&: resultShape);
1800 } else {
1801 result = rewriter.create<tensor::ExpandShapeOp>(
1802 location: loc, args&: originalResultType, args&: collapsedOpResult, args&: reassociation,
1803 args&: resultShape);
1804 }
1805 results.push_back(Elt: result);
1806 } else {
1807 results.push_back(Elt: collapsedOpResult);
1808 }
1809 }
1810 return CollapseResult{.results: results, .collapsedOp: collapsedOp};
1811}
1812
1813namespace {
1814
1815/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1816/// contracting dimensions of the loop.
1817class FoldWithProducerReshapeOpByCollapsing
1818 : public OpRewritePattern<GenericOp> {
1819public:
1820 // TODO : support fusion with all linalg ops, not just generic.
1821 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1822 ControlFusionFn foldReshapes,
1823 PatternBenefit benefit = 1)
1824 : OpRewritePattern<GenericOp>(context, benefit),
1825 controlFoldingReshapes(std::move(foldReshapes)) {}
1826
1827 LogicalResult matchAndRewrite(GenericOp genericOp,
1828 PatternRewriter &rewriter) const override {
1829 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1830 tensor::ExpandShapeOp reshapeOp =
1831 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1832 if (!reshapeOp)
1833 continue;
1834
1835 SmallVector<ReassociationIndices> collapsableIterationDims =
1836 getCollapsableIterationSpaceDims(genericOp, fusableOperand: &opOperand,
1837 reassociation: reshapeOp.getReassociationIndices());
1838 if (collapsableIterationDims.empty() ||
1839 !controlFoldingReshapes(&opOperand)) {
1840 continue;
1841 }
1842
1843 std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1844 op: genericOp, foldedIterationDims: collapsableIterationDims, rewriter);
1845 if (!collapseResult) {
1846 return rewriter.notifyMatchFailure(
1847 arg&: genericOp, msg: "failed to do the fusion by collapsing transformation");
1848 }
1849
1850 rewriter.replaceOp(op: genericOp, newValues: collapseResult->results);
1851 return success();
1852 }
1853 return failure();
1854 }
1855
1856private:
1857 ControlFusionFn controlFoldingReshapes;
1858};
1859
1860/// Pattern to fold a tensor.collapse_shape op with its producer generic op
1861/// by expanding the dimensionality of the loop in the producer op.
1862struct FoldReshapeWithGenericOpByCollapsing
1863 : public OpRewritePattern<tensor::CollapseShapeOp> {
1864
1865 FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1866 ControlFusionFn foldReshapes,
1867 PatternBenefit benefit = 1)
1868 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1869 controlFoldingReshapes(std::move(foldReshapes)) {}
1870
1871 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1872 PatternRewriter &rewriter) const override {
1873 // Fold only if all constraints of fusing with reshape by collapsing are
1874 // met.
1875 auto producerResult = dyn_cast<OpResult>(Val: reshapeOp.getSrc());
1876 if (!producerResult) {
1877 return rewriter.notifyMatchFailure(arg&: reshapeOp,
1878 msg: "source not produced by an operation");
1879 }
1880
1881 // TODO : support fusion with all linalg producers, not just generic.
1882 auto producer = dyn_cast<GenericOp>(Val: producerResult.getOwner());
1883 if (!producer) {
1884 return rewriter.notifyMatchFailure(arg&: reshapeOp,
1885 msg: "producer not a generic op");
1886 }
1887
1888 SmallVector<ReassociationIndices> collapsableIterationDims =
1889 getCollapsableIterationSpaceDims(
1890 genericOp: producer,
1891 fusableOperand: producer.getDpsInitOperand(i: producerResult.getResultNumber()),
1892 reassociation: reshapeOp.getReassociationIndices());
1893 if (collapsableIterationDims.empty()) {
1894 return rewriter.notifyMatchFailure(
1895 arg&: reshapeOp, msg: "failed preconditions of fusion with producer generic op");
1896 }
1897
1898 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1899 return rewriter.notifyMatchFailure(arg&: reshapeOp,
1900 msg: "fusion blocked by control function");
1901 }
1902
1903 // Set the insertion point after `producer` because there could be uses
1904 // of `producer` between it and the `tensor.collapse_shape` op.
1905 rewriter.setInsertionPointAfter(producer);
1906 std::optional<CollapseResult> collapseResult =
1907 collapseOpIterationDims(op: producer, foldedIterationDims: collapsableIterationDims, rewriter);
1908 if (!collapseResult) {
1909 return rewriter.notifyMatchFailure(
1910 arg&: producer, msg: "failed to do the fusion by collapsing transformation");
1911 }
1912
1913 rewriter.replaceOp(op: producer, newValues: collapseResult->results);
1914 return success();
1915 }
1916
1917private:
1918 ControlFusionFn controlFoldingReshapes;
1919};
1920
1921class FoldPadWithProducerReshapeOpByCollapsing
1922 : public OpRewritePattern<tensor::PadOp> {
1923public:
1924 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1925 ControlFusionFn foldReshapes,
1926 PatternBenefit benefit = 1)
1927 : OpRewritePattern<tensor::PadOp>(context, benefit),
1928 controlFoldingReshapes(std::move(foldReshapes)) {}
1929
1930 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1931 PatternRewriter &rewriter) const override {
1932 tensor::ExpandShapeOp reshapeOp =
1933 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1934 if (!reshapeOp)
1935 return failure();
1936 if (!reshapeOp->hasOneUse())
1937 return failure();
1938
1939 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1940 return rewriter.notifyMatchFailure(arg&: padOp,
1941 msg: "fusion blocked by control function");
1942 }
1943
1944 ArrayRef<int64_t> low = padOp.getStaticLow();
1945 ArrayRef<int64_t> high = padOp.getStaticHigh();
1946 SmallVector<ReassociationIndices> reassociations =
1947 reshapeOp.getReassociationIndices();
1948
1949 for (auto reInd : reassociations) {
1950 if (reInd.size() == 1)
1951 continue;
1952 if (llvm::any_of(Range&: reInd, P: [&](int64_t ind) {
1953 return low[ind] != 0 || high[ind] != 0;
1954 })) {
1955 return failure();
1956 }
1957 }
1958
1959 SmallVector<OpFoldResult> newLow, newHigh;
1960 RankedTensorType collapsedType = reshapeOp.getSrcType();
1961 RankedTensorType paddedType = padOp.getResultType();
1962 SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1963 SmallVector<OpFoldResult> expandedPaddedSizes(
1964 getMixedValues(staticValues: reshapeOp.getStaticOutputShape(),
1965 dynamicValues: reshapeOp.getOutputShape(), b&: rewriter));
1966 AffineExpr d0, d1, d2;
1967 bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1, exprs&: d2);
1968 auto addMap = AffineMap::get(dimCount: 3, symbolCount: 0, result: {d0 + d1 + d2});
1969 Location loc = reshapeOp->getLoc();
1970 for (auto [idx, reInd] : llvm::enumerate(First&: reassociations)) {
1971 OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1972 OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1973 if (reInd.size() == 1) {
1974 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1975 OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
1976 b&: rewriter, loc, map: addMap, operands: {l, h, expandedPaddedSizes[reInd[0]]});
1977 expandedPaddedSizes[reInd[0]] = paddedSize;
1978 }
1979 newLow.push_back(Elt: l);
1980 newHigh.push_back(Elt: h);
1981 }
1982
1983 RankedTensorType collapsedPaddedType =
1984 paddedType.clone(shape: collapsedPaddedShape);
1985 auto newPadOp = rewriter.create<tensor::PadOp>(
1986 location: loc, args&: collapsedPaddedType, args: reshapeOp.getSrc(), args&: newLow, args&: newHigh,
1987 args: padOp.getConstantPaddingValue(), args: padOp.getNofold());
1988
1989 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1990 op: padOp, args: padOp.getResultType(), args: newPadOp.getResult(), args&: reassociations,
1991 args&: expandedPaddedSizes);
1992
1993 return success();
1994 }
1995
1996private:
1997 ControlFusionFn controlFoldingReshapes;
1998};
1999
2000/// Pattern to collapse dimensions.
2001template <typename LinalgType>
2002class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
2003public:
2004 CollapseLinalgDimensions(MLIRContext *context,
2005 GetCollapsableDimensionsFn collapseDimensions,
2006 PatternBenefit benefit = 1)
2007 : OpRewritePattern<LinalgType>(context, benefit),
2008 controlCollapseDimension(std::move(collapseDimensions)) {}
2009
2010 LogicalResult matchAndRewrite(LinalgType op,
2011 PatternRewriter &rewriter) const override {
2012 SmallVector<ReassociationIndices> collapsableIterationDims =
2013 controlCollapseDimension(op);
2014 if (collapsableIterationDims.empty())
2015 return failure();
2016
2017 // Check if the specified list of dimensions to collapse is a valid list.
2018 if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
2019 collapsableIterationDims)) {
2020 return rewriter.notifyMatchFailure(
2021 op, "specified dimensions cannot be collapsed");
2022 }
2023
2024 std::optional<CollapseResult> collapseResult =
2025 collapseOpIterationDims(op, collapsableIterationDims, rewriter);
2026 if (!collapseResult) {
2027 return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
2028 }
2029 rewriter.replaceOp(op, collapseResult->results);
2030 return success();
2031 }
2032
2033private:
2034 GetCollapsableDimensionsFn controlCollapseDimension;
2035};
2036
2037} // namespace
2038
2039//===---------------------------------------------------------------------===//
2040// Methods and patterns that fuse constants with linalg.generic operations.
2041//===---------------------------------------------------------------------===//
2042
2043namespace {
2044/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
2045/// handle cases where the constant is not single-valued.
2046class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
2047public:
2048 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2049 : OpRewritePattern<GenericOp>(context, benefit) {}
2050
2051 LogicalResult matchAndRewrite(GenericOp genericOp,
2052 PatternRewriter &rewriter) const override {
2053 if (!genericOp.hasPureTensorSemantics())
2054 return failure();
2055 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2056 Operation *def = opOperand->get().getDefiningOp();
2057 TypedAttr constantAttr;
2058 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
2059 {
2060 DenseElementsAttr splatAttr;
2061 if (matchPattern(op: def, pattern: m_Constant<DenseElementsAttr>(bind_value: &splatAttr)) &&
2062 splatAttr.isSplat() &&
2063 splatAttr.getType().getElementType().isIntOrFloat()) {
2064 constantAttr = splatAttr.getSplatValue<TypedAttr>();
2065 return true;
2066 }
2067 }
2068 {
2069 IntegerAttr intAttr;
2070 if (matchPattern(op: def, pattern: m_Constant<IntegerAttr>(bind_value: &intAttr))) {
2071 constantAttr = intAttr;
2072 return true;
2073 }
2074 }
2075 {
2076 FloatAttr floatAttr;
2077 if (matchPattern(op: def, pattern: m_Constant<FloatAttr>(bind_value: &floatAttr))) {
2078 constantAttr = floatAttr;
2079 return true;
2080 }
2081 }
2082 return false;
2083 };
2084
2085 auto resultValue = dyn_cast<OpResult>(Val: opOperand->get());
2086 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2087 continue;
2088
2089 // The operands and the indexing_maps of the fused operation the same as
2090 // the operands and indexing_maps of the generic operations with the
2091 // values at the constant index dropped.
2092 SmallVector<AffineMap> fusedIndexMaps;
2093 SmallVector<Value> fusedOperands;
2094 SmallVector<Location> fusedLocs{genericOp.getLoc()};
2095 fusedIndexMaps.reserve(N: genericOp->getNumOperands());
2096 fusedOperands.reserve(N: genericOp.getNumDpsInputs());
2097 fusedLocs.reserve(N: fusedLocs.size() + genericOp.getNumDpsInputs());
2098 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2099 if (inputOperand == opOperand)
2100 continue;
2101 Value inputValue = inputOperand->get();
2102 fusedIndexMaps.push_back(
2103 Elt: genericOp.getMatchingIndexingMap(opOperand: inputOperand));
2104 fusedOperands.push_back(Elt: inputValue);
2105 fusedLocs.push_back(Elt: inputValue.getLoc());
2106 }
2107 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2108 fusedIndexMaps.push_back(
2109 Elt: genericOp.getMatchingIndexingMap(opOperand: &outputOperand));
2110
2111 // Check if the operation shapes to loops map is computable.
2112 if (!inversePermutation(
2113 map: concatAffineMaps(maps: fusedIndexMaps, context: rewriter.getContext()))) {
2114 return rewriter.notifyMatchFailure(
2115 arg&: genericOp, msg: "fused op loop bound computation failed");
2116 }
2117
2118 // Create a constant scalar value from the splat constant.
2119 Value scalarConstant =
2120 rewriter.create<arith::ConstantOp>(location: def->getLoc(), args&: constantAttr);
2121
2122 SmallVector<Value> outputOperands = genericOp.getOutputs();
2123 auto fusedOp = rewriter.create<GenericOp>(
2124 location: rewriter.getFusedLoc(locs: fusedLocs), args: genericOp->getResultTypes(),
2125 /*inputs=*/args&: fusedOperands,
2126 /*outputs=*/args&: outputOperands,
2127 args: rewriter.getAffineMapArrayAttr(values: fusedIndexMaps),
2128 args: genericOp.getIteratorTypes(),
2129 /*doc=*/args: nullptr,
2130 /*library_call=*/args: nullptr);
2131
2132 // Map the block argument corresponding to the replaced argument with the
2133 // scalar constant.
2134 Region &region = genericOp->getRegion(index: 0);
2135 Block &entryBlock = *region.begin();
2136 IRMapping mapping;
2137 mapping.map(from: entryBlock.getArgument(i: opOperand->getOperandNumber()),
2138 to: scalarConstant);
2139 Region &fusedRegion = fusedOp->getRegion(index: 0);
2140 rewriter.cloneRegionBefore(region, parent&: fusedRegion, before: fusedRegion.begin(),
2141 mapping);
2142 rewriter.replaceOp(op: genericOp, newValues: fusedOp->getResults());
2143 return success();
2144 }
2145 return failure();
2146 }
2147};
2148
2149} // namespace
2150
2151//===---------------------------------------------------------------------===//
2152// Miscellaneous patterns that help fusion.
2153//===---------------------------------------------------------------------===//
2154
2155namespace {
2156/// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2157/// value of the `outs` operand is not used within the op. This is only
2158/// implemented for `linalg.generic` operations for now, but should hold for all
2159/// linalg structured ops.
2160struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2161 using OpRewritePattern<GenericOp>::OpRewritePattern;
2162
2163 LogicalResult matchAndRewrite(GenericOp op,
2164 PatternRewriter &rewriter) const override {
2165 rewriter.startOpModification(op);
2166 bool modifiedOutput = false;
2167 Location loc = op.getLoc();
2168 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2169 if (!op.payloadUsesValueFromOperand(opOperand: &opOperand)) {
2170 Value operandVal = opOperand.get();
2171 auto operandType = dyn_cast<RankedTensorType>(Val: operandVal.getType());
2172 if (!operandType)
2173 continue;
2174
2175 // If outs is sparse, leave it to the sparsifier.
2176 if (sparse_tensor::getSparseTensorEncoding(type: operandVal.getType()))
2177 continue;
2178
2179 // If outs is already an `empty` operation, nothing to do.
2180 auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2181 if (definingOp)
2182 continue;
2183 modifiedOutput = true;
2184 SmallVector<OpFoldResult> mixedSizes =
2185 tensor::getMixedSizes(builder&: rewriter, loc, value: operandVal);
2186 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
2187 location: loc, args&: mixedSizes, args: operandType.getElementType());
2188 op->setOperand(idx: opOperand.getOperandNumber(), value: emptyTensor);
2189 }
2190 }
2191 if (!modifiedOutput) {
2192 rewriter.cancelOpModification(op);
2193 return failure();
2194 }
2195 rewriter.finalizeOpModification(op);
2196 return success();
2197 }
2198};
2199
2200/// Fold linalg.fill into linalg.generic
2201struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2202 using OpRewritePattern<GenericOp>::OpRewritePattern;
2203
2204 LogicalResult matchAndRewrite(GenericOp genericOp,
2205 PatternRewriter &rewriter) const override {
2206 if (!genericOp.hasPureTensorSemantics())
2207 return failure();
2208 bool fillFound = false;
2209 Block &payload = genericOp.getRegion().front();
2210 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2211 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2212 continue;
2213 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2214 if (!fillOp)
2215 continue;
2216 fillFound = true;
2217 Value fillVal = fillOp.value();
2218 auto resultType =
2219 cast<RankedTensorType>(Val: fillOp.result().getType()).getElementType();
2220 Value convertedVal =
2221 convertScalarToDtype(b&: rewriter, loc: fillOp.getLoc(), operand: fillVal, toType: resultType,
2222 /*isUnsignedCast =*/false);
2223 rewriter.replaceAllUsesWith(
2224 from: payload.getArgument(i: opOperand->getOperandNumber()), to: convertedVal);
2225 }
2226 return success(IsSuccess: fillFound);
2227 }
2228};
2229} // namespace
2230
2231void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
2232 RewritePatternSet &patterns,
2233 const ControlFusionFn &controlFoldingReshapes) {
2234 patterns.add<FoldReshapeWithGenericOpByExpansion>(arg: patterns.getContext(),
2235 args: controlFoldingReshapes);
2236 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(arg: patterns.getContext(),
2237 args: controlFoldingReshapes);
2238 patterns.add<FoldWithProducerReshapeOpByExpansion>(arg: patterns.getContext(),
2239 args: controlFoldingReshapes);
2240}
2241
2242void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
2243 RewritePatternSet &patterns,
2244 const ControlFusionFn &controlFoldingReshapes) {
2245 patterns.add<FoldWithProducerReshapeOpByCollapsing>(arg: patterns.getContext(),
2246 args: controlFoldingReshapes);
2247 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2248 arg: patterns.getContext(), args: controlFoldingReshapes);
2249 patterns.add<FoldReshapeWithGenericOpByCollapsing>(arg: patterns.getContext(),
2250 args: controlFoldingReshapes);
2251}
2252
2253void mlir::linalg::populateElementwiseOpsFusionPatterns(
2254 RewritePatternSet &patterns,
2255 const ControlFusionFn &controlElementwiseOpsFusion) {
2256 auto *context = patterns.getContext();
2257 patterns.add<FuseElementwiseOps>(arg&: context, args: controlElementwiseOpsFusion);
2258 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2259 RemoveOutsDependency>(arg&: context);
2260 // Add the patterns that clean up dead operands and results.
2261 populateEraseUnusedOperandsAndResultsPatterns(patterns);
2262}
2263
2264void mlir::linalg::populateCollapseDimensions(
2265 RewritePatternSet &patterns,
2266 const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2267 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2268 CollapseLinalgDimensions<linalg::CopyOp>>(
2269 arg: patterns.getContext(), args: controlCollapseDimensions);
2270}
2271
2272//===---------------------------------------------------------------------===//
2273// Passes
2274//===---------------------------------------------------------------------===//
2275
2276namespace {
2277
2278/// Pass that fuses generic ops on tensors. Used only for testing.
2279// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2280// patterns added here heavily depends on the cost function used. Having an
2281// opinionated pass of this form is not recommended. Deprecate this pass in
2282// favor of test passes that check the functionality of each of the patterns
2283// added here individually.
2284struct LinalgElementwiseOpFusionPass
2285 : public impl::LinalgElementwiseOpFusionPassBase<
2286 LinalgElementwiseOpFusionPass> {
2287 using impl::LinalgElementwiseOpFusionPassBase<
2288 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2289 void runOnOperation() override {
2290 Operation *op = getOperation();
2291 MLIRContext *context = op->getContext();
2292 RewritePatternSet patterns(context);
2293
2294 // Add folding with reshape by expansion patterns.
2295 ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2296 Operation *producer = fusedOperand->get().getDefiningOp();
2297 return producer && producer->hasOneUse();
2298 };
2299
2300 // Add elementwise op fusion patterns.
2301 populateElementwiseOpsFusionPatterns(patterns, controlElementwiseOpsFusion: defaultControlFn);
2302 populateFoldReshapeOpsByExpansionPatterns(patterns, controlFoldingReshapes: defaultControlFn);
2303 tensor::populateBubbleUpExpandShapePatterns(patterns);
2304
2305 // General canonicalization patterns.
2306 affine::AffineApplyOp::getCanonicalizationPatterns(results&: patterns, context);
2307 GenericOp::getCanonicalizationPatterns(results&: patterns, context);
2308 tensor::ExpandShapeOp::getCanonicalizationPatterns(results&: patterns, context);
2309 tensor::CollapseShapeOp::getCanonicalizationPatterns(results&: patterns, context);
2310 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2311 results&: patterns);
2312
2313 // Add constant folding patterns.
2314 populateConstantFoldLinalgOperations(patterns, controlFn: defaultControlFn);
2315
2316 // Use TopDownTraversal for compile time reasons.
2317 (void)applyPatternsGreedily(op, patterns: std::move(patterns),
2318 config: GreedyRewriteConfig().setUseTopDownTraversal());
2319 }
2320};
2321
2322} // namespace
2323

source code of mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp