1//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Fusion on tensors operations pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Linalg/Passes.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Support/LLVM.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "mlir/Transforms/RegionUtils.h"
29#include <optional>
30#include <utility>
31
32namespace mlir {
33#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
34#include "mlir/Dialect/Linalg/Passes.h.inc"
35} // namespace mlir
36
37using namespace mlir;
38using namespace mlir::linalg;
39
40//===---------------------------------------------------------------------===//
41// Methods and patterns that fuse elementwise `linalg.generic` operations.
42//===---------------------------------------------------------------------===//
43
44/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
45/// the `producer` to use in the fused operation given the indexing map of the
46/// result of the producer in the consumer.
47static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
48 OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
49 AffineMap fusedConsumerArgIndexMap) {
50 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
51 // from consumer loop -> consumer arg tensor index/producer result tensor
52 // index. The fused loop is same as the consumer loop. For each producer arg
53 // the indexing map to be computed is a map from consumer loop -> producer
54 // arg tensor index.
55 // producerResultIndexMap is a map from producer loop -> tensor index.
56 // Compute the inverse to get map from tensor index -> producer loop.
57 // The inverse is a map from producer result tensor index -> producer loop.
58 AffineMap invProducerResultIndexMap =
59 inversePermutation(map: producerResultIndexMap);
60 assert(invProducerResultIndexMap &&
61 "expected producer result indexing map to be invertible");
62
63 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
64 // argMap is a map from producer loop -> producer arg tensor index.
65 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
66
67 // Compose argMap with invProducerResultIndexMap to get a map from
68 // producer result tensor index -> producer arg tensor index.
69 AffineMap t1 = argMap.compose(map: invProducerResultIndexMap);
70
71 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
72 // consumer loop/ fused loop -> producer arg tensor index.
73 return t1.compose(map: fusedConsumerArgIndexMap);
74}
75
76// Checks if the given operand can be dropped, and the remaining operands
77// of the fused producer & consumer after the fusion can still compute the
78// bounds of the op.
79static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
80 GenericOp producer, GenericOp consumer,
81 ArrayRef<OpOperand *> opOperandsToIgnore) {
82 SmallVector<AffineMap> indexingMaps;
83
84 SmallVector<GenericOp> ops = {producer, consumer};
85 for (auto &op : ops) {
86 for (auto &opOperand : op->getOpOperands()) {
87 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
88 continue;
89 }
90 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
91 }
92 }
93 if (indexingMaps.empty()) {
94 // If there are no indexing maps, the operand can only be dropped
95 // if neither op has loops.
96 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
97 }
98
99 // The concatanation of the remained indexing maps must be invertible, so
100 // the bounds of the op can be still computed after dropping the selected
101 // operand. inversePermutation returns an empty AffineMap in case the
102 // concatanated indexing maps are not invertible.
103 return inversePermutation(concatAffineMaps(
104 indexingMaps, producer.getContext())) != AffineMap();
105}
106
107/// Returns a set of indices of the producer's results which would
108/// be preserved after the fusion.
109/// * There is a chance that the implementation of the transformation does not
110/// agree with the result of this method. This function gives a prediction based
111/// on an optimized fusion.
112llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
113 GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
114 llvm::SmallDenseSet<int> preservedProducerResults;
115 llvm::SmallVector<OpOperand *> opOperandsToIgnore;
116
117 // The fusedOperand will be removed during the fusion
118 opOperandsToIgnore.emplace_back(Args&: fusedOperand);
119
120 for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
121 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
122 opOperandsToIgnore.emplace_back(outputOperand);
123 if (producer.payloadUsesValueFromOperand(outputOperand) ||
124 !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
125 opOperandsToIgnore) ||
126 llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
127 return user != consumer.getOperation();
128 })) {
129 preservedProducerResults.insert(producerResult.index());
130
131 // In case the operand can't be dropped
132 (void)opOperandsToIgnore.pop_back_val();
133 }
134 }
135 return preservedProducerResults;
136}
137
138/// Conditions for elementwise fusion of generic operations.
139bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
140 if (!fusedOperand)
141 return false;
142
143 auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
144 auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
145
146 // Check producer and consumer are generic ops.
147 if (!producer || !consumer)
148 return false;
149
150 // Consumer can have mixed semantics, just check operand itself has tensor
151 // type. Producer must have full tensor semantics to avoid potential
152 // aliasing between producer and consumer memrefs.
153 if (!producer.hasPureTensorSemantics() ||
154 !isa<RankedTensorType>(Val: fusedOperand->get().getType()))
155 return false;
156
157 // Verify that
158 // - the producer has all "parallel" iterator type.
159 if (producer.getNumParallelLoops() != producer.getNumLoops())
160 return false;
161
162 // Only allow fusing the producer of an input operand for now.
163 // TODO: allow fusing the producer of an output operand.
164 if (!consumer.isDpsInput(fusedOperand))
165 return false;
166
167 // Get the consumer index map. The number of results of the consumer index
168 // map must match the number of loops of the producer.
169 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
170 if (consumerIndexMap.getNumResults() != producer.getNumLoops())
171 return false;
172
173 // Finally the index_map for the result must be invertible. For now just
174 // verify it is a permutation.
175 AffineMap producerResultIndexMap =
176 producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
177 if (!producerResultIndexMap.isPermutation())
178 return false;
179
180 // Ensure that the fusion does not remove size information required to
181 // get the loop bounds. For non-reduction generics, this is trivially the
182 // case due to the output operand. For reductions, we need to check that after
183 // the fusion, each loop dimension has at least one input that defines it.
184 if ((consumer.getNumReductionLoops())) {
185 BitVector coveredDims(consumer.getNumLoops(), false);
186
187 auto addToCoveredDims = [&](AffineMap map) {
188 for (auto result : map.getResults())
189 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: result))
190 coveredDims[dimExpr.getPosition()] = true;
191 };
192
193 for (auto pair :
194 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
195 Value operand = std::get<0>(pair);
196 if (operand == fusedOperand->get())
197 continue;
198 AffineMap operandMap = std::get<1>(pair);
199 addToCoveredDims(operandMap);
200 }
201
202 for (OpOperand *operand : producer.getDpsInputOperands()) {
203 AffineMap newIndexingMap =
204 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
205 operand, producerResultIndexMap, consumerIndexMap);
206 addToCoveredDims(newIndexingMap);
207 }
208 if (!coveredDims.all())
209 return false;
210 }
211
212 return true;
213}
214
215/// Generate the region of the fused tensor operation. The region of the fused
216/// op must be empty.
217static void generateFusedElementwiseOpRegion(
218 RewriterBase &rewriter, GenericOp fusedOp,
219 AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
220 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
221 auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
222 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
223 // Build the region of the fused op.
224 Block &producerBlock = producer->getRegion(0).front();
225 Block &consumerBlock = consumer->getRegion(0).front();
226 OpBuilder::InsertionGuard guard(rewriter);
227 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
228 IRMapping mapper;
229
230 // 2. Add an index operation for every fused loop dimension and use the
231 // `consumerToProducerLoopsMap` to map the producer indices.
232 if (producer.hasIndexSemantics()) {
233 // Add an index operation for every fused loop dimension.
234 unsigned numFusedOpLoops =
235 std::max(producer.getNumLoops(), consumer.getNumLoops());
236 SmallVector<Value> fusedIndices;
237 fusedIndices.reserve(N: numFusedOpLoops);
238 llvm::transform(Range: llvm::seq<uint64_t>(Begin: 0, End: numFusedOpLoops),
239 d_first: std::back_inserter(x&: fusedIndices), F: [&](uint64_t dim) {
240 return rewriter.create<IndexOp>(producer.getLoc(), dim);
241 });
242 for (IndexOp indexOp :
243 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
244 Value newIndex = rewriter.create<affine::AffineApplyOp>(
245 producer.getLoc(),
246 consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
247 mapper.map(indexOp.getResult(), newIndex);
248 }
249 }
250 // TODO: allow fusing the producer of an output operand.
251 assert(consumer.isDpsInput(fusedOperand) &&
252 "expected producer of input operand");
253 // 3. Consumer input operands up to consumerIdx (exclusive).
254 for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
255 fusedOperand->getOperandNumber())) // input assumption.
256 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
257
258 // Replacing consumerIdx requires getting the cloned, yielded, value from
259 // the (cloned) producer block. This happens in step 9.
260
261 // 4. Splice in producer's input operands.
262 for (BlockArgument bbArg :
263 producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
264 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
265
266 // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
267 for (BlockArgument bbArg :
268 consumerBlock.getArguments()
269 .take_front(consumer.getNumDpsInputs())
270 .drop_front(fusedOperand->getOperandNumber() + 1))
271 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
272
273 // 6. All of the producer's output operands
274 for (const auto &bbArg : llvm::enumerate(
275 producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
276 if (!preservedProducerResults.count(bbArg.index()))
277 continue;
278 mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
279 bbArg.value().getLoc()));
280 }
281
282 // 7. All of consumer's output operands.
283 for (BlockArgument bbArg :
284 consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
285 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
286
287 // 8. Clone all producer operations except for the yield and index operations
288 // to the fused operation.
289 for (auto &op : producerBlock.without_terminator()) {
290 if (!isa<IndexOp>(op))
291 rewriter.clone(op, mapper);
292 }
293 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
294 // forward the yield operand.
295 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
296 unsigned producerResultNumber =
297 cast<OpResult>(Val: fusedOperand->get()).getResultNumber();
298 Value replacement =
299 mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
300
301 // Sanity checks, if replacement is not already in the mapper then it must be
302 // produced outside.
303 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
304 if (auto bb = dyn_cast<BlockArgument>(replacement))
305 assert(bb.getOwner() != &producerBlock &&
306 "yielded block argument must have been mapped");
307 else
308 assert(!producer->isAncestor(replacement.getDefiningOp()) &&
309 "yielded value must have been mapped");
310 }
311 mapper.map(from: consumerBlock.getArgument(i: fusedOperand->getOperandNumber()),
312 to: replacement);
313 // 10. Clone operations from the consumer to the fused op.
314 for (auto &op : consumerBlock.without_terminator())
315 rewriter.clone(op, mapper);
316
317 // 11. Include the final yield (which is the remapped values for all the
318 // yield)
319 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
320 SmallVector<Value> fusedYieldValues;
321 fusedYieldValues.reserve(N: producerYieldOp.getNumOperands() +
322 consumerYieldOp.getNumOperands());
323 for (const auto &producerYieldVal :
324 llvm::enumerate(producerYieldOp.getOperands())) {
325 if (preservedProducerResults.count(producerYieldVal.index()))
326 fusedYieldValues.push_back(
327 mapper.lookupOrDefault(producerYieldVal.value()));
328 }
329 for (auto consumerYieldVal : consumerYieldOp.getOperands())
330 fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
331 rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
332
333 // Sanity checks.
334 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
335 "Ill-formed GenericOp region");
336}
337
338FailureOr<mlir::linalg::ElementwiseOpFusionResult>
339mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
340 OpOperand *fusedOperand) {
341 assert(areElementwiseOpsFusable(fusedOperand) &&
342 "expected elementwise operation pre-conditions to pass");
343 auto producerResult = cast<OpResult>(Val: fusedOperand->get());
344 auto producer = cast<GenericOp>(producerResult.getOwner());
345 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
346 // TODO: allow fusing the producer of an output operand.
347 assert(consumer.isDpsInput(fusedOperand) &&
348 "expected producer of input operand");
349 /// Find the results of the producer that have uses outside of the consumer,
350 /// after the fusion.
351 llvm::SmallDenseSet<int> preservedProducerResults =
352 mlir::linalg::getPreservedProducerResults(producer: producer, consumer: consumer,
353 fusedOperand);
354
355 // Compute the fused operands list and indexing maps.
356 SmallVector<Value> fusedInputOperands, fusedOutputOperands;
357 SmallVector<Type> fusedResultTypes;
358 SmallVector<AffineMap> fusedIndexMaps;
359 fusedInputOperands.reserve(N: producer.getNumDpsInputs() +
360 consumer.getNumDpsInputs());
361 fusedOutputOperands.reserve(N: preservedProducerResults.size() +
362 consumer.getNumDpsInits());
363 fusedResultTypes.reserve(N: preservedProducerResults.size() +
364 consumer.getNumDpsInits());
365 fusedIndexMaps.reserve(N: producer->getNumOperands() +
366 consumer->getNumOperands());
367 // In the following, numbering matches that of `generateFusedTensorOpRegion`.
368 // 3. Consumer input operands/maps up to consumerIdx (exclusive).
369 auto consumerInputs = consumer.getDpsInputOperands();
370 auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
371 return operand == fusedOperand;
372 });
373 assert(it != consumerInputs.end() && "expected to find the consumer operand");
374 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
375 fusedInputOperands.push_back(opOperand->get());
376 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
377 }
378 // 4. Splice in producer's input operands/maps.
379 AffineMap producerResultIndexMap =
380 producer.getIndexingMapMatchingResult(producerResult);
381 for (OpOperand *opOperand : producer.getDpsInputOperands()) {
382 fusedInputOperands.push_back(opOperand->get());
383 // Compute indexing maps for the producer args in the fused operation.
384 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
385 opOperand, producerResultIndexMap,
386 consumer.getMatchingIndexingMap(fusedOperand));
387 fusedIndexMaps.push_back(map);
388 }
389 // 5. Remaining consumer's input operands/maps (drop past index
390 // `consumerIdx`).
391 for (OpOperand *opOperand :
392 llvm::make_range(std::next(it), consumerInputs.end())) {
393 fusedInputOperands.push_back(opOperand->get());
394 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
395 }
396
397 // 6. Collect all of the producer outputs.
398 for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
399 if (!preservedProducerResults.count(opOperand.index()))
400 continue;
401
402 fusedOutputOperands.push_back(opOperand.value().get());
403 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
404 &opOperand.value(), producerResultIndexMap,
405 consumer.getMatchingIndexingMap(fusedOperand));
406 fusedIndexMaps.push_back(map);
407 fusedResultTypes.push_back(opOperand.value().get().getType());
408 }
409
410 // 7. All of consumer's output operands (skip operands: added by the builder).
411 for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
412 fusedOutputOperands.push_back(opOperand.get());
413 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
414 Type resultType = opOperand.get().getType();
415 if (!isa<MemRefType>(resultType))
416 fusedResultTypes.push_back(resultType);
417 }
418
419 // Generate the fused op.
420 auto fusedOp = rewriter.create<GenericOp>(
421 consumer.getLoc(), fusedResultTypes, fusedInputOperands,
422 fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
423 consumer.getIteratorTypes(),
424 /*doc=*/nullptr,
425 /*library_call=*/nullptr);
426 if (!fusedOp.getShapesToLoopsMap()) {
427 // Fused op has invalid indexing maps. Typically this means something is off
428 // in the input, but going ahead here would result in verification errors.
429 // So cleanup and abort.
430 rewriter.eraseOp(op: fusedOp);
431 return rewriter.notifyMatchFailure(
432 fusedOp, "fused op failed loop bound computation check");
433 }
434
435 // Construct an AffineMap from consumer loops to producer loops.
436 // consumer loop -> tensor index
437 AffineMap consumerResultIndexMap =
438 consumer.getMatchingIndexingMap(fusedOperand);
439 // tensor index -> producer loop
440 AffineMap invProducerResultIndexMap =
441 inversePermutation(map: producerResultIndexMap);
442 assert(invProducerResultIndexMap &&
443 "expected producer result indexig map to be invertible");
444 // consumer loop -> producer loop
445 AffineMap consumerToProducerLoopsMap =
446 invProducerResultIndexMap.compose(map: consumerResultIndexMap);
447
448 generateFusedElementwiseOpRegion(
449 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
450 consumer.getNumLoops(), preservedProducerResults);
451 ElementwiseOpFusionResult result;
452 result.fusedOp = fusedOp;
453 int resultNum = 0;
454 for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
455 if (preservedProducerResults.count(index))
456 result.replacements[producerResult] = fusedOp->getResult(resultNum++);
457 for (auto consumerResult : consumer->getResults())
458 result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
459 return result;
460}
461
462namespace {
463/// Patterns to fuse a generic op, with the producer of its operands.
464class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
465public:
466 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
467 PatternBenefit benefit = 1)
468 : OpRewritePattern<GenericOp>(context, benefit),
469 controlFn(std::move(fun)) {}
470
471 LogicalResult matchAndRewrite(GenericOp genericOp,
472 PatternRewriter &rewriter) const override {
473 // Find the first operand that is defined by another generic op on tensors.
474 for (OpOperand &opOperand : genericOp->getOpOperands()) {
475 if (!areElementwiseOpsFusable(&opOperand))
476 continue;
477 if (!controlFn(&opOperand))
478 continue;
479
480 Operation *producer = opOperand.get().getDefiningOp();
481
482 // Find the producer of the operand.
483 FailureOr<ElementwiseOpFusionResult> fusionResult =
484 fuseElementwiseOps(rewriter, &opOperand);
485 if (failed(fusionResult))
486 return rewriter.notifyMatchFailure(genericOp, "fusion failed");
487
488 // Perform the fusion.
489 for (auto [origVal, replacement] : fusionResult->replacements) {
490 rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
491 // Only replace consumer uses.
492 return use.get().getDefiningOp() != producer;
493 });
494 }
495 rewriter.eraseOp(genericOp);
496 return success();
497 }
498 return failure();
499 }
500
501private:
502 ControlFusionFn controlFn;
503};
504} // namespace
505
506//===---------------------------------------------------------------------===//
507// Methods and patterns that fuse reshape ops with elementwise operations by
508// expanding the dimensionality of the elementwise operations.
509//===---------------------------------------------------------------------===//
510
511/// Conditions for folding a structured linalg operation with a reshape op by
512/// expanding the iteration space dimensionality for tensor operations. These
513/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
514/// the following fusion pattern.
515///
516/// Consider
517///
518/// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
519/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
520/// affine_map<(d0, d1, d2) -> (d1, d2)>,
521/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
522/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
523/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
524///
525/// The reshape can be folded into the `linalgOp` if its loop dimensionality
526/// is increased to match the result (operand) of the tensor.expand_shape.
527/// The indexing_map of the fused tensor in the `linalgOp` and the
528/// reassociation map helps compute the indexing maps of the modified op.
529/// For the above example, based on the reassociation map it
530/// can be concluded that
531///
532/// - The loop used to access the first dimension of the fused tensor is split
533/// into two.
534/// - The loop used to access the second dimension of the fused tensor is kept
535/// as is.
536/// - The loop used to access the third dimension of the fused tensor is split
537/// into three.
538///
539/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
540/// op, then
541///
542/// d0 -> e0, e1
543/// d1 -> e2, e3, e4
544/// d2 -> e5
545///
546/// substituting this, the structured op can be rewritten as
547///
548/// %d = linalg.generic ins(%0, %1 : )
549/// indexing_maps =
550/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
551/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
552/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
553///
554/// Since operands to the linalg generic are now 5D, reshapes can be introduced
555/// to make it consistent
556///
557/// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
558/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
559/// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
560/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
561///
562/// The added reshapes are again expanding patterns, so they will get fused
563/// with its producers if possible.
564static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
565 OpOperand *fusableOpOperand) {
566 // Is fusable only if:
567 // - All the indexing maps for operands and results are projected
568 // permutations.
569 // - The fused tensor is not a scalar.
570 SmallVector<utils::IteratorType> iteratorTypes =
571 linalgOp.getIteratorTypesArray();
572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573 return linalgOp.hasPureTensorSemantics() &&
574 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
575 [](Attribute attr) {
576 return cast<AffineMapAttr>(attr)
577 .getValue()
578 .isProjectedPermutation();
579 }) &&
580 operandMap.getNumResults() > 0;
581}
582
583namespace {
584/// Information needed to expand a generic operation to fold the reshape with
585/// it.
586class ExpansionInfo {
587public:
588 // Computes the mapping from original dimensions of the op to the dimensions
589 // of the expanded op given the `indexingMap` of the fused operand/result of
590 // the generic op, the `reassocationMaps` of the reshape op and the shape of
591 // the expanded op.
592 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
593 ArrayRef<AffineMap> reassociationMaps,
594 ArrayRef<OpFoldResult> expandedShape,
595 PatternRewriter &rewriter);
596 unsigned getOrigOpNumDims() const { return reassociation.size(); }
597 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
598 ReassociationIndicesRef getExpandedDims(unsigned i) const {
599 return reassociation[i];
600 }
601 ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
602 return expandedShapeMap[i];
603 }
604 ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
605
606private:
607 /// Reassociation from the dimensions in the original operation to the
608 /// dimension of the expanded operation.
609 SmallVector<ReassociationIndices> reassociation;
610 /// Mapping from extent of loops in the original operation, to the extent of
611 /// loops in the expanded operation.
612 SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
613 /// Extent of the loop in the original operation.
614 SmallVector<OpFoldResult> originalLoopExtent;
615 unsigned expandedOpNumDims;
616};
617} // namespace
618
619LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
620 OpOperand *fusableOpOperand,
621 ArrayRef<AffineMap> reassociationMaps,
622 ArrayRef<OpFoldResult> expandedShape,
623 PatternRewriter &rewriter) {
624 if (reassociationMaps.empty())
625 return failure();
626 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
627
628 OpBuilder::InsertionGuard g(rewriter);
629 rewriter.setInsertionPoint(linalgOp);
630 originalLoopExtent = llvm::map_to_vector(
631 linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
632 [](Range r) { return r.size; });
633
634 reassociation.clear();
635 expandedShapeMap.clear();
636 // Compute the number of dimension in the expanded op that correspond to each
637 // dimension of the original op.
638 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
639 expandedShapeMap.resize(N: fusedIndexMap.getNumDims());
640 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
641 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
642 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
643 numExpandedDims[pos] = foldedDims.getNumResults();
644 ArrayRef<OpFoldResult> shape =
645 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
646 expandedShapeMap[pos].assign(shape.begin(), shape.end());
647 }
648 // The remaining dimensions remain the same.
649 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
650 if (expandedShapeMap[i].empty())
651 expandedShapeMap[i] = {originalLoopExtent[i]};
652
653 // Compute reassociation map from the original op to the expanded op.
654 unsigned sum = 0;
655 reassociation.reserve(N: fusedIndexMap.getNumDims());
656 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
657 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
658 reassociation.emplace_back(seq.begin(), seq.end());
659 sum += numFoldedDim.value();
660 }
661 expandedOpNumDims = sum;
662 return success();
663}
664
665/// Return the indexing map to use in the expanded op for a given the
666/// `indexingMap` of the original operation.
667static AffineMap
668getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
669 const ExpansionInfo &expansionInfo) {
670 SmallVector<AffineExpr> newExprs;
671 for (AffineExpr expr : indexingMap.getResults()) {
672 unsigned pos = cast<AffineDimExpr>(Val&: expr).getPosition();
673 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
674 Range: llvm::map_range(C: expansionInfo.getExpandedDims(i: pos), F: [&](int64_t v) {
675 return builder.getAffineDimExpr(position: static_cast<unsigned>(v));
676 }));
677 newExprs.append(in_start: expandedExprs.begin(), in_end: expandedExprs.end());
678 }
679 return AffineMap::get(dimCount: expansionInfo.getExpandedOpNumDims(),
680 symbolCount: indexingMap.getNumSymbols(), results: newExprs,
681 context: builder.getContext());
682}
683
684/// Return the shape and type of the operand/result to use in the expanded op
685/// given the type in the original op.
686static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
688 const ExpansionInfo &expansionInfo) {
689 SmallVector<OpFoldResult> expandedShape;
690 for (AffineExpr expr : indexingMap.getResults()) {
691 unsigned dim = cast<AffineDimExpr>(Val&: expr).getPosition();
692 ArrayRef<OpFoldResult> dimExpansion =
693 expansionInfo.getExpandedShapeOfDim(i: dim);
694 expandedShape.append(in_start: dimExpansion.begin(), in_end: dimExpansion.end());
695 }
696 SmallVector<int64_t> expandedStaticShape;
697 std::tie(args&: expandedStaticShape, args: std::ignore) =
698 decomposeMixedValues(mixedValues: expandedShape);
699 return {expandedShape, RankedTensorType::get(expandedStaticShape,
700 originalType.getElementType())};
701}
702
703/// Returns the reassociation maps to use in the `tensor.expand_shape`
704/// operation to convert the operands of the original operation to operands of
705/// the expanded operation. The same method is used to compute the
706/// `tensor.collapse_shape` used to collapse the result of the expanded
707/// op to get the value that can replace all uses of the results of the original
708/// op.
709static SmallVector<ReassociationIndices>
710getReassociationForExpansion(AffineMap indexingMap,
711 const ExpansionInfo &expansionInfo) {
712 SmallVector<ReassociationIndices> reassociation;
713 unsigned numReshapeDims = 0;
714 for (AffineExpr expr : indexingMap.getResults()) {
715 unsigned dim = cast<AffineDimExpr>(Val&: expr).getPosition();
716 auto numExpandedDims = expansionInfo.getExpandedDims(i: dim).size();
717 SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
718 Range: llvm::seq<int64_t>(Begin: numReshapeDims, End: numReshapeDims + numExpandedDims));
719 reassociation.emplace_back(Args: std::move(indices));
720 numReshapeDims += numExpandedDims;
721 }
722 return reassociation;
723}
724
725/// Update the body of an expanded linalg operation having index semantics. The
726/// indices of the original operation need to be recovered by linearizing the
727/// indices of the correspoding dimensions of the expanded operation. For now it
728/// is assumed that the shapes of the expanded operation needed for
729/// linearization are static.
730static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
731 Location loc, Region &fusedRegion,
732 const ExpansionInfo &expansionInfo) {
733 // Replace the original indices by the linearization of the expanded indices.
734 for (IndexOp indexOp :
735 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
736 ArrayRef<int64_t> expandedDims =
737 expansionInfo.getExpandedDims(indexOp.getDim());
738 assert(!expandedDims.empty() && "expected valid expansion info");
739
740 // Skip index operations that are not affected by the expansion.
741 if (expandedDims.size() == 1 &&
742 expandedDims.front() == (int64_t)indexOp.getDim())
743 continue;
744
745 // Linearize the expanded indices of the original index dimension.
746 OpBuilder::InsertionGuard guard(rewriter);
747 rewriter.setInsertionPointAfter(indexOp);
748 ArrayRef<OpFoldResult> expandedDimsShape =
749 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
750 SmallVector<Value> expandedIndices;
751 expandedIndices.reserve(expandedDims.size() - 1);
752 llvm::transform(
753 expandedDims.drop_front(), std::back_inserter(expandedIndices),
754 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
755 OpFoldResult newIndex =
756 rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
757 for (auto [expandedShape, expandedIndex] :
758 llvm::zip(expandedDimsShape, expandedIndices)) {
759 AffineExpr idx, acc, shape;
760 bindDims(rewriter.getContext(), idx, acc);
761 bindSymbols(rewriter.getContext(), shape);
762 newIndex = affine::makeComposedFoldedAffineApply(
763 rewriter, indexOp.getLoc(), idx + acc * shape,
764 ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
765 }
766 Value newIndexVal =
767 getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
768 rewriter.replaceOp(indexOp, newIndexVal);
769 }
770}
771
772// Create an expanded transpose op.
773// the reassociation map is already permuted hence we inverse permute and then
774// flatten it. Then we inverse permute it again to get the final expanded
775// transpose permutation. For example,
776//
777// permutation = [2, 0, 1]
778// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
779//
780// inverse permutation = [1, 2, 0]
781// applied to reassocation_map and then flattened becomes
782// flatened permutation = [2, 3, 4, 5, 0, 1]
783// final permuation is the inverse of the flattened permutation.
784//
785// Becomes
786//
787// permutation=[4, 5, 0, 1, 2, 3]
788
789static Operation *createExpandedTransposeOp(PatternRewriter &rewriter,
790 TransposeOp transposeOp,
791 Value expandedInput, Value output,
792 ExpansionInfo &expansionInfo) {
793 SmallVector<int64_t> newPerm;
794 for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
795 auto reassoc = expansionInfo.getExpandedDims(perm);
796 for (int64_t dim : reassoc) {
797 newPerm.push_back(dim);
798 }
799 }
800 return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
801 output, invertPermutationVector(newPerm));
802}
803
804// Create an expanded generic op.
805static Operation *createExpandedGenericOp(
806 PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
807 ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
808 ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
809 // The iterator types of the expanded op are all parallel.
810 SmallVector<utils::IteratorType> iteratorTypes(
811 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
812
813 for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
814 for (auto j : expansionInfo.getExpandedDims(i))
815 iteratorTypes[j] = type;
816
817 Operation *fused = rewriter.create<GenericOp>(
818 linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
819 expandedOpIndexingMaps, iteratorTypes);
820
821 Region &fusedRegion = fused->getRegion(index: 0);
822 Region &originalRegion = linalgOp->getRegion(0);
823 rewriter.cloneRegionBefore(region&: originalRegion, parent&: fusedRegion, before: fusedRegion.begin());
824
825 // Update the index accesses after the expansion.
826 updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
827 expansionInfo);
828
829 return fused;
830}
831
832// Create an expanded fused op that retains the name for certain ops
833// such as fill, copy and transpose and produce a generic op for
834// rest of linalg ops.
835static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
836 TypeRange resultTypes,
837 ArrayRef<Value> expandedOpOperands,
838 ArrayRef<Value> outputs,
839 ArrayRef<AffineMap> expandedOpIndexingMaps,
840 ExpansionInfo &expansionInfo) {
841
842 return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
843 .Case<TransposeOp>([&](TransposeOp transposeOp) {
844 return createExpandedTransposeOp(rewriter, transposeOp,
845 expandedOpOperands[0], outputs[0],
846 expansionInfo);
847 })
848 .Case<FillOp, CopyOp>([&](Operation *op) {
849 return clone(rewriter, linalgOp, resultTypes,
850 llvm::to_vector(llvm::concat<Value>(
851 llvm::to_vector(expandedOpOperands),
852 llvm::to_vector(outputs))));
853 })
854 .Default([&](Operation *op) {
855 return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
856 expandedOpOperands, outputs,
857 expansionInfo, expandedOpIndexingMaps);
858 });
859}
860
861/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
862/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
863/// that those conditions have been satisfied.
864static std::optional<SmallVector<Value>>
865fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
866 OpOperand *fusableOpOperand,
867 PatternRewriter &rewriter) {
868 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
869 "preconditions for fuse operation failed");
870
871 Location loc = linalgOp.getLoc();
872 SmallVector<OpFoldResult> expandedShape;
873 SmallVector<AffineMap, 4> reassociationIndices;
874 Value src;
875 if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
876 // Try to move the dynamic dimensions in output shape before the `linalgOp`
877 // to maintain SSA validity
878 if (failed(moveValueDefinitions(
879 rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
880 return std::nullopt;
881
882 expandedShape = expandingReshapeOp.getMixedOutputShape();
883 reassociationIndices = expandingReshapeOp.getReassociationMaps();
884 src = expandingReshapeOp.getSrc();
885 } else {
886 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887 if (!collapsingReshapeOp)
888 return std::nullopt;
889
890 expandedShape = tensor::getMixedSizes(
891 builder&: rewriter, loc: collapsingReshapeOp->getLoc(), value: collapsingReshapeOp.getSrc());
892 reassociationIndices = collapsingReshapeOp.getReassociationMaps();
893 src = collapsingReshapeOp.getSrc();
894 }
895
896 ExpansionInfo expansionInfo;
897 if (failed(expansionInfo.compute(linalgOp: linalgOp, fusableOpOperand,
898 reassociationMaps: reassociationIndices, expandedShape,
899 rewriter)))
900 return std::nullopt;
901
902 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
903 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
904 return getIndexingMapInExpandedOp(builder&: rewriter, indexingMap: m, expansionInfo);
905 }));
906
907 // Set insertion point to the generic op.
908 OpBuilder::InsertionGuard g(rewriter);
909 rewriter.setInsertionPoint(linalgOp);
910
911 SmallVector<Value> expandedOpOperands;
912 expandedOpOperands.reserve(N: linalgOp.getNumDpsInputs());
913 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
914 if (opOperand == fusableOpOperand) {
915 expandedOpOperands.push_back(src);
916 continue;
917 }
918 if (auto opOperandType =
919 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
920 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
921 SmallVector<OpFoldResult> expandedOperandShape;
922 RankedTensorType expandedOperandType;
923 std::tie(expandedOperandShape, expandedOperandType) =
924 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
925 if (expandedOperandType != opOperand->get().getType()) {
926 // Reshape the operand to get the right type.
927 SmallVector<ReassociationIndices> reassociation =
928 getReassociationForExpansion(indexingMap, expansionInfo);
929 if (failed(reshapeLikeShapesAreCompatible(
930 [&](const Twine &msg) {
931 return rewriter.notifyMatchFailure(linalgOp, msg);
932 },
933 opOperandType.getShape(), expandedOperandType.getShape(),
934 reassociation,
935 /*isExpandingReshape=*/true)))
936 return std::nullopt;
937 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
938 loc, expandedOperandType, opOperand->get(), reassociation,
939 expandedOperandShape));
940 continue;
941 }
942 }
943 expandedOpOperands.push_back(opOperand->get());
944 }
945
946 SmallVector<Value> outputs;
947 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
948 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
949 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
950 SmallVector<OpFoldResult> expandedOutputShape;
951 RankedTensorType expandedOutputType;
952 std::tie(expandedOutputShape, expandedOutputType) =
953 getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
954 if (expandedOutputType != opOperand.get().getType()) {
955 SmallVector<ReassociationIndices> reassociation =
956 getReassociationForExpansion(indexingMap, expansionInfo);
957 if (failed(reshapeLikeShapesAreCompatible(
958 [&](const Twine &msg) {
959 return rewriter.notifyMatchFailure(linalgOp, msg);
960 },
961 opOperandType.getShape(), expandedOutputType.getShape(),
962 reassociation,
963 /*isExpandingReshape=*/true)))
964 return std::nullopt;
965 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
966 loc, expandedOutputType, opOperand.get(), reassociation,
967 expandedOutputShape));
968 } else {
969 outputs.push_back(opOperand.get());
970 }
971 }
972
973 TypeRange resultTypes = ValueRange(outputs).getTypes();
974 Operation *fusedOp =
975 createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
976 outputs, expandedOpIndexingMaps, expansionInfo);
977 // Reshape the result values to their original shape if this is a collapsing
978 // reshape folded into its consumer.
979 SmallVector<Value> resultVals;
980 for (OpResult opResult : linalgOp->getOpResults()) {
981 int64_t resultNumber = opResult.getResultNumber();
982 if (resultTypes[resultNumber] != opResult.getType()) {
983 SmallVector<ReassociationIndices> reassociation =
984 getReassociationForExpansion(
985 linalgOp.getMatchingIndexingMap(
986 linalgOp.getDpsInitOperand(resultNumber)),
987 expansionInfo);
988 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
989 linalgOp.getLoc(), opResult.getType(),
990 fusedOp->getResult(resultNumber), reassociation));
991 } else {
992 resultVals.push_back(fusedOp->getResult(resultNumber));
993 }
994 }
995 // Assuming a single result.
996 return resultVals;
997}
998
999namespace {
1000
1001/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
1002/// when the reshape op is collapsing dimensions. The dimensionality of the loop
1003/// in the consumer is expanded.
1004class FoldWithProducerReshapeOpByExpansion
1005 : public OpInterfaceRewritePattern<LinalgOp> {
1006public:
1007 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
1008 ControlFusionFn foldReshapes,
1009 PatternBenefit benefit = 1)
1010 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
1011 controlFoldingReshapes(std::move(foldReshapes)) {}
1012
1013 LogicalResult matchAndRewrite(LinalgOp linalgOp,
1014 PatternRewriter &rewriter) const override {
1015 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
1016 tensor::CollapseShapeOp reshapeOp =
1017 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
1018 if (!reshapeOp)
1019 continue;
1020 // Fold only if
1021 // - The tensor reshape op is folding.
1022 // - All constraints of fusing with reshape by expansion are met.
1023 if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
1024 (!controlFoldingReshapes(opOperand)))
1025 continue;
1026
1027 std::optional<SmallVector<Value>> replacementValues =
1028 fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
1029 if (!replacementValues)
1030 return failure();
1031 rewriter.replaceOp(linalgOp, *replacementValues);
1032 return success();
1033 }
1034 return failure();
1035 }
1036
1037private:
1038 ControlFusionFn controlFoldingReshapes;
1039};
1040
1041class FoldPadWithProducerReshapeOpByExpansion
1042 : public OpRewritePattern<tensor::PadOp> {
1043public:
1044 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1045 ControlFusionFn foldReshapes,
1046 PatternBenefit benefit = 1)
1047 : OpRewritePattern<tensor::PadOp>(context, benefit),
1048 controlFoldingReshapes(std::move(foldReshapes)) {}
1049
1050 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1051 PatternRewriter &rewriter) const override {
1052 tensor::CollapseShapeOp reshapeOp =
1053 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1054 if (!reshapeOp)
1055 return failure();
1056 if (!reshapeOp->hasOneUse())
1057 return failure();
1058
1059 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1060 return rewriter.notifyMatchFailure(padOp,
1061 "fusion blocked by control function");
1062 }
1063
1064 ArrayRef<int64_t> low = padOp.getStaticLow();
1065 ArrayRef<int64_t> high = padOp.getStaticHigh();
1066 SmallVector<ReassociationIndices> reassociations =
1067 reshapeOp.getReassociationIndices();
1068
1069 for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070 if (reInd.size() != 1 && (l != 0 || h != 0))
1071 return failure();
1072 }
1073
1074 SmallVector<OpFoldResult> newLow, newHigh;
1075 RankedTensorType expandedType = reshapeOp.getSrcType();
1076 RankedTensorType paddedType = padOp.getResultType();
1077 SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1078 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1079 if (reInd.size() == 1) {
1080 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1081 }
1082 for (size_t i = 0; i < reInd.size(); ++i) {
1083 newLow.push_back(padOp.getMixedLowPad()[idx]);
1084 newHigh.push_back(padOp.getMixedHighPad()[idx]);
1085 }
1086 }
1087
1088 Location loc = padOp->getLoc();
1089 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1090 auto newPadOp = rewriter.create<tensor::PadOp>(
1091 loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1092 padOp.getConstantPaddingValue(), padOp.getNofold());
1093
1094 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1095 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1096
1097 return success();
1098 }
1099
1100private:
1101 ControlFusionFn controlFoldingReshapes;
1102};
1103
1104/// Pattern to fold a tensor.expand_shape op with its producer generic op
1105/// by expanding the dimensionality of the loop in the producer op.
1106struct FoldReshapeWithGenericOpByExpansion
1107 : public OpRewritePattern<tensor::ExpandShapeOp> {
1108
1109 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1110 ControlFusionFn foldReshapes,
1111 PatternBenefit benefit = 1)
1112 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1113 controlFoldingReshapes(std::move(foldReshapes)) {}
1114
1115 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1116 PatternRewriter &rewriter) const override {
1117 // Fold only if all constraints of fusing with reshape by expansion are met.
1118 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1119 if (!producerResult) {
1120 return rewriter.notifyMatchFailure(reshapeOp,
1121 "source not produced by an operation");
1122 }
1123
1124 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1125 if (!producer) {
1126 return rewriter.notifyMatchFailure(reshapeOp,
1127 "producer not a generic op");
1128 }
1129
1130 if (!isFusableWithReshapeByDimExpansion(
1131 producer,
1132 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1133 return rewriter.notifyMatchFailure(
1134 reshapeOp, "failed preconditions of fusion with producer generic op");
1135 }
1136
1137 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1138 return rewriter.notifyMatchFailure(reshapeOp,
1139 "fusion blocked by control function");
1140 }
1141
1142 std::optional<SmallVector<Value>> replacementValues =
1143 fuseWithReshapeByExpansion(
1144 producer, reshapeOp,
1145 producer.getDpsInitOperand(producerResult.getResultNumber()),
1146 rewriter);
1147 if (!replacementValues) {
1148 return rewriter.notifyMatchFailure(reshapeOp,
1149 "fusion by expansion failed");
1150 }
1151
1152 // Find the replacement for the reshape op. Since the replacements have the
1153 // same type as the returns of the original generic op, the consumer reshape
1154 // op can be replaced by the source of the collapse_shape op that defines
1155 // the replacement.
1156 Value reshapeReplacement =
1157 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1158 .getResultNumber()];
1159 if (auto collapseOp =
1160 reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1161 reshapeReplacement = collapseOp.getSrc();
1162 }
1163 rewriter.replaceOp(reshapeOp, reshapeReplacement);
1164 rewriter.replaceOp(producer, *replacementValues);
1165 return success();
1166 }
1167
1168private:
1169 ControlFusionFn controlFoldingReshapes;
1170};
1171} // namespace
1172
1173//===---------------------------------------------------------------------===//
1174// Methods and patterns to fuse reshape with linalg.generic operations by
1175// contraction of dimensions.
1176//===---------------------------------------------------------------------===//
1177
1178/// For a given list of indices in the range of the `indexingMap` that are
1179/// folded, return the indices of the corresponding domain. Return
1180/// `std::nullopt` on failure. Ensures that all the elements of the returned
1181/// reassociation are distinct.
1182static ReassociationIndices
1183getDomainReassociation(AffineMap indexingMap,
1184 ReassociationIndicesRef rangeReassociation) {
1185 assert(indexingMap.isProjectedPermutation() &&
1186 "expected projected permutation");
1187
1188 ReassociationIndices domainReassociation = llvm::to_vector<4>(
1189 Range: llvm::map_range(C&: rangeReassociation, F: [&](int64_t pos) -> int64_t {
1190 return cast<AffineDimExpr>(Val: indexingMap.getResults()[pos]).getPosition();
1191 }));
1192 // The projected permutation semantics ensures that there is no repetition of
1193 // the domain indices.
1194 return domainReassociation;
1195}
1196
1197/// For a given `dimSequence`, check if the sequence is conserved in the
1198/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1199/// Non-existence of the sequence returns true as well.
1200bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
1201 ReassociationIndicesRef dimSequence) {
1202 assert(!dimSequence.empty() &&
1203 "expected non-empty list for dimension sequence");
1204 assert(indexingMap.isProjectedPermutation() &&
1205 "expected indexing map to be projected permutation");
1206
1207 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1208 sequenceElements.insert_range(R&: dimSequence);
1209
1210 unsigned dimSequenceStart = dimSequence[0];
1211 for (const auto &expr : enumerate(First: indexingMap.getResults())) {
1212 unsigned dimInMapStart = cast<AffineDimExpr>(Val: expr.value()).getPosition();
1213 // 1. Check if this start of the sequence.
1214 if (dimInMapStart == dimSequenceStart) {
1215 if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1216 return false;
1217 // 1a. Check if sequence is preserved.
1218 for (const auto &dimInSequence : enumerate(First&: dimSequence)) {
1219 unsigned dimInMap =
1220 cast<AffineDimExpr>(
1221 Val: indexingMap.getResult(idx: expr.index() + dimInSequence.index()))
1222 .getPosition();
1223 if (dimInMap != dimInSequence.value())
1224 return false;
1225 }
1226 // Found the sequence. Projected permutation
1227 // enforces that all AffineDimExprs in the result are unique, so no
1228 // further checks are needed.
1229 return true;
1230 }
1231 // 2. If position in the expr (which is of type AffineDimExpr) is part
1232 // of sequence, return false here. This implies the entire sequence does not
1233 // exist in the indexing map.
1234 if (sequenceElements.count(V: dimInMapStart))
1235 return false;
1236 }
1237 // 3. No element of sequence found. Return true.
1238 return true;
1239}
1240
1241bool mlir::linalg::areDimSequencesPreserved(
1242 ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {
1243 return llvm::all_of(Range&: maps, P: [&](AffineMap map) {
1244 return llvm::all_of(Range&: dimSequences, P: [&](ReassociationIndicesRef dimSequence) {
1245 return isDimSequencePreserved(indexingMap: map, dimSequence);
1246 });
1247 });
1248}
1249
1250// Return the list of dimensions of the iteration domain that can be
1251// collapsed to allow for fusion with the a producer that is an expand_shape
1252// operation. If all dimensions created by expansion can be collapsed in the
1253// iteration space then the reshape is defunct.
1254//
1255// Example:
1256//
1257// ```mlir
1258// #map = affine_map<(d0, d1) -> (d0, d1)>
1259// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1260// %2 = tensor.empty [..] : tensor<?x4xf32>
1261// %3 = linalg.generic {
1262// indexing_maps = [#map, #map],
1263// iterator_types = ["parallel" ,"parallel"]}
1264// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1265// ```
1266//
1267// can be fused by collapsing the dimensions of the iteration space.
1268//
1269// ```mlir
1270// #map = affine_map<(d0) -> (d0)>
1271// %2 = tensor.empty [..] : tensor<?xf32>
1272// %3 = linalg.generic {
1273// indexing_maps = [#map, #map],
1274// iterator_types = ["parallel"]}
1275// ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1276// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1277// ```
1278//
1279// In the following example,
1280//
1281// ```mlir
1282// #map0 = affine_map<(d0, d1) -> (d0, d1)>
1283// #map1 = affine_map<(d0, d1) -> (d1, d0)>
1284// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1285// %2 = tensor.empty [..] : tensor<4x?xf32>
1286// %2 = linalg.generic {
1287// indexing_maps = [#map0, #map1],
1288// iterator_types = ["parallel" ,"parallel"]}
1289// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1290// ```
1291//
1292// the reshape cannot be fused with the generic op by collapsing the op
1293// dimensions since the indexing maps will have to contain mods and divs
1294// to preserve the accesses pattern. When no dimensions of the iteration
1295// space are collapsable and empty vector is returned.
1296static SmallVector<ReassociationIndices>
1297getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1298 ArrayRef<ReassociationIndices> reassociation) {
1299 // Some basic checks for this fusion to be valid.
1300 if (!genericOp.hasPureTensorSemantics())
1301 return {};
1302
1303 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1304 return map.isProjectedPermutation();
1305 })) {
1306 return {};
1307 }
1308
1309 // Compute all the loops with the reduction iterator types.
1310 SmallVector<unsigned> reductionDims;
1311 genericOp.getReductionDims(reductionDims);
1312
1313 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1314 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1315 auto iteratorTypes = genericOp.getIteratorTypesArray();
1316 SmallVector<ReassociationIndices> iterationSpaceReassociation;
1317 for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1318 assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1319
1320 // Ignore dims that are not folded.
1321 if (foldedRangeDims.size() == 1)
1322 continue;
1323
1324 ReassociationIndices foldedIterationSpaceDims =
1325 getDomainReassociation(indexingMap, rangeReassociation: foldedRangeDims);
1326
1327 // Check that the folded iteration dims do not contain already processed
1328 // dims.
1329 if (llvm::any_of(Range&: foldedIterationSpaceDims, P: [&](int64_t dim) {
1330 return processedIterationDims.count(V: dim);
1331 }))
1332 continue;
1333
1334 // Check that all folded iterator types are all parallel or all reductions.
1335 utils::IteratorType startIteratorType =
1336 iteratorTypes[foldedIterationSpaceDims[0]];
1337 if (!isParallelIterator(startIteratorType) &&
1338 !isReductionIterator(startIteratorType))
1339 continue;
1340 if (llvm::any_of(Range&: foldedIterationSpaceDims, P: [&](int64_t dim) {
1341 return iteratorTypes[dim] != startIteratorType;
1342 }))
1343 continue;
1344
1345 // If the folded dimensions correspond to a "reduction" iterator type,
1346 // the folded dimensions need to be "in-order". Strictly speaking this is
1347 // not necessary, for reductions that are associative and commutative, but
1348 // using a more strict definition of reduction for now.
1349 if (isReductionIterator(startIteratorType)) {
1350 bool isContiguous = false;
1351 for (const auto &startDim : llvm::enumerate(First&: reductionDims)) {
1352 // Move window in `reductionDims` to start of the folded iteration dims.
1353 if (startDim.value() != foldedIterationSpaceDims[0])
1354 continue;
1355 // If sizes doesnt match, trivial not contiguous. This condition should
1356 // not be hit.
1357 if (startDim.index() + foldedIterationSpaceDims.size() >
1358 reductionDims.size())
1359 break;
1360 // Check that the contiguity is maintained.
1361 isContiguous = true;
1362 for (const auto &foldedDim :
1363 llvm::enumerate(foldedIterationSpaceDims)) {
1364 if (reductionDims[foldedDim.index() + startDim.index()] !=
1365 foldedDim.value()) {
1366 isContiguous = false;
1367 break;
1368 }
1369 }
1370 break;
1371 }
1372 if (!isContiguous)
1373 continue;
1374 }
1375
1376 // Check that the sequence is preserved in all indexing maps.
1377 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1378 [&](AffineMap indexingMap) {
1379 return !isDimSequencePreserved(indexingMap,
1380 foldedIterationSpaceDims);
1381 }))
1382 continue;
1383
1384 processedIterationDims.insert_range(R&: foldedIterationSpaceDims);
1385 iterationSpaceReassociation.emplace_back(
1386 Args: std::move(foldedIterationSpaceDims));
1387 }
1388
1389 return iterationSpaceReassociation;
1390}
1391
1392/// Helper class to carry state while collapsing the `linalg.generic` op.
1393namespace {
1394class CollapsingInfo {
1395public:
1396 LogicalResult initialize(unsigned origNumLoops,
1397 ArrayRef<ReassociationIndices> foldedIterationDims) {
1398 llvm::SmallDenseSet<int64_t, 4> processedDims;
1399 // Find all the dims that are folded.
1400 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1401 if (foldedIterationDim.empty())
1402 continue;
1403 // If the folded dims contain dims already folded, that's illegal
1404 // specification. Repetition within a list is also illegal.
1405 for (auto dim : foldedIterationDim) {
1406 if (dim >= origNumLoops)
1407 return failure();
1408 if (processedDims.count(V: dim))
1409 return failure();
1410 processedDims.insert(V: dim);
1411 }
1412 collapsedOpToOrigOpIterationDim.emplace_back(Args: foldedIterationDim.begin(),
1413 Args: foldedIterationDim.end());
1414 }
1415 if (processedDims.size() > origNumLoops)
1416 return failure();
1417
1418 // Add all the preserved dims of the original op as single
1419 // elements to `collapsedOpToOrigOpIterationDim`.
1420 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: origNumLoops)) {
1421 if (processedDims.count(V: dim))
1422 continue;
1423 collapsedOpToOrigOpIterationDim.emplace_back(Args: ReassociationIndices{dim});
1424 }
1425
1426 llvm::sort(C&: collapsedOpToOrigOpIterationDim,
1427 Comp: [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1428 return lhs[0] < rhs[0];
1429 });
1430 origOpToCollapsedOpIterationDim.resize(N: origNumLoops);
1431 for (const auto &foldedDims :
1432 llvm::enumerate(First&: collapsedOpToOrigOpIterationDim)) {
1433 for (const auto &dim : enumerate(First&: foldedDims.value()))
1434 origOpToCollapsedOpIterationDim[dim.value()] =
1435 std::make_pair<int64_t, unsigned>(x: foldedDims.index(), y: dim.index());
1436 }
1437 return success();
1438 }
1439
1440 /// Return mapping from collapsed loop domain to original loop domain.
1441 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1442 return collapsedOpToOrigOpIterationDim;
1443 }
1444
1445 /// Return mapping from original loop domain to collapsed loop domain. The
1446 /// mapping is a pair. First value is the dimension in the collapsed loop that
1447 /// the original loop is mapped to. Second is the relative position in folded
1448 /// list of this domain. For example if the original loop domain is 3D, and
1449 /// the collapsed loop domain is folding all of it, i.e.
1450 ///
1451 /// ```
1452 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1453 /// ```
1454 ///
1455 /// then
1456 ///
1457 /// ```
1458 /// origOpToCollapsedOpMapping[0] = {0, 0};
1459 /// origOpToCollapsedOpMapping[1] = {0, 1};
1460 /// origOpToCollapsedOpMapping[2] = {0, 2};
1461 /// origOpToCollapsedOpMapping[3] = {1, 0};
1462 /// origOpToCollapsedOpMapping[4] = {1, 1};
1463 /// ```
1464 ///
1465 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1466 return origOpToCollapsedOpIterationDim;
1467 }
1468
1469 /// Return the collapsed op iteration domain rank.
1470 unsigned getCollapsedOpIterationRank() const {
1471 return collapsedOpToOrigOpIterationDim.size();
1472 }
1473
1474private:
1475 /// Map from the iteration domain index in collapsed op to the iteration
1476 /// domain indices in the original op.
1477 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1478
1479 /// Map from iteration domain index in the original op to the iteration domain
1480 /// index in the collapsed op.
1481 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1482};
1483} // namespace
1484
1485/// Get the iterator types for the collapsed operation given the original
1486/// iterator types and collapsed dimensions.
1487static SmallVector<utils::IteratorType>
1488getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1489 const CollapsingInfo &collapsingInfo) {
1490 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1491 for (ReassociationIndicesRef foldedIterDims :
1492 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1493 assert(!foldedIterDims.empty() &&
1494 "reassociation indices expected to have non-empty sets");
1495 // Just pick the iterator type of the first folded dim. Pre-condition checks
1496 // expected to have checked that iterator types of all folded dimensions are
1497 // the same.
1498 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1499 }
1500 return collapsedIteratorTypes;
1501}
1502
1503/// Compute the indexing map in the collapsed op that corresponds to the given
1504/// `indexingMap` of the original operation.
1505static AffineMap
1506getCollapsedOpIndexingMap(AffineMap indexingMap,
1507 const CollapsingInfo &collapsingInfo) {
1508 MLIRContext *context = indexingMap.getContext();
1509 assert(indexingMap.isProjectedPermutation() &&
1510 "expected indexing map to be projected permutation");
1511 SmallVector<AffineExpr> resultExprs;
1512 auto origOpToCollapsedOpMapping =
1513 collapsingInfo.getOrigOpToCollapsedOpMapping();
1514 for (auto expr : indexingMap.getResults()) {
1515 unsigned dim = cast<AffineDimExpr>(Val&: expr).getPosition();
1516 // If the dim is not the first of the collapsed dim, do nothing.
1517 if (origOpToCollapsedOpMapping[dim].second != 0)
1518 continue;
1519 // The next n-dims are guaranteed to be collapsed. So just use the
1520 // iteration dimension of the collapsed op.
1521 resultExprs.push_back(
1522 Elt: getAffineDimExpr(position: origOpToCollapsedOpMapping[dim].first, context));
1523 }
1524 return AffineMap::get(dimCount: collapsingInfo.getCollapsedOpIterationRank(), symbolCount: 0,
1525 results: resultExprs, context);
1526}
1527
1528/// Return the `reassociation` indices to use to collapse the operand when the
1529/// iteration space of a generic op is collapsed.
1530static SmallVector<ReassociationIndices>
1531getOperandReassociation(AffineMap indexingMap,
1532 const CollapsingInfo &collapsingInfo) {
1533 unsigned counter = 0;
1534 SmallVector<ReassociationIndices> operandReassociation;
1535 auto origOpToCollapsedOpMapping =
1536 collapsingInfo.getOrigOpToCollapsedOpMapping();
1537 auto collapsedOpToOrigOpMapping =
1538 collapsingInfo.getCollapsedOpToOrigOpMapping();
1539 while (counter < indexingMap.getNumResults()) {
1540 unsigned dim =
1541 cast<AffineDimExpr>(Val: indexingMap.getResult(idx: counter)).getPosition();
1542 // This is the start of a collapsed dimensions of the iteration that
1543 // is gauranteed to be preserved in the indexing map. The number of folded
1544 // dims is obtained from the collapsed op to original op mapping.
1545 unsigned numFoldedDims =
1546 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1547 .size();
1548 if (origOpToCollapsedOpMapping[dim].second == 0) {
1549 auto range = llvm::seq<unsigned>(Begin: counter, End: counter + numFoldedDims);
1550 operandReassociation.emplace_back(Args: range.begin(), Args: range.end());
1551 }
1552 counter += numFoldedDims;
1553 }
1554 return operandReassociation;
1555}
1556
1557/// Get the new value to use for a given `OpOperand` in the collapsed operation.
1558static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1559 OpOperand *opOperand,
1560 const CollapsingInfo &collapsingInfo,
1561 OpBuilder &builder) {
1562 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1563 SmallVector<ReassociationIndices> operandReassociation =
1564 getOperandReassociation(indexingMap, collapsingInfo);
1565
1566 // If the number of entries in the reassociation for the operand is same as
1567 // the number of results of the indexing map, then nothing to do for this
1568 // operand.
1569 Value operand = opOperand->get();
1570 if (operandReassociation.size() == indexingMap.getNumResults())
1571 return operand;
1572
1573 // Insert a reshape to collapse the dimensions.
1574 if (isa<MemRefType>(Val: operand.getType())) {
1575 return builder
1576 .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1577 .getResult();
1578 }
1579 return builder
1580 .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1581 .getResult();
1582}
1583
1584/// Modify the `linalg.index` operations in the original generic op, to its
1585/// value in the collapsed operation.
1586static void generateCollapsedIndexingRegion(
1587 Location loc, Block *block, const CollapsingInfo &collapsingInfo,
1588 ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
1589 OpBuilder::InsertionGuard g(rewriter);
1590 rewriter.setInsertionPointToStart(block);
1591
1592 // Collect all the original index ops.
1593 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1594
1595 // For each folded dimension list resolve the original induction variable
1596 // values in terms of the folded dimension induction variable.
1597 // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1598 // can be inverted to
1599 // i2 = i_{folded} % d2
1600 // i1 = (i_{folded} / d2) % d1
1601 // i0 = i_{folded} / (d1 * d2)
1602 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1603 for (auto foldedDims :
1604 enumerate(First: collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1605 ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1606 Value newIndexVal =
1607 rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1608 for (auto dim : llvm::reverse(C: foldedDimsRef.drop_front())) {
1609 Value loopDim =
1610 getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange[dim]);
1611 indexReplacementVals[dim] =
1612 rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
1613 newIndexVal =
1614 rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
1615 }
1616 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1617 }
1618
1619 for (auto indexOp : indexOps) {
1620 auto dim = indexOp.getDim();
1621 rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1622 }
1623}
1624
1625void collapseOperandsAndResults(LinalgOp op,
1626 const CollapsingInfo &collapsingInfo,
1627 RewriterBase &rewriter,
1628 SmallVectorImpl<Value> &inputOperands,
1629 SmallVectorImpl<Value> &outputOperands,
1630 SmallVectorImpl<Type> &resultTypes) {
1631 Location loc = op->getLoc();
1632 inputOperands =
1633 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1634 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1635 rewriter);
1636 });
1637
1638 // Get the output operands and result types.
1639 resultTypes.reserve(N: op.getNumDpsInits());
1640 outputOperands.reserve(N: op.getNumDpsInits());
1641 for (OpOperand &output : op.getDpsInitsMutable()) {
1642 Value newOutput =
1643 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1644 outputOperands.push_back(newOutput);
1645 // If the op has "buffer semantics", then the init operands are ranked
1646 // memrefs and the op has no results.
1647 if (!op.hasPureBufferSemantics())
1648 resultTypes.push_back(newOutput.getType());
1649 }
1650}
1651
1652/// Clone a `LinalgOp` to a collapsed version of same name
1653template <typename OpTy>
1654OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1655 const CollapsingInfo &collapsingInfo) {
1656 return nullptr;
1657}
1658
1659/// Collapse any `LinalgOp` that does not require any specialization such as
1660/// indexing_maps, iterator_types, etc.
1661template <>
1662LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1663 const CollapsingInfo &collapsingInfo) {
1664 SmallVector<Value> inputOperands, outputOperands;
1665 SmallVector<Type> resultTypes;
1666 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1667 outputOperands, resultTypes);
1668
1669 return clone(
1670 rewriter, origOp, resultTypes,
1671 llvm::to_vector(Range: llvm::concat<Value>(Ranges&: inputOperands, Ranges&: outputOperands)));
1672}
1673
1674/// Collapse a `GenericOp`
1675template <>
1676GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1677 GenericOp origOp,
1678 const CollapsingInfo &collapsingInfo) {
1679 SmallVector<Value> inputOperands, outputOperands;
1680 SmallVector<Type> resultTypes;
1681 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1682 outputOperands, resultTypes);
1683 SmallVector<AffineMap> indexingMaps(
1684 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1685 return getCollapsedOpIndexingMap(indexingMap: map, collapsingInfo);
1686 }));
1687
1688 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1689 origOp.getIteratorTypesArray(), collapsingInfo));
1690
1691 GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1692 origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1693 iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1694 Block *origOpBlock = &origOp->getRegion(0).front();
1695 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1696 rewriter.mergeBlocks(source: origOpBlock, dest: collapsedOpBlock,
1697 argValues: collapsedOpBlock->getArguments());
1698 return collapsedOp;
1699}
1700
1701LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1702 RewriterBase &rewriter) {
1703 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1704 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1705 } else {
1706 return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1707 }
1708}
1709
1710/// Implementation of fusion with reshape operation by collapsing dimensions.
1711FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1712 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1713 RewriterBase &rewriter) {
1714 // Bail on trivial no-op cases.
1715 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1716 llvm::all_of(Range&: foldedIterationDims, P: [](ReassociationIndicesRef foldedDims) {
1717 return foldedDims.size() <= 1;
1718 }))
1719 return failure();
1720
1721 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1722 if (hasPureBufferSemantics &&
1723 !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1724 MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1725 if (!memRefToCollapse)
1726 return true;
1727
1728 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1729 memRefToCollapse, foldedIterationDims);
1730 }))
1731 return rewriter.notifyMatchFailure(op,
1732 "memref is not guaranteed collapsible");
1733
1734 CollapsingInfo collapsingInfo;
1735 if (failed(
1736 collapsingInfo.initialize(origNumLoops: op.getNumLoops(), foldedIterationDims))) {
1737 return rewriter.notifyMatchFailure(
1738 op, "illegal to collapse specified dimensions");
1739 }
1740
1741 // Bail on non-canonical ranges.
1742 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1743 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1744 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1745 return cast<IntegerAttr>(attr).getInt() == value;
1746 llvm::APInt actual;
1747 return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1748 actual.getSExtValue() == value;
1749 };
1750 if (!llvm::all_of(Range&: loopRanges, P: [&](Range range) {
1751 return opFoldIsConstantValue(range.offset, 0) &&
1752 opFoldIsConstantValue(range.stride, 1);
1753 })) {
1754 return rewriter.notifyMatchFailure(
1755 op, "expected all loop ranges to have zero start and unit stride");
1756 }
1757
1758 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1759
1760 Location loc = op->getLoc();
1761 SmallVector<OpFoldResult> loopBound =
1762 llvm::map_to_vector(C&: loopRanges, F: [](Range range) { return range.size; });
1763
1764 if (collapsedOp.hasIndexSemantics()) {
1765 // Collect the loop range of the generic op.
1766 OpBuilder::InsertionGuard g(rewriter);
1767 rewriter.setInsertionPoint(collapsedOp);
1768 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1769 collapsingInfo, loopBound, rewriter);
1770 }
1771
1772 // Insert expanding reshape for the result to get back the original result
1773 // type.
1774 SmallVector<Value> results;
1775 for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1776 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1777 auto originalResultType =
1778 cast<ShapedType>(originalResult.value().getType());
1779 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1780 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1781 AffineMap indexingMap =
1782 op.getIndexingMapMatchingResult(originalResult.value());
1783 SmallVector<ReassociationIndices> reassociation =
1784 getOperandReassociation(indexingMap, collapsingInfo);
1785 assert(
1786 indexingMap.isProjectedPermutation() &&
1787 "Expected indexing map to be a projected permutation for collapsing");
1788 SmallVector<OpFoldResult> resultShape =
1789 applyPermutationMap(indexingMap, ArrayRef(loopBound));
1790 Value result;
1791 if (isa<MemRefType>(collapsedOpResult.getType())) {
1792 MemRefType expandShapeResultType = MemRefType::get(
1793 originalResultType.getShape(), originalResultType.getElementType());
1794 result = rewriter.create<memref::ExpandShapeOp>(
1795 loc, expandShapeResultType, collapsedOpResult, reassociation,
1796 resultShape);
1797 } else {
1798 result = rewriter.create<tensor::ExpandShapeOp>(
1799 loc, originalResultType, collapsedOpResult, reassociation,
1800 resultShape);
1801 }
1802 results.push_back(result);
1803 } else {
1804 results.push_back(collapsedOpResult);
1805 }
1806 }
1807 return CollapseResult{results, collapsedOp};
1808}
1809
1810namespace {
1811
1812/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1813/// contracting dimensions of the loop.
1814class FoldWithProducerReshapeOpByCollapsing
1815 : public OpRewritePattern<GenericOp> {
1816public:
1817 // TODO : support fusion with all linalg ops, not just generic.
1818 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1819 ControlFusionFn foldReshapes,
1820 PatternBenefit benefit = 1)
1821 : OpRewritePattern<GenericOp>(context, benefit),
1822 controlFoldingReshapes(std::move(foldReshapes)) {}
1823
1824 LogicalResult matchAndRewrite(GenericOp genericOp,
1825 PatternRewriter &rewriter) const override {
1826 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1827 tensor::ExpandShapeOp reshapeOp =
1828 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1829 if (!reshapeOp)
1830 continue;
1831
1832 SmallVector<ReassociationIndices> collapsableIterationDims =
1833 getCollapsableIterationSpaceDims(genericOp, &opOperand,
1834 reshapeOp.getReassociationIndices());
1835 if (collapsableIterationDims.empty() ||
1836 !controlFoldingReshapes(&opOperand)) {
1837 continue;
1838 }
1839
1840 std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1841 genericOp, collapsableIterationDims, rewriter);
1842 if (!collapseResult) {
1843 return rewriter.notifyMatchFailure(
1844 genericOp, "failed to do the fusion by collapsing transformation");
1845 }
1846
1847 rewriter.replaceOp(genericOp, collapseResult->results);
1848 return success();
1849 }
1850 return failure();
1851 }
1852
1853private:
1854 ControlFusionFn controlFoldingReshapes;
1855};
1856
1857/// Pattern to fold a tensor.collapse_shape op with its producer generic op
1858/// by expanding the dimensionality of the loop in the producer op.
1859struct FoldReshapeWithGenericOpByCollapsing
1860 : public OpRewritePattern<tensor::CollapseShapeOp> {
1861
1862 FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
1863 ControlFusionFn foldReshapes,
1864 PatternBenefit benefit = 1)
1865 : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1866 controlFoldingReshapes(std::move(foldReshapes)) {}
1867
1868 LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
1869 PatternRewriter &rewriter) const override {
1870 // Fold only if all constraints of fusing with reshape by collapsing are
1871 // met.
1872 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1873 if (!producerResult) {
1874 return rewriter.notifyMatchFailure(reshapeOp,
1875 "source not produced by an operation");
1876 }
1877
1878 // TODO : support fusion with all linalg producers, not just generic.
1879 auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
1880 if (!producer) {
1881 return rewriter.notifyMatchFailure(reshapeOp,
1882 "producer not a generic op");
1883 }
1884
1885 SmallVector<ReassociationIndices> collapsableIterationDims =
1886 getCollapsableIterationSpaceDims(
1887 producer,
1888 producer.getDpsInitOperand(producerResult.getResultNumber()),
1889 reshapeOp.getReassociationIndices());
1890 if (collapsableIterationDims.empty()) {
1891 return rewriter.notifyMatchFailure(
1892 reshapeOp, "failed preconditions of fusion with producer generic op");
1893 }
1894
1895 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1896 return rewriter.notifyMatchFailure(reshapeOp,
1897 "fusion blocked by control function");
1898 }
1899
1900 // Set the insertion point after `producer` because there could be uses
1901 // of `producer` between it and the `tensor.collapse_shape` op.
1902 rewriter.setInsertionPointAfter(producer);
1903 std::optional<CollapseResult> collapseResult =
1904 collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
1905 if (!collapseResult) {
1906 return rewriter.notifyMatchFailure(
1907 producer, "failed to do the fusion by collapsing transformation");
1908 }
1909
1910 rewriter.replaceOp(producer, collapseResult->results);
1911 return success();
1912 }
1913
1914private:
1915 ControlFusionFn controlFoldingReshapes;
1916};
1917
1918class FoldPadWithProducerReshapeOpByCollapsing
1919 : public OpRewritePattern<tensor::PadOp> {
1920public:
1921 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1922 ControlFusionFn foldReshapes,
1923 PatternBenefit benefit = 1)
1924 : OpRewritePattern<tensor::PadOp>(context, benefit),
1925 controlFoldingReshapes(std::move(foldReshapes)) {}
1926
1927 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1928 PatternRewriter &rewriter) const override {
1929 tensor::ExpandShapeOp reshapeOp =
1930 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1931 if (!reshapeOp)
1932 return failure();
1933 if (!reshapeOp->hasOneUse())
1934 return failure();
1935
1936 if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1937 return rewriter.notifyMatchFailure(padOp,
1938 "fusion blocked by control function");
1939 }
1940
1941 ArrayRef<int64_t> low = padOp.getStaticLow();
1942 ArrayRef<int64_t> high = padOp.getStaticHigh();
1943 SmallVector<ReassociationIndices> reassociations =
1944 reshapeOp.getReassociationIndices();
1945
1946 for (auto reInd : reassociations) {
1947 if (reInd.size() == 1)
1948 continue;
1949 if (llvm::any_of(reInd, [&](int64_t ind) {
1950 return low[ind] != 0 || high[ind] != 0;
1951 })) {
1952 return failure();
1953 }
1954 }
1955
1956 SmallVector<OpFoldResult> newLow, newHigh;
1957 RankedTensorType collapsedType = reshapeOp.getSrcType();
1958 RankedTensorType paddedType = padOp.getResultType();
1959 SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1960 SmallVector<OpFoldResult> expandedPaddedSizes(
1961 getMixedValues(reshapeOp.getStaticOutputShape(),
1962 reshapeOp.getOutputShape(), rewriter));
1963 AffineExpr d0, d1, d2;
1964 bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1, exprs&: d2);
1965 auto addMap = AffineMap::get(dimCount: 3, symbolCount: 0, result: {d0 + d1 + d2});
1966 Location loc = reshapeOp->getLoc();
1967 for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1968 OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1969 OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1970 if (reInd.size() == 1) {
1971 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1972 OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
1973 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1974 expandedPaddedSizes[reInd[0]] = paddedSize;
1975 }
1976 newLow.push_back(l);
1977 newHigh.push_back(h);
1978 }
1979
1980 RankedTensorType collapsedPaddedType =
1981 paddedType.clone(collapsedPaddedShape);
1982 auto newPadOp = rewriter.create<tensor::PadOp>(
1983 loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1984 padOp.getConstantPaddingValue(), padOp.getNofold());
1985
1986 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1987 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1988 expandedPaddedSizes);
1989
1990 return success();
1991 }
1992
1993private:
1994 ControlFusionFn controlFoldingReshapes;
1995};
1996
1997/// Pattern to collapse dimensions.
1998template <typename LinalgType>
1999class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
2000public:
2001 CollapseLinalgDimensions(MLIRContext *context,
2002 GetCollapsableDimensionsFn collapseDimensions,
2003 PatternBenefit benefit = 1)
2004 : OpRewritePattern<LinalgType>(context, benefit),
2005 controlCollapseDimension(std::move(collapseDimensions)) {}
2006
2007 LogicalResult matchAndRewrite(LinalgType op,
2008 PatternRewriter &rewriter) const override {
2009 SmallVector<ReassociationIndices> collapsableIterationDims =
2010 controlCollapseDimension(op);
2011 if (collapsableIterationDims.empty())
2012 return failure();
2013
2014 // Check if the specified list of dimensions to collapse is a valid list.
2015 if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
2016 collapsableIterationDims)) {
2017 return rewriter.notifyMatchFailure(
2018 op, "specified dimensions cannot be collapsed");
2019 }
2020
2021 std::optional<CollapseResult> collapseResult =
2022 collapseOpIterationDims(op, collapsableIterationDims, rewriter);
2023 if (!collapseResult) {
2024 return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
2025 }
2026 rewriter.replaceOp(op, collapseResult->results);
2027 return success();
2028 }
2029
2030private:
2031 GetCollapsableDimensionsFn controlCollapseDimension;
2032};
2033
2034} // namespace
2035
2036//===---------------------------------------------------------------------===//
2037// Methods and patterns that fuse constants with linalg.generic operations.
2038//===---------------------------------------------------------------------===//
2039
2040namespace {
2041/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
2042/// handle cases where the constant is not single-valued.
2043class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
2044public:
2045 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
2046 : OpRewritePattern<GenericOp>(context, benefit) {}
2047
2048 LogicalResult matchAndRewrite(GenericOp genericOp,
2049 PatternRewriter &rewriter) const override {
2050 if (!genericOp.hasPureTensorSemantics())
2051 return failure();
2052 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2053 Operation *def = opOperand->get().getDefiningOp();
2054 TypedAttr constantAttr;
2055 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
2056 {
2057 DenseElementsAttr splatAttr;
2058 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
2059 splatAttr.isSplat() &&
2060 splatAttr.getType().getElementType().isIntOrFloat()) {
2061 constantAttr = splatAttr.getSplatValue<TypedAttr>();
2062 return true;
2063 }
2064 }
2065 {
2066 IntegerAttr intAttr;
2067 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
2068 constantAttr = intAttr;
2069 return true;
2070 }
2071 }
2072 {
2073 FloatAttr floatAttr;
2074 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
2075 constantAttr = floatAttr;
2076 return true;
2077 }
2078 }
2079 return false;
2080 };
2081
2082 auto resultValue = dyn_cast<OpResult>(opOperand->get());
2083 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
2084 continue;
2085
2086 // The operands and the indexing_maps of the fused operation the same as
2087 // the operands and indexing_maps of the generic operations with the
2088 // values at the constant index dropped.
2089 SmallVector<AffineMap> fusedIndexMaps;
2090 SmallVector<Value> fusedOperands;
2091 SmallVector<Location> fusedLocs{genericOp.getLoc()};
2092 fusedIndexMaps.reserve(genericOp->getNumOperands());
2093 fusedOperands.reserve(genericOp.getNumDpsInputs());
2094 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
2095 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
2096 if (inputOperand == opOperand)
2097 continue;
2098 Value inputValue = inputOperand->get();
2099 fusedIndexMaps.push_back(
2100 genericOp.getMatchingIndexingMap(inputOperand));
2101 fusedOperands.push_back(inputValue);
2102 fusedLocs.push_back(inputValue.getLoc());
2103 }
2104 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2105 fusedIndexMaps.push_back(
2106 genericOp.getMatchingIndexingMap(&outputOperand));
2107
2108 // Check if the operation shapes to loops map is computable.
2109 if (!inversePermutation(
2110 concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2111 return rewriter.notifyMatchFailure(
2112 genericOp, "fused op loop bound computation failed");
2113 }
2114
2115 // Create a constant scalar value from the splat constant.
2116 Value scalarConstant =
2117 rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
2118
2119 SmallVector<Value> outputOperands = genericOp.getOutputs();
2120 auto fusedOp = rewriter.create<GenericOp>(
2121 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2122 /*inputs=*/fusedOperands,
2123 /*outputs=*/outputOperands,
2124 rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2125 genericOp.getIteratorTypes(),
2126 /*doc=*/nullptr,
2127 /*library_call=*/nullptr);
2128
2129 // Map the block argument corresponding to the replaced argument with the
2130 // scalar constant.
2131 Region &region = genericOp->getRegion(0);
2132 Block &entryBlock = *region.begin();
2133 IRMapping mapping;
2134 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2135 scalarConstant);
2136 Region &fusedRegion = fusedOp->getRegion(0);
2137 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2138 mapping);
2139 rewriter.replaceOp(genericOp, fusedOp->getResults());
2140 return success();
2141 }
2142 return failure();
2143 }
2144};
2145
2146} // namespace
2147
2148//===---------------------------------------------------------------------===//
2149// Miscellaneous patterns that help fusion.
2150//===---------------------------------------------------------------------===//
2151
2152namespace {
2153/// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2154/// value of the `outs` operand is not used within the op. This is only
2155/// implemented for `linalg.generic` operations for now, but should hold for all
2156/// linalg structured ops.
2157struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2158 using OpRewritePattern<GenericOp>::OpRewritePattern;
2159
2160 LogicalResult matchAndRewrite(GenericOp op,
2161 PatternRewriter &rewriter) const override {
2162 rewriter.startOpModification(op: op);
2163 bool modifiedOutput = false;
2164 Location loc = op.getLoc();
2165 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2166 if (!op.payloadUsesValueFromOperand(&opOperand)) {
2167 Value operandVal = opOperand.get();
2168 auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2169 if (!operandType)
2170 continue;
2171
2172 // If outs is sparse, leave it to the sparsifier.
2173 if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
2174 continue;
2175
2176 // If outs is already an `empty` operation, nothing to do.
2177 auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2178 if (definingOp)
2179 continue;
2180 modifiedOutput = true;
2181 SmallVector<OpFoldResult> mixedSizes =
2182 tensor::getMixedSizes(rewriter, loc, operandVal);
2183 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
2184 loc, mixedSizes, operandType.getElementType());
2185 op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2186 }
2187 }
2188 if (!modifiedOutput) {
2189 rewriter.cancelOpModification(op: op);
2190 return failure();
2191 }
2192 rewriter.finalizeOpModification(op: op);
2193 return success();
2194 }
2195};
2196
2197/// Fold linalg.fill into linalg.generic
2198struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2199 using OpRewritePattern<GenericOp>::OpRewritePattern;
2200
2201 LogicalResult matchAndRewrite(GenericOp genericOp,
2202 PatternRewriter &rewriter) const override {
2203 if (!genericOp.hasPureTensorSemantics())
2204 return failure();
2205 bool fillFound = false;
2206 Block &payload = genericOp.getRegion().front();
2207 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2208 if (!genericOp.payloadUsesValueFromOperand(opOperand))
2209 continue;
2210 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2211 if (!fillOp)
2212 continue;
2213 fillFound = true;
2214 Value fillVal = fillOp.value();
2215 auto resultType =
2216 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2217 Value convertedVal =
2218 convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2219 /*isUnsignedCast =*/false);
2220 rewriter.replaceAllUsesWith(
2221 payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2222 }
2223 return success(IsSuccess: fillFound);
2224 }
2225};
2226} // namespace
2227
2228void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
2229 RewritePatternSet &patterns,
2230 const ControlFusionFn &controlFoldingReshapes) {
2231 patterns.add<FoldReshapeWithGenericOpByExpansion>(arg: patterns.getContext(),
2232 args: controlFoldingReshapes);
2233 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(arg: patterns.getContext(),
2234 args: controlFoldingReshapes);
2235 patterns.add<FoldWithProducerReshapeOpByExpansion>(arg: patterns.getContext(),
2236 args: controlFoldingReshapes);
2237}
2238
2239void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
2240 RewritePatternSet &patterns,
2241 const ControlFusionFn &controlFoldingReshapes) {
2242 patterns.add<FoldWithProducerReshapeOpByCollapsing>(arg: patterns.getContext(),
2243 args: controlFoldingReshapes);
2244 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2245 arg: patterns.getContext(), args: controlFoldingReshapes);
2246 patterns.add<FoldReshapeWithGenericOpByCollapsing>(arg: patterns.getContext(),
2247 args: controlFoldingReshapes);
2248}
2249
2250void mlir::linalg::populateElementwiseOpsFusionPatterns(
2251 RewritePatternSet &patterns,
2252 const ControlFusionFn &controlElementwiseOpsFusion) {
2253 auto *context = patterns.getContext();
2254 patterns.add<FuseElementwiseOps>(arg&: context, args: controlElementwiseOpsFusion);
2255 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2256 RemoveOutsDependency>(arg&: context);
2257 // Add the patterns that clean up dead operands and results.
2258 populateEraseUnusedOperandsAndResultsPatterns(patterns);
2259}
2260
2261void mlir::linalg::populateCollapseDimensions(
2262 RewritePatternSet &patterns,
2263 const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2264 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2265 CollapseLinalgDimensions<linalg::CopyOp>>(
2266 patterns.getContext(), controlCollapseDimensions);
2267}
2268
2269//===---------------------------------------------------------------------===//
2270// Passes
2271//===---------------------------------------------------------------------===//
2272
2273namespace {
2274
2275/// Pass that fuses generic ops on tensors. Used only for testing.
2276// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2277// patterns added here heavily depends on the cost function used. Having an
2278// opinionated pass of this form is not recommended. Deprecate this pass in
2279// favor of test passes that check the functionality of each of the patterns
2280// added here individually.
2281struct LinalgElementwiseOpFusionPass
2282 : public impl::LinalgElementwiseOpFusionPassBase<
2283 LinalgElementwiseOpFusionPass> {
2284 using impl::LinalgElementwiseOpFusionPassBase<
2285 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2286 void runOnOperation() override {
2287 Operation *op = getOperation();
2288 MLIRContext *context = op->getContext();
2289 RewritePatternSet patterns(context);
2290
2291 // Add folding with reshape by expansion patterns.
2292 ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2293 Operation *producer = fusedOperand->get().getDefiningOp();
2294 return producer && producer->hasOneUse();
2295 };
2296
2297 // Add elementwise op fusion patterns.
2298 populateElementwiseOpsFusionPatterns(patterns, controlElementwiseOpsFusion: defaultControlFn);
2299 populateFoldReshapeOpsByExpansionPatterns(patterns, controlFoldingReshapes: defaultControlFn);
2300 tensor::populateBubbleUpExpandShapePatterns(patterns);
2301
2302 // General canonicalization patterns.
2303 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2304 GenericOp::getCanonicalizationPatterns(patterns, context);
2305 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2306 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2307 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2308 patterns);
2309
2310 // Add constant folding patterns.
2311 populateConstantFoldLinalgOperations(patterns, controlFn: defaultControlFn);
2312
2313 // Use TopDownTraversal for compile time reasons.
2314 (void)applyPatternsGreedily(op, std::move(patterns),
2315 GreedyRewriteConfig().setUseTopDownTraversal());
2316 }
2317};
2318
2319} // namespace
2320

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp