1//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Fusion on tensors operations pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Linalg/Passes.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20#include "mlir/IR/AffineExpr.h"
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Support/LLVM.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include <optional>
27#include <utility>
28
29namespace mlir {
30#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
31#include "mlir/Dialect/Linalg/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35using namespace mlir::linalg;
36
37//===---------------------------------------------------------------------===//
38// Methods and patterns that fuse elementwise `linalg.generic` operations.
39//===---------------------------------------------------------------------===//
40
41/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
42/// the `producer` to use in the fused operation given the indexing map of the
43/// result of the producer in the consumer.
44static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
45 OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
46 AffineMap fusedConsumerArgIndexMap) {
47 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
48 // from consumer loop -> consumer arg tensor index/producer result tensor
49 // index. The fused loop is same as the consumer loop. For each producer arg
50 // the indexing map to be computed is a map from consumer loop -> producer
51 // arg tensor index.
52 // producerResultIndexMap is a map from producer loop -> tensor index.
53 // Compute the inverse to get map from tensor index -> producer loop.
54 // The inverse is a map from producer result tensor index -> producer loop.
55 AffineMap invProducerResultIndexMap =
56 inversePermutation(map: producerResultIndexMap);
57 assert(invProducerResultIndexMap &&
58 "expected producer result indexing map to be invertible");
59
60 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
61 // argMap is a map from producer loop -> producer arg tensor index.
62 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
63
64 // Compose argMap with invProducerResultIndexMap to get a map from
65 // producer result tensor index -> producer arg tensor index.
66 AffineMap t1 = argMap.compose(map: invProducerResultIndexMap);
67
68 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
69 // consumer loop/ fused loop -> producer arg tensor index.
70 return t1.compose(map: fusedConsumerArgIndexMap);
71}
72
73/// Returns a set of indices of the producer's results which would
74/// be preserved after the fusion.
75llvm::SmallDenseSet<int>
76ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
77 GenericOp consumer) {
78 llvm::SmallDenseSet<int> preservedProducerResults;
79 for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
80 auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
81 if (producer.payloadUsesValueFromOperand(outputOperand) ||
82 !producer.canOpOperandsBeDropped(outputOperand) ||
83 llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
84 return user != consumer.getOperation();
85 })) {
86 preservedProducerResults.insert(producerResult.index());
87 }
88 }
89 return preservedProducerResults;
90}
91
92/// Conditions for elementwise fusion of generic operations.
93bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
94 if (!fusedOperand)
95 return false;
96
97 auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
98 auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
99
100 // Check producer and consumer are generic ops.
101 if (!producer || !consumer)
102 return false;
103
104 // Consumer can have mixed semantics, just check operand itself has tensor
105 // type. Producer must have full tensor semantics to avoid potential
106 // aliasing between producer and consumer memrefs.
107 if (!producer.hasPureTensorSemantics() ||
108 !isa<RankedTensorType>(Val: fusedOperand->get().getType()))
109 return false;
110
111 // Verify that
112 // - the producer has all "parallel" iterator type.
113 if (producer.getNumParallelLoops() != producer.getNumLoops())
114 return false;
115
116 // Only allow fusing the producer of an input operand for now.
117 // TODO: allow fusing the producer of an output operand.
118 if (!consumer.isDpsInput(fusedOperand))
119 return false;
120
121 // Get the consumer index map. The number of results of the consumer index
122 // map must match the number of loops of the producer.
123 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
124 if (consumerIndexMap.getNumResults() != producer.getNumLoops())
125 return false;
126
127 // Finally the index_map for the result must be invertible. For now just
128 // verify it is a permutation.
129 AffineMap producerResultIndexMap =
130 producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
131 if (!producerResultIndexMap.isPermutation())
132 return false;
133
134 // Ensure that the fusion does not remove size information required to
135 // get the loop bounds. For non-reduction generics, this is trivially the
136 // case due to the output operand. For reductions, we need to check that after
137 // the fusion, each loop dimension has at least one input that defines it.
138 if ((consumer.getNumReductionLoops())) {
139 BitVector coveredDims(consumer.getNumLoops(), false);
140
141 auto addToCoveredDims = [&](AffineMap map) {
142 for (auto result : map.getResults())
143 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: result))
144 coveredDims[dimExpr.getPosition()] = true;
145 };
146
147 for (auto pair :
148 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
149 Value operand = std::get<0>(pair);
150 if (operand == fusedOperand->get())
151 continue;
152 AffineMap operandMap = std::get<1>(pair);
153 addToCoveredDims(operandMap);
154 }
155
156 for (OpOperand *operand : producer.getDpsInputOperands()) {
157 AffineMap newIndexingMap =
158 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
159 operand, producerResultIndexMap, consumerIndexMap);
160 addToCoveredDims(newIndexingMap);
161 }
162 if (!coveredDims.all())
163 return false;
164 }
165
166 return true;
167}
168
169/// Generate the region of the fused tensor operation. The region of the fused
170/// op must be empty.
171static void generateFusedElementwiseOpRegion(
172 RewriterBase &rewriter, GenericOp fusedOp,
173 AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
174 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
175 auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
176 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
177 // Build the region of the fused op.
178 Block &producerBlock = producer->getRegion(0).front();
179 Block &consumerBlock = consumer->getRegion(0).front();
180 OpBuilder::InsertionGuard guard(rewriter);
181 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
182 IRMapping mapper;
183
184 // 2. Add an index operation for every fused loop dimension and use the
185 // `consumerToProducerLoopsMap` to map the producer indices.
186 if (producer.hasIndexSemantics()) {
187 // Add an index operation for every fused loop dimension.
188 unsigned numFusedOpLoops =
189 std::max(producer.getNumLoops(), consumer.getNumLoops());
190 SmallVector<Value> fusedIndices;
191 fusedIndices.reserve(N: numFusedOpLoops);
192 llvm::transform(Range: llvm::seq<uint64_t>(Begin: 0, End: numFusedOpLoops),
193 d_first: std::back_inserter(x&: fusedIndices), F: [&](uint64_t dim) {
194 return rewriter.create<IndexOp>(producer.getLoc(), dim);
195 });
196 for (IndexOp indexOp :
197 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
198 Value newIndex = rewriter.create<affine::AffineApplyOp>(
199 producer.getLoc(),
200 consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
201 mapper.map(indexOp.getResult(), newIndex);
202 }
203 }
204 // TODO: allow fusing the producer of an output operand.
205 assert(consumer.isDpsInput(fusedOperand) &&
206 "expected producer of input operand");
207 // 3. Consumer input operands up to consumerIdx (exclusive).
208 for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
209 fusedOperand->getOperandNumber())) // input assumption.
210 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
211
212 // Replacing consumerIdx requires getting the cloned, yielded, value from
213 // the (cloned) producer block. This happens in step 9.
214
215 // 4. Splice in producer's input operands.
216 for (BlockArgument bbArg :
217 producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
218 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
219
220 // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
221 for (BlockArgument bbArg :
222 consumerBlock.getArguments()
223 .take_front(consumer.getNumDpsInputs())
224 .drop_front(fusedOperand->getOperandNumber() + 1))
225 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
226
227 // 6. All of the producer's output operands
228 for (const auto &bbArg : llvm::enumerate(
229 producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
230 if (!preservedProducerResults.count(bbArg.index()))
231 continue;
232 mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
233 bbArg.value().getLoc()));
234 }
235
236 // 7. All of consumer's output operands.
237 for (BlockArgument bbArg :
238 consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
239 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
240
241 // 8. Clone all producer operations except for the yield and index operations
242 // to the fused operation.
243 for (auto &op : producerBlock.without_terminator()) {
244 if (!isa<IndexOp>(op))
245 rewriter.clone(op, mapper);
246 }
247 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
248 // forward the yield operand.
249 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
250 unsigned producerResultNumber =
251 cast<OpResult>(Val: fusedOperand->get()).getResultNumber();
252 Value replacement =
253 mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
254
255 // Sanity checks, if replacement is not already in the mapper then it must be
256 // produced outside.
257 if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
258 if (auto bb = dyn_cast<BlockArgument>(replacement))
259 assert(bb.getOwner() != &producerBlock &&
260 "yielded block argument must have been mapped");
261 else
262 assert(!producer->isAncestor(replacement.getDefiningOp()) &&
263 "yielded value must have been mapped");
264 }
265 mapper.map(from: consumerBlock.getArgument(i: fusedOperand->getOperandNumber()),
266 to: replacement);
267 // 10. Clone operations from the consumer to the fused op.
268 for (auto &op : consumerBlock.without_terminator())
269 rewriter.clone(op, mapper);
270
271 // 11. Include the final yield (which is the remapped values for all the
272 // yield)
273 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
274 SmallVector<Value> fusedYieldValues;
275 fusedYieldValues.reserve(N: producerYieldOp.getNumOperands() +
276 consumerYieldOp.getNumOperands());
277 for (const auto &producerYieldVal :
278 llvm::enumerate(producerYieldOp.getOperands())) {
279 if (preservedProducerResults.count(producerYieldVal.index()))
280 fusedYieldValues.push_back(
281 mapper.lookupOrDefault(producerYieldVal.value()));
282 }
283 for (auto consumerYieldVal : consumerYieldOp.getOperands())
284 fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
285 rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
286
287 // Sanity checks.
288 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
289 "Ill-formed GenericOp region");
290}
291
292FailureOr<mlir::linalg::ElementwiseOpFusionResult>
293mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
294 OpOperand *fusedOperand) {
295 assert(areElementwiseOpsFusable(fusedOperand) &&
296 "expected elementwise operation pre-conditions to pass");
297 auto producerResult = cast<OpResult>(Val: fusedOperand->get());
298 auto producer = cast<GenericOp>(producerResult.getOwner());
299 auto consumer = cast<GenericOp>(fusedOperand->getOwner());
300 // TODO: allow fusing the producer of an output operand.
301 assert(consumer.isDpsInput(fusedOperand) &&
302 "expected producer of input operand");
303 /// Find the results of the producer that have uses outside of the consumer.
304 llvm::SmallDenseSet<int> preservedProducerResults =
305 ElementwiseOpFusionResult::getPreservedProducerResults(producer,
306 consumer);
307
308 // Compute the fused operands list and indexing maps.
309 SmallVector<Value> fusedInputOperands, fusedOutputOperands;
310 SmallVector<Type> fusedResultTypes;
311 SmallVector<AffineMap> fusedIndexMaps;
312 fusedInputOperands.reserve(N: producer.getNumDpsInputs() +
313 consumer.getNumDpsInputs());
314 fusedOutputOperands.reserve(N: preservedProducerResults.size() +
315 consumer.getNumDpsInits());
316 fusedResultTypes.reserve(N: preservedProducerResults.size() +
317 consumer.getNumDpsInits());
318 fusedIndexMaps.reserve(N: producer->getNumOperands() +
319 consumer->getNumOperands());
320 // In the following, numbering matches that of `generateFusedTensorOpRegion`.
321 // 3. Consumer input operands/maps up to consumerIdx (exclusive).
322 auto consumerInputs = consumer.getDpsInputOperands();
323 auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
324 return operand == fusedOperand;
325 });
326 assert(it != consumerInputs.end() && "expected to find the consumer operand");
327 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
328 fusedInputOperands.push_back(opOperand->get());
329 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
330 }
331 // 4. Splice in producer's input operands/maps.
332 AffineMap producerResultIndexMap =
333 producer.getIndexingMapMatchingResult(producerResult);
334 for (OpOperand *opOperand : producer.getDpsInputOperands()) {
335 fusedInputOperands.push_back(opOperand->get());
336 // Compute indexing maps for the producer args in the fused operation.
337 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
338 opOperand, producerResultIndexMap,
339 consumer.getMatchingIndexingMap(fusedOperand));
340 fusedIndexMaps.push_back(map);
341 }
342 // 5. Remaining consumer's input operands/maps (drop past index
343 // `consumerIdx`).
344 for (OpOperand *opOperand :
345 llvm::make_range(std::next(it), consumerInputs.end())) {
346 fusedInputOperands.push_back(opOperand->get());
347 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
348 }
349
350 // 6. Collect all of the producer outputs.
351 for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
352 if (!preservedProducerResults.count(opOperand.index()))
353 continue;
354
355 fusedOutputOperands.push_back(opOperand.value().get());
356 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
357 &opOperand.value(), producerResultIndexMap,
358 consumer.getMatchingIndexingMap(fusedOperand));
359 fusedIndexMaps.push_back(map);
360 fusedResultTypes.push_back(opOperand.value().get().getType());
361 }
362
363 // 7. All of consumer's output operands (skip operands: added by the builder).
364 for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
365 fusedOutputOperands.push_back(opOperand.get());
366 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
367 Type resultType = opOperand.get().getType();
368 if (!isa<MemRefType>(resultType))
369 fusedResultTypes.push_back(resultType);
370 }
371
372 // Generate the fused op.
373 auto fusedOp = rewriter.create<GenericOp>(
374 consumer.getLoc(), fusedResultTypes, fusedInputOperands,
375 fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
376 consumer.getIteratorTypes(),
377 /*doc=*/nullptr,
378 /*library_call=*/nullptr);
379 if (!fusedOp.getShapesToLoopsMap()) {
380 // Fused op has invalid indexing maps. Typically this means something is off
381 // in the input, but going ahead here would result in verification errors.
382 // So cleanup and abort.
383 rewriter.eraseOp(op: fusedOp);
384 return rewriter.notifyMatchFailure(
385 fusedOp, "fused op failed loop bound computation check");
386 }
387
388 // Construct an AffineMap from consumer loops to producer loops.
389 // consumer loop -> tensor index
390 AffineMap consumerResultIndexMap =
391 consumer.getMatchingIndexingMap(fusedOperand);
392 // tensor index -> producer loop
393 AffineMap invProducerResultIndexMap =
394 inversePermutation(map: producerResultIndexMap);
395 assert(invProducerResultIndexMap &&
396 "expected producer result indexig map to be invertible");
397 // consumer loop -> producer loop
398 AffineMap consumerToProducerLoopsMap =
399 invProducerResultIndexMap.compose(map: consumerResultIndexMap);
400
401 generateFusedElementwiseOpRegion(
402 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
403 consumer.getNumLoops(), preservedProducerResults);
404 ElementwiseOpFusionResult result;
405 result.fusedOp = fusedOp;
406 int resultNum = 0;
407 for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
408 if (preservedProducerResults.count(index))
409 result.replacements[producerResult] = fusedOp->getResult(resultNum++);
410 for (auto consumerResult : consumer->getResults())
411 result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
412 return result;
413}
414
415namespace {
416/// Patterns to fuse a generic op, with the producer of its operands.
417class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
418public:
419 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
420 PatternBenefit benefit = 1)
421 : OpRewritePattern<GenericOp>(context, benefit),
422 controlFn(std::move(fun)) {}
423
424 LogicalResult matchAndRewrite(GenericOp genericOp,
425 PatternRewriter &rewriter) const override {
426 // Find the first operand that is defined by another generic op on tensors.
427 for (OpOperand &opOperand : genericOp->getOpOperands()) {
428 if (!areElementwiseOpsFusable(&opOperand))
429 continue;
430 if (!controlFn(&opOperand))
431 continue;
432
433 Operation *producer = opOperand.get().getDefiningOp();
434
435 // Do not fuse a sparse-in/dense-out operation, as the
436 // result is too often not sparsifiable anymore.
437 if (sparse_tensor::hasAnySparseOperand(producer) &&
438 !sparse_tensor::hasAnySparseResult(producer))
439 return failure();
440
441 // Find the producer of the operand.
442 FailureOr<ElementwiseOpFusionResult> fusionResult =
443 fuseElementwiseOps(rewriter, &opOperand);
444 if (failed(fusionResult))
445 return rewriter.notifyMatchFailure(genericOp, "fusion failed");
446
447 // Perform the fusion.
448 for (auto [origVal, replacement] : fusionResult->replacements) {
449 rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
450 // Only replace consumer uses.
451 return use.get().getDefiningOp() != producer;
452 });
453 }
454 rewriter.eraseOp(genericOp);
455 return success();
456 }
457 return failure();
458 }
459
460private:
461 ControlFusionFn controlFn;
462};
463} // namespace
464
465//===---------------------------------------------------------------------===//
466// Methods and patterns that fuse reshape ops with elementwise operations by
467// expanding the dimensionality of the elementwise operations.
468//===---------------------------------------------------------------------===//
469
470/// Conditions for folding a structured linalg operation with a reshape op by
471/// expanding the iteration space dimensionality for tensor operations. These
472/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
473/// the following fusion pattern.
474///
475/// Consider
476///
477/// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
478/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
479/// affine_map<(d0, d1, d2) -> (d1, d2)>,
480/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
481/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
482/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
483///
484/// The reshape can be folded into the `linalgOp` if its loop dimensionality
485/// is increased to match the result (operand) of the tensor.expand_shape.
486/// The indexing_map of the fused tensor in the `linalgOp` and the
487/// reassociation map helps compute the indexing maps of the modified op.
488/// For the above example, based on the reassociation map it
489/// can be concluded that
490///
491/// - The loop used to access the first dimension of the fused tensor is split
492/// into two.
493/// - The loop used to access the second dimension of the fused tensor is kept
494/// as is.
495/// - The loop used to access the third dimension of the fused tensor is split
496/// into three.
497///
498/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
499/// op, then
500///
501/// d0 -> e0, e1
502/// d1 -> e2, e3, e4
503/// d2 -> e5
504///
505/// substituting this, the structured op can be rewritten as
506///
507/// %d = linalg.generic ins(%0, %1 : )
508/// indexing_maps =
509/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
510/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
511/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
512///
513/// Since operands to the linalg generic are now 5D, reshapes can be introduced
514/// to make it consistent
515///
516/// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
517/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
518/// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
519/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
520///
521/// The added reshapes are again expanding patterns, so they will get fused
522/// with its producers if possible.
523static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
524 OpOperand *fusableOpOperand) {
525 // Is fusable only if:
526 // - All the indexing maps for operands and results are projected
527 // permutations.
528 // - The fused tensor is not a scalar.
529 // - All the loops for the reshaped operand are parallel loops.
530 SmallVector<utils::IteratorType> iteratorTypes =
531 linalgOp.getIteratorTypesArray();
532 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
533 return linalgOp.hasPureTensorSemantics() &&
534 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
535 [](Attribute attr) {
536 return cast<AffineMapAttr>(attr)
537 .getValue()
538 .isProjectedPermutation();
539 }) &&
540 operandMap.getNumResults() > 0 &&
541 llvm::all_of(Range: operandMap.getResults(), P: [&](AffineExpr expr) {
542 return isParallelIterator(
543 iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
544 });
545}
546
547namespace {
548/// Information needed to expand a generic operation to fold the reshape with
549/// it.
550class ExpansionInfo {
551public:
552 // Computes the mapping from original dimensions of the op to the dimensions
553 // of the expanded op given the `indexingMap` of the fused operand/result of
554 // the generic op, the `reassocationMaps` of the reshape op and the shape of
555 // the expanded op.
556 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
557 ArrayRef<AffineMap> reassociationMaps,
558 ArrayRef<int64_t> expandedShape,
559 ArrayRef<int64_t> collapsedShape,
560 PatternRewriter &rewriter);
561 unsigned getOrigOpNumDims() const { return reassociation.size(); }
562 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
563 ReassociationIndicesRef getExpandedDims(unsigned i) const {
564 return reassociation[i];
565 }
566 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
567 return expandedShapeMap[i];
568 }
569 ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
570
571private:
572 /// Reassociation from the dimensions in the original operation to the
573 /// dimension of the expanded operation.
574 SmallVector<ReassociationIndices> reassociation;
575 /// Mapping from extent of loops in the original operation, to the extent of
576 /// loops in the expanded operation.
577 SmallVector<SmallVector<int64_t>> expandedShapeMap;
578 /// Extent of the loop in the original operation.
579 SmallVector<int64_t> originalLoopExtent;
580 unsigned expandedOpNumDims;
581};
582} // namespace
583
584LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
585 OpOperand *fusableOpOperand,
586 ArrayRef<AffineMap> reassociationMaps,
587 ArrayRef<int64_t> expandedShape,
588 ArrayRef<int64_t> collapsedShape,
589 PatternRewriter &rewriter) {
590 if (reassociationMaps.empty())
591 return failure();
592 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
593
594 SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
595 originalLoopExtent.assign(in_start: originalLoopRange.begin(), in_end: originalLoopRange.end());
596
597 reassociation.clear();
598 expandedShapeMap.clear();
599 // Compute the number of dimension in the expanded op that correspond to each
600 // dimension of the original op.
601 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
602 expandedShapeMap.resize(N: fusedIndexMap.getNumDims());
603 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
604 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
605 AffineMap foldedDims = reassociationMaps[resultExpr.index()];
606 numExpandedDims[pos] = foldedDims.getNumResults();
607 ArrayRef<int64_t> shape =
608 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
609 expandedShapeMap[pos].assign(shape.begin(), shape.end());
610 }
611 // The remaining dimensions remain the same.
612 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
613 if (expandedShapeMap[i].empty())
614 expandedShapeMap[i] = {originalLoopExtent[i]};
615
616 // Compute reassociation map from the original op to the expanded op.
617 unsigned sum = 0;
618 reassociation.reserve(N: fusedIndexMap.getNumDims());
619 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
620 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
621 reassociation.emplace_back(seq.begin(), seq.end());
622 sum += numFoldedDim.value();
623 }
624 expandedOpNumDims = sum;
625 return success();
626}
627
628/// Epanding the body of a linalg operation requires adaptations of the accessed
629/// loop indices. Specifically, access of indices in the original operation need
630/// to be replaced with linearizations of indices in the expanded op. That
631/// requires the shape of the expanded dimensions to be static (at least all but
632/// the most significant). For now check that these are all statically sized.
633/// Note that this could be extended to handle dynamic case, but the
634/// implementation below uses `affine.apply` which seems to have issues when the
635/// shapes are not static.
636static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
637 const ExpansionInfo &expansionInfo,
638 PatternRewriter &rewriter) {
639 if (!linalgOp.hasIndexSemantics())
640 return success();
641 for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: expansionInfo.getOrigOpNumDims())) {
642 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
643 if (expandedShape.size() == 1)
644 continue;
645 for (int64_t shape : expandedShape.drop_front()) {
646 if (ShapedType::isDynamic(shape)) {
647 return rewriter.notifyMatchFailure(
648 linalgOp, "cannot expand due to index semantics and dynamic dims");
649 }
650 }
651 }
652 return success();
653}
654
655/// Return the indexing map to use in the expanded op for a given the
656/// `indexingMap` of the original operation.
657static AffineMap
658getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
659 const ExpansionInfo &expansionInfo) {
660 SmallVector<AffineExpr> newExprs;
661 for (AffineExpr expr : indexingMap.getResults()) {
662 unsigned pos = cast<AffineDimExpr>(Val&: expr).getPosition();
663 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
664 Range: llvm::map_range(C: expansionInfo.getExpandedDims(i: pos), F: [&](int64_t v) {
665 return builder.getAffineDimExpr(position: static_cast<unsigned>(v));
666 }));
667 newExprs.append(in_start: expandedExprs.begin(), in_end: expandedExprs.end());
668 }
669 return AffineMap::get(dimCount: expansionInfo.getExpandedOpNumDims(),
670 symbolCount: indexingMap.getNumSymbols(), results: newExprs,
671 context: builder.getContext());
672}
673
674/// Return the type of the operand/result to use in the expanded op given the
675/// type in the original op.
676static RankedTensorType getExpandedType(RankedTensorType originalType,
677 AffineMap indexingMap,
678 const ExpansionInfo &expansionInfo) {
679 SmallVector<int64_t> expandedShape;
680 for (AffineExpr expr : indexingMap.getResults()) {
681 unsigned dim = cast<AffineDimExpr>(Val&: expr).getPosition();
682 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(i: dim);
683 expandedShape.append(in_start: dimExpansion.begin(), in_end: dimExpansion.end());
684 }
685 return RankedTensorType::get(expandedShape, originalType.getElementType());
686}
687
688/// Returns the reassociation maps to use in the `tensor.expand_shape`
689/// operation to convert the operands of the original operation to operands of
690/// the expanded operation. The same method is used to compute the
691/// `tensor.collapse_shape` used to collapse the result of the expanded
692/// op to get the value that can replace all uses of the results of the original
693/// op.
694static SmallVector<ReassociationIndices>
695getReassociationForExpansion(AffineMap indexingMap,
696 const ExpansionInfo &expansionInfo) {
697 SmallVector<ReassociationIndices> reassociation;
698 unsigned numReshapeDims = 0;
699 for (AffineExpr expr : indexingMap.getResults()) {
700 unsigned dim = cast<AffineDimExpr>(Val&: expr).getPosition();
701 auto numExpandedDims = expansionInfo.getExpandedDims(i: dim).size();
702 SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
703 Range: llvm::seq<int64_t>(Begin: numReshapeDims, End: numReshapeDims + numExpandedDims));
704 reassociation.emplace_back(Args: std::move(indices));
705 numReshapeDims += numExpandedDims;
706 }
707 return reassociation;
708}
709
710/// Update the body of an expanded linalg operation having index semantics. The
711/// indices of the original operation need to be recovered by linearizing the
712/// indices of the correspoding dimensions of the expanded operation. For now it
713/// is assumed that the shapes of the expanded operation needed for
714/// linearization are static.
715static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
716 Location loc, Region &fusedRegion,
717 const ExpansionInfo &expansionInfo) {
718 // Replace the original indices by the linearization of the expanded indices.
719 for (IndexOp indexOp :
720 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
721 ArrayRef<int64_t> expandedDims =
722 expansionInfo.getExpandedDims(indexOp.getDim());
723 assert(!expandedDims.empty() && "expected valid expansion info");
724
725 // Skip index operations that are not affected by the expansion.
726 if (expandedDims.size() == 1 &&
727 expandedDims.front() == (int64_t)indexOp.getDim())
728 continue;
729
730 // Linearize the expanded indices of the original index dimension.
731 OpBuilder::InsertionGuard guard(rewriter);
732 rewriter.setInsertionPointAfter(indexOp);
733 ArrayRef<int64_t> expandedDimsShape =
734 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
735 SmallVector<Value> expandedIndices;
736 expandedIndices.reserve(expandedDims.size() - 1);
737 llvm::transform(
738 expandedDims.drop_front(), std::back_inserter(expandedIndices),
739 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
740 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
741 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
742 assert(!ShapedType::isDynamic(std::get<0>(it)));
743 AffineExpr idx, acc;
744 bindDims(rewriter.getContext(), idx, acc);
745 newIndex = rewriter.create<affine::AffineApplyOp>(
746 indexOp.getLoc(), idx + acc * std::get<0>(it),
747 ValueRange{std::get<1>(it), newIndex});
748 }
749 rewriter.replaceOp(indexOp, newIndex);
750 }
751}
752
753/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
754/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
755/// that those conditions have been satisfied.
756static std::optional<SmallVector<Value>>
757fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
758 OpOperand *fusableOpOperand,
759 PatternRewriter &rewriter) {
760 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
761 "preconditions for fuse operation failed");
762 // Check if reshape is expanding or collapsing.
763 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
764 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
765 bool isExpanding = (expandingReshapeOp != nullptr);
766 RankedTensorType expandedType = isExpanding
767 ? expandingReshapeOp.getResultType()
768 : collapsingReshapeOp.getSrcType();
769 RankedTensorType collapsedType = isExpanding
770 ? expandingReshapeOp.getSrcType()
771 : collapsingReshapeOp.getResultType();
772
773 ExpansionInfo expansionInfo;
774 if (failed(expansionInfo.compute(
775 linalgOp: linalgOp, fusableOpOperand,
776 reassociationMaps: isExpanding ? expandingReshapeOp.getReassociationMaps()
777 : collapsingReshapeOp.getReassociationMaps(),
778 expandedShape: expandedType.getShape(), collapsedShape: collapsedType.getShape(), rewriter)))
779 return std::nullopt;
780
781 if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
782 return std::nullopt;
783
784 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
785 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
786 return getIndexingMapInExpandedOp(builder&: rewriter, indexingMap: m, expansionInfo);
787 }));
788
789 // Set insertion point to the generic op.
790 OpBuilder::InsertionGuard g(rewriter);
791 rewriter.setInsertionPoint(linalgOp);
792
793 SmallVector<Value> expandedOpOperands;
794 expandedOpOperands.reserve(N: linalgOp.getNumDpsInputs());
795 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
796 if (opOperand == fusableOpOperand) {
797 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
798 : collapsingReshapeOp.getSrc());
799 continue;
800 }
801 if (auto opOperandType =
802 dyn_cast<RankedTensorType>(opOperand->get().getType())) {
803 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
804 RankedTensorType expandedOperandType =
805 getExpandedType(opOperandType, indexingMap, expansionInfo);
806 if (expandedOperandType != opOperand->get().getType()) {
807 // Reshape the operand to get the right type.
808 SmallVector<ReassociationIndices> reassociation =
809 getReassociationForExpansion(indexingMap, expansionInfo);
810 if (failed(reshapeLikeShapesAreCompatible(
811 [&](const Twine &msg) {
812 return rewriter.notifyMatchFailure(linalgOp, msg);
813 },
814 opOperandType.getShape(), expandedOperandType.getShape(),
815 reassociation,
816 /*isExpandingReshape=*/true)))
817 return std::nullopt;
818 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
819 linalgOp.getLoc(), expandedOperandType, opOperand->get(),
820 reassociation));
821 continue;
822 }
823 }
824 expandedOpOperands.push_back(opOperand->get());
825 }
826
827 Location loc = linalgOp.getLoc();
828 SmallVector<Value> outputs;
829 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
830 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
831 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
832 RankedTensorType expandedOutputType =
833 getExpandedType(opOperandType, indexingMap, expansionInfo);
834 if (expandedOutputType != opOperand.get().getType()) {
835 SmallVector<ReassociationIndices> reassociation =
836 getReassociationForExpansion(indexingMap, expansionInfo);
837 if (failed(reshapeLikeShapesAreCompatible(
838 [&](const Twine &msg) {
839 return rewriter.notifyMatchFailure(linalgOp, msg);
840 },
841 opOperandType.getShape(), expandedOutputType.getShape(),
842 reassociation,
843 /*isExpandingReshape=*/true)))
844 return std::nullopt;
845 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
846 linalgOp.getLoc(), expandedOutputType, opOperand.get(),
847 reassociation));
848 } else {
849 outputs.push_back(opOperand.get());
850 }
851 }
852
853 // The iterator types of the expanded op are all parallel.
854 SmallVector<utils::IteratorType> iteratorTypes(
855 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
856 for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
857 for (auto j : expansionInfo.getExpandedDims(i))
858 iteratorTypes[j] = type;
859
860 TypeRange resultTypes = ValueRange(outputs).getTypes();
861 auto fusedOp =
862 rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
863 /*inputs=*/expandedOpOperands, outputs,
864 expandedOpIndexingMaps, iteratorTypes);
865 Region &fusedRegion = fusedOp->getRegion(0);
866 Region &originalRegion = linalgOp->getRegion(0);
867 rewriter.cloneRegionBefore(region&: originalRegion, parent&: fusedRegion, before: fusedRegion.begin());
868
869 // Update the index accesses after the expansion.
870 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
871
872 // Reshape the result values to their original shape if this is a collapsing
873 // reshape folded into its consumer.
874 SmallVector<Value> resultVals;
875 for (OpResult opResult : linalgOp->getOpResults()) {
876 int64_t resultNumber = opResult.getResultNumber();
877 if (resultTypes[resultNumber] != opResult.getType()) {
878 SmallVector<ReassociationIndices> reassociation =
879 getReassociationForExpansion(
880 linalgOp.getMatchingIndexingMap(
881 linalgOp.getDpsInitOperand(resultNumber)),
882 expansionInfo);
883 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
884 linalgOp.getLoc(), opResult.getType(),
885 fusedOp->getResult(resultNumber), reassociation));
886 } else {
887 resultVals.push_back(fusedOp->getResult(resultNumber));
888 }
889 }
890 // Assuming a single result.
891 return resultVals;
892}
893
894namespace {
895
896/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
897/// when the reshape op is collapsing dimensions. The dimensionality of the loop
898/// in the consumer is expanded.
899class FoldWithProducerReshapeOpByExpansion
900 : public OpInterfaceRewritePattern<LinalgOp> {
901public:
902 FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
903 ControlFusionFn foldReshapes,
904 PatternBenefit benefit = 1)
905 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
906 controlFoldingReshapes(std::move(foldReshapes)) {}
907
908 LogicalResult matchAndRewrite(LinalgOp linalgOp,
909 PatternRewriter &rewriter) const override {
910 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
911 tensor::CollapseShapeOp reshapeOp =
912 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
913 if (!reshapeOp)
914 continue;
915 // Fold only if
916 // - The tensor reshape op is folding.
917 // - All constraints of fusing with reshape by expansion are met.
918 if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
919 (!controlFoldingReshapes(opOperand)))
920 continue;
921
922 std::optional<SmallVector<Value>> replacementValues =
923 fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
924 if (!replacementValues)
925 return failure();
926 rewriter.replaceOp(linalgOp, *replacementValues);
927 return success();
928 }
929 return failure();
930 }
931
932private:
933 ControlFusionFn controlFoldingReshapes;
934};
935
936/// Pattern to fold a tensor.expand_shape op with its producer generic op
937/// by expanding the dimensionality of the loop in the producer op.
938struct FoldReshapeWithGenericOpByExpansion
939 : public OpRewritePattern<tensor::ExpandShapeOp> {
940
941 FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
942 ControlFusionFn foldReshapes,
943 PatternBenefit benefit = 1)
944 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
945 controlFoldingReshapes(std::move(foldReshapes)) {}
946
947 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
948 PatternRewriter &rewriter) const override {
949 // Fold only if all constraints of fusing with reshape by expansion are met.
950 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
951 if (!producerResult) {
952 return rewriter.notifyMatchFailure(reshapeOp,
953 "source not produced by an operation");
954 }
955
956 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
957 if (!producer) {
958 return rewriter.notifyMatchFailure(reshapeOp,
959 "producer not a generic op");
960 }
961
962 if (!isFusableWithReshapeByDimExpansion(
963 producer,
964 producer.getDpsInitOperand(producerResult.getResultNumber()))) {
965 return rewriter.notifyMatchFailure(
966 reshapeOp, "failed preconditions of fusion with producer generic op");
967 }
968
969 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
970 return rewriter.notifyMatchFailure(reshapeOp,
971 "fusion blocked by control function");
972 }
973
974 std::optional<SmallVector<Value>> replacementValues =
975 fuseWithReshapeByExpansion(
976 producer, reshapeOp,
977 producer.getDpsInitOperand(producerResult.getResultNumber()),
978 rewriter);
979 if (!replacementValues) {
980 return rewriter.notifyMatchFailure(reshapeOp,
981 "fusion by expansion failed");
982 }
983
984 // Find the replacement for the reshape op. Since the replacements have the
985 // same type as the returns of the original generic op, the consumer reshape
986 // op can be replaced by the source of the collapse_shape op that defines
987 // the replacement.
988 Value reshapeReplacement =
989 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
990 .getResultNumber()];
991 if (auto collapseOp =
992 reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
993 reshapeReplacement = collapseOp.getSrc();
994 }
995 rewriter.replaceOp(reshapeOp, reshapeReplacement);
996 rewriter.replaceOp(producer, *replacementValues);
997 return success();
998 }
999
1000private:
1001 ControlFusionFn controlFoldingReshapes;
1002};
1003} // namespace
1004
1005//===---------------------------------------------------------------------===//
1006// Methods and patterns to fuse reshape with linalg.generic operations by
1007// contraction of dimensions.
1008//===---------------------------------------------------------------------===//
1009
1010/// For a given list of indices in the range of the `indexingMap` that are
1011/// folded, return the indices of the corresponding domain. Return
1012/// `std::nullopt` on failure. Ensures that all the elements of the returned
1013/// reassociation are distinct.
1014static ReassociationIndices
1015getDomainReassociation(AffineMap indexingMap,
1016 ReassociationIndicesRef rangeReassociation) {
1017 assert(indexingMap.isProjectedPermutation() &&
1018 "expected projected permutation");
1019
1020 ReassociationIndices domainReassociation = llvm::to_vector<4>(
1021 Range: llvm::map_range(C&: rangeReassociation, F: [&](int64_t pos) -> int64_t {
1022 return cast<AffineDimExpr>(Val: indexingMap.getResults()[pos]).getPosition();
1023 }));
1024 // The projected permutation semantics ensures that there is no repetition of
1025 // the domain indices.
1026 return domainReassociation;
1027}
1028
1029/// For a given `dimSequence`, check if the sequence is conserved in the
1030/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1031/// Non-existence of the sequence returns true as well.
1032bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
1033 ReassociationIndicesRef dimSequence) {
1034 assert(!dimSequence.empty() &&
1035 "expected non-empty list for dimension sequence");
1036 assert(indexingMap.isProjectedPermutation() &&
1037 "expected indexing map to be projected permutation");
1038
1039 llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1040 sequenceElements.insert(I: dimSequence.begin(), E: dimSequence.end());
1041
1042 unsigned dimSequenceStart = dimSequence[0];
1043 for (const auto &expr : enumerate(First: indexingMap.getResults())) {
1044 unsigned dimInMapStart = cast<AffineDimExpr>(Val: expr.value()).getPosition();
1045 // 1. Check if this start of the sequence.
1046 if (dimInMapStart == dimSequenceStart) {
1047 if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1048 return false;
1049 // 1a. Check if sequence is preserved.
1050 for (const auto &dimInSequence : enumerate(First&: dimSequence)) {
1051 unsigned dimInMap =
1052 cast<AffineDimExpr>(
1053 Val: indexingMap.getResult(idx: expr.index() + dimInSequence.index()))
1054 .getPosition();
1055 if (dimInMap != dimInSequence.value())
1056 return false;
1057 }
1058 // Found the sequence. Projected permutation
1059 // enforces that all AffineDimExprs in the result are unique, so no
1060 // further checks are needed.
1061 return true;
1062 }
1063 // 2. If position in the expr (which is of type AffineDimExpr) is part
1064 // of sequence, return false here. This implies the entire sequence does not
1065 // exist in the indexing map.
1066 if (sequenceElements.count(V: dimInMapStart))
1067 return false;
1068 }
1069 // 3. No element of sequence found. Return true.
1070 return true;
1071}
1072
1073bool mlir::linalg::areDimSequencesPreserved(
1074 ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {
1075 return llvm::all_of(Range&: maps, P: [&](AffineMap map) {
1076 return llvm::all_of(Range&: dimSequences, P: [&](ReassociationIndicesRef dimSequence) {
1077 return isDimSequencePreserved(indexingMap: map, dimSequence);
1078 });
1079 });
1080}
1081
1082// Return the list of dimensions of the iteration domain that can be
1083// collapsed to allow for fusion with the a producer that is an expand_shape
1084// operation. If all dimensions created by expansion can be collapsed in the
1085// iteration space then the reshape is defunct.
1086//
1087// Example:
1088//
1089// ```mlir
1090// #map = affine_map<(d0, d1) -> (d0, d1)>
1091// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1092// %2 = tensor.empty [..] : tensor<?x4xf32>
1093// %3 = linalg.generic {
1094// indexing_maps = [#map, #map],
1095// iterator_types = ["parallel" ,"parallel"]}
1096// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1097// ```
1098//
1099// can be fused by collapsing the dimensions of the iteration space.
1100//
1101// ```mlir
1102// #map = affine_map<(d0) -> (d0)>
1103// %2 = tensor.empty [..] : tensor<?xf32>
1104// %3 = linalg.generic {
1105// indexing_maps = [#map, #map],
1106// iterator_types = ["parallel"]}
1107// ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1108// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1109// ```
1110//
1111// In the following example,
1112//
1113// ```mlir
1114// #map0 = affine_map<(d0, d1) -> (d0, d1)>
1115// #map1 = affine_map<(d0, d1) -> (d1, d0)>
1116// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1117// %2 = tensor.empty [..] : tensor<4x?xf32>
1118// %2 = linalg.generic {
1119// indexing_maps = [#map0, #map1],
1120// iterator_types = ["parallel" ,"parallel"]}
1121// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1122// ```
1123//
1124// the reshape cannot be fused with the generic op by collapsing the op
1125// dimensions since the indexing maps will have to contain mods and divs
1126// to preserve the accesses pattern. When no dimensions of the iteration
1127// space are collapsable and empty vector is returned.
1128static SmallVector<ReassociationIndices>
1129getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1130 ArrayRef<ReassociationIndices> reassociation) {
1131 // Some basic checks for this fusion to be valid.
1132 if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
1133 return {};
1134
1135 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1136 return map.isProjectedPermutation();
1137 })) {
1138 return {};
1139 }
1140
1141 // Compute all the loops with the reduction iterator types.
1142 SmallVector<unsigned> reductionDims;
1143 genericOp.getReductionDims(reductionDims);
1144
1145 llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1146 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1147 auto iteratorTypes = genericOp.getIteratorTypesArray();
1148 SmallVector<ReassociationIndices> iterationSpaceReassociation;
1149 for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1150 assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1151
1152 // Ignore dims that are not folded.
1153 if (foldedRangeDims.size() == 1)
1154 continue;
1155
1156 ReassociationIndices foldedIterationSpaceDims =
1157 getDomainReassociation(indexingMap, rangeReassociation: foldedRangeDims);
1158
1159 // Check that the folded iteration dims do not contain already processed
1160 // dims.
1161 if (llvm::any_of(Range&: foldedIterationSpaceDims, P: [&](int64_t dim) {
1162 return processedIterationDims.count(V: dim);
1163 }))
1164 continue;
1165
1166 // Check that all folded iterator types are all parallel or all reductions.
1167 utils::IteratorType startIteratorType =
1168 iteratorTypes[foldedIterationSpaceDims[0]];
1169 if (!isParallelIterator(startIteratorType) &&
1170 !isReductionIterator(startIteratorType))
1171 continue;
1172 if (llvm::any_of(Range&: foldedIterationSpaceDims, P: [&](int64_t dim) {
1173 return iteratorTypes[dim] != startIteratorType;
1174 }))
1175 continue;
1176
1177 // If the folded dimensions correspond to a "reduction" iterator type,
1178 // the folded dimensions need to be "in-order". Strictly speaking this is
1179 // not necessary, for reductions that are associative and commutative, but
1180 // using a more strict definition of reduction for now.
1181 if (isReductionIterator(startIteratorType)) {
1182 bool isContiguous = false;
1183 for (const auto &startDim : llvm::enumerate(First&: reductionDims)) {
1184 // Move window in `reductionDims` to start of the folded iteration dims.
1185 if (startDim.value() != foldedIterationSpaceDims[0])
1186 continue;
1187 // If sizes doesnt match, trivial not contiguous. This condition should
1188 // not be hit.
1189 if (startDim.index() + foldedIterationSpaceDims.size() >
1190 reductionDims.size())
1191 break;
1192 // Check that the contiguity is maintained.
1193 isContiguous = true;
1194 for (const auto &foldedDim :
1195 llvm::enumerate(foldedIterationSpaceDims)) {
1196 if (reductionDims[foldedDim.index() + startDim.index()] !=
1197 foldedDim.value()) {
1198 isContiguous = false;
1199 break;
1200 }
1201 }
1202 break;
1203 }
1204 if (!isContiguous)
1205 continue;
1206 }
1207
1208 // Check that the sequence is preserved in all indexing maps.
1209 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1210 [&](AffineMap indexingMap) {
1211 return !isDimSequencePreserved(indexingMap,
1212 foldedIterationSpaceDims);
1213 }))
1214 continue;
1215
1216 processedIterationDims.insert(I: foldedIterationSpaceDims.begin(),
1217 E: foldedIterationSpaceDims.end());
1218 iterationSpaceReassociation.emplace_back(
1219 Args: std::move(foldedIterationSpaceDims));
1220 }
1221
1222 return iterationSpaceReassociation;
1223}
1224
1225/// Helper class to carry state while collapsing the `linalg.generic` op.
1226namespace {
1227class CollapsingInfo {
1228public:
1229 LogicalResult initialize(unsigned origNumLoops,
1230 ArrayRef<ReassociationIndices> foldedIterationDims) {
1231 llvm::SmallDenseSet<int64_t, 4> processedDims;
1232 // Find all the dims that are folded.
1233 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1234 if (foldedIterationDim.empty())
1235 continue;
1236 // If the folded dims contain dims already folded, that's illegal
1237 // specification. Repetition within a list is also illegal.
1238 for (auto dim : foldedIterationDim) {
1239 if (dim >= origNumLoops)
1240 return failure();
1241 if (processedDims.count(V: dim))
1242 return failure();
1243 processedDims.insert(V: dim);
1244 }
1245 collapsedOpToOrigOpIterationDim.emplace_back(Args: foldedIterationDim.begin(),
1246 Args: foldedIterationDim.end());
1247 }
1248 if (processedDims.size() > origNumLoops)
1249 return failure();
1250
1251 // Add all the preserved dims of the original op as single
1252 // elements to `collapsedOpToOrigOpIterationDim`.
1253 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: origNumLoops)) {
1254 if (processedDims.count(V: dim))
1255 continue;
1256 collapsedOpToOrigOpIterationDim.emplace_back(Args: ReassociationIndices{dim});
1257 }
1258
1259 llvm::sort(C&: collapsedOpToOrigOpIterationDim,
1260 Comp: [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1261 return lhs[0] < rhs[0];
1262 });
1263 origOpToCollapsedOpIterationDim.resize(N: origNumLoops);
1264 for (const auto &foldedDims :
1265 llvm::enumerate(First&: collapsedOpToOrigOpIterationDim)) {
1266 for (const auto &dim : enumerate(First&: foldedDims.value()))
1267 origOpToCollapsedOpIterationDim[dim.value()] =
1268 std::make_pair<int64_t, unsigned>(x: foldedDims.index(), y: dim.index());
1269 }
1270 return success();
1271 }
1272
1273 /// Return mapping from collapsed loop domain to original loop domain.
1274 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1275 return collapsedOpToOrigOpIterationDim;
1276 }
1277
1278 /// Return mapping from original loop domain to collapsed loop domain. The
1279 /// mapping is a pair. First value is the dimension in the collapsed loop that
1280 /// the original loop is mapped to. Second is the relative position in folded
1281 /// list of this domain. For example if the original loop domain is 3D, and
1282 /// the collapsed loop domain is folding all of it, i.e.
1283 ///
1284 /// ```
1285 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1286 /// ```
1287 ///
1288 /// then
1289 ///
1290 /// ```
1291 /// origOpToCollapsedOpMapping[0] = {0, 0};
1292 /// origOpToCollapsedOpMapping[1] = {0, 1};
1293 /// origOpToCollapsedOpMapping[2] = {0, 2};
1294 /// origOpToCollapsedOpMapping[3] = {1, 0};
1295 /// origOpToCollapsedOpMapping[4] = {1, 1};
1296 /// ```
1297 ///
1298 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1299 return origOpToCollapsedOpIterationDim;
1300 }
1301
1302 /// Return the collapsed op iteration domain rank.
1303 unsigned getCollapsedOpIterationRank() const {
1304 return collapsedOpToOrigOpIterationDim.size();
1305 }
1306
1307private:
1308 /// Map from the iteration domain index in collapsed op to the iteration
1309 /// domain indices in the original op.
1310 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1311
1312 /// Map from iteration domain index in the original op to the iteration domain
1313 /// index in the collapsed op.
1314 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1315};
1316} // namespace
1317
1318/// Get the iterator types for the collapsed operation given the original
1319/// iterator types and collapsed dimensions.
1320static SmallVector<utils::IteratorType>
1321getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1322 const CollapsingInfo &collapsingInfo) {
1323 SmallVector<utils::IteratorType> collapsedIteratorTypes;
1324 for (ReassociationIndicesRef foldedIterDims :
1325 collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1326 assert(!foldedIterDims.empty() &&
1327 "reassociation indices expected to have non-empty sets");
1328 // Just pick the iterator type of the first folded dim. Pre-condition checks
1329 // expected to have checked that iterator types of all folded dimensions are
1330 // the same.
1331 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1332 }
1333 return collapsedIteratorTypes;
1334}
1335
1336/// Compute the indexing map in the collapsed op that corresponds to the given
1337/// `indexingMap` of the original operation.
1338static AffineMap
1339getCollapsedOpIndexingMap(AffineMap indexingMap,
1340 const CollapsingInfo &collapsingInfo) {
1341 MLIRContext *context = indexingMap.getContext();
1342 assert(indexingMap.isProjectedPermutation() &&
1343 "expected indexing map to be projected permutation");
1344 SmallVector<AffineExpr> resultExprs;
1345 auto origOpToCollapsedOpMapping =
1346 collapsingInfo.getOrigOpToCollapsedOpMapping();
1347 for (auto expr : indexingMap.getResults()) {
1348 unsigned dim = cast<AffineDimExpr>(Val&: expr).getPosition();
1349 // If the dim is not the first of the collapsed dim, do nothing.
1350 if (origOpToCollapsedOpMapping[dim].second != 0)
1351 continue;
1352 // The next n-dims are guaranteed to be collapsed. So just use the
1353 // iteration dimension of the collapsed op.
1354 resultExprs.push_back(
1355 Elt: getAffineDimExpr(position: origOpToCollapsedOpMapping[dim].first, context));
1356 }
1357 return AffineMap::get(dimCount: collapsingInfo.getCollapsedOpIterationRank(), symbolCount: 0,
1358 results: resultExprs, context);
1359}
1360
1361/// Return the `reassociation` indices to use to collapse the operand when the
1362/// iteration space of a generic op is collapsed.
1363static SmallVector<ReassociationIndices>
1364getOperandReassociation(AffineMap indexingMap,
1365 const CollapsingInfo &collapsingInfo) {
1366 unsigned counter = 0;
1367 SmallVector<ReassociationIndices> operandReassociation;
1368 auto origOpToCollapsedOpMapping =
1369 collapsingInfo.getOrigOpToCollapsedOpMapping();
1370 auto collapsedOpToOrigOpMapping =
1371 collapsingInfo.getCollapsedOpToOrigOpMapping();
1372 while (counter < indexingMap.getNumResults()) {
1373 unsigned dim =
1374 cast<AffineDimExpr>(Val: indexingMap.getResult(idx: counter)).getPosition();
1375 // This is the start of a collapsed dimensions of the iteration that
1376 // is gauranteed to be preserved in the indexing map. The number of folded
1377 // dims is obtained from the collapsed op to original op mapping.
1378 unsigned numFoldedDims =
1379 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1380 .size();
1381 if (origOpToCollapsedOpMapping[dim].second == 0) {
1382 auto range = llvm::seq<unsigned>(Begin: counter, End: counter + numFoldedDims);
1383 operandReassociation.emplace_back(Args: range.begin(), Args: range.end());
1384 }
1385 counter += numFoldedDims;
1386 }
1387 return operandReassociation;
1388}
1389
1390/// Get the new value to use for a given `OpOperand` in the collapsed operation.
1391static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1392 OpOperand *opOperand,
1393 const CollapsingInfo &collapsingInfo,
1394 OpBuilder &builder) {
1395 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1396 SmallVector<ReassociationIndices> operandReassociation =
1397 getOperandReassociation(indexingMap, collapsingInfo);
1398
1399 // If the number of entries in the reassociation for the operand is same as
1400 // the number of results of the indexing map, then nothing to do for this
1401 // operand.
1402 Value operand = opOperand->get();
1403 if (operandReassociation.size() == indexingMap.getNumResults())
1404 return operand;
1405
1406 // Insert a reshape to collapse the dimensions.
1407 if (isa<MemRefType>(Val: operand.getType())) {
1408 return builder
1409 .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1410 .getResult();
1411 }
1412 return builder
1413 .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1414 .getResult();
1415}
1416
1417/// Modify the `linalg.index` operations in the original generic op, to its
1418/// value in the collapsed operation.
1419void generateCollapsedIndexingRegion(Location loc, Block *block,
1420 const CollapsingInfo &collapsingInfo,
1421 ValueRange loopRange,
1422 RewriterBase &rewriter) {
1423 OpBuilder::InsertionGuard g(rewriter);
1424 rewriter.setInsertionPointToStart(block);
1425
1426 // Collect all the original index ops.
1427 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1428
1429 // For each folded dimension list resolve the original induction variable
1430 // values in terms of the folded dimension induction variable.
1431 // i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1432 // can be inverted to
1433 // i2 = i_{folded} % d2
1434 // i1 = (i_{folded} / d2) % d1
1435 // i0 = i_{folded} / (d1 * d2)
1436 llvm::DenseMap<unsigned, Value> indexReplacementVals;
1437 for (auto foldedDims :
1438 enumerate(First: collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1439 ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1440 Value newIndexVal =
1441 rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1442 for (auto dim : llvm::reverse(C: foldedDimsRef.drop_front())) {
1443 indexReplacementVals[dim] =
1444 rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
1445 newIndexVal =
1446 rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
1447 }
1448 indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1449 }
1450
1451 for (auto indexOp : indexOps) {
1452 auto dim = indexOp.getDim();
1453 rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1454 }
1455}
1456
1457void collapseOperandsAndResults(LinalgOp op,
1458 const CollapsingInfo &collapsingInfo,
1459 RewriterBase &rewriter,
1460 SmallVectorImpl<Value> &inputOperands,
1461 SmallVectorImpl<Value> &outputOperands,
1462 SmallVectorImpl<Type> &resultTypes) {
1463 Location loc = op->getLoc();
1464 inputOperands =
1465 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1466 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1467 rewriter);
1468 });
1469
1470 // Get the output operands and result types.
1471 resultTypes.reserve(N: op.getNumDpsInits());
1472 outputOperands.reserve(N: op.getNumDpsInits());
1473 for (OpOperand &output : op.getDpsInitsMutable()) {
1474 Value newOutput =
1475 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1476 outputOperands.push_back(newOutput);
1477 // If the op has "buffer semantics", then the init operands are ranked
1478 // memrefs and the op has no results.
1479 if (!op.hasPureBufferSemantics())
1480 resultTypes.push_back(newOutput.getType());
1481 }
1482}
1483
1484/// Clone a `LinalgOp` to a collapsed version of same name
1485template <typename OpTy>
1486OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1487 const CollapsingInfo &collapsingInfo) {
1488 return nullptr;
1489}
1490
1491/// Collapse any `LinalgOp` that does not require any specialization such as
1492/// indexing_maps, iterator_types, etc.
1493template <>
1494LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1495 const CollapsingInfo &collapsingInfo) {
1496 SmallVector<Value> inputOperands, outputOperands;
1497 SmallVector<Type> resultTypes;
1498 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1499 outputOperands, resultTypes);
1500
1501 return clone(
1502 rewriter, origOp, resultTypes,
1503 llvm::to_vector(Range: llvm::concat<Value>(Ranges&: inputOperands, Ranges&: outputOperands)));
1504}
1505
1506/// Collapse a `GenericOp`
1507template <>
1508GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1509 GenericOp origOp,
1510 const CollapsingInfo &collapsingInfo) {
1511 SmallVector<Value> inputOperands, outputOperands;
1512 SmallVector<Type> resultTypes;
1513 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1514 outputOperands, resultTypes);
1515 SmallVector<AffineMap> indexingMaps(
1516 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1517 return getCollapsedOpIndexingMap(map, collapsingInfo);
1518 }));
1519
1520 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1521 origOp.getIteratorTypesArray(), collapsingInfo));
1522
1523 GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1524 origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1525 iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1526 Block *origOpBlock = &origOp->getRegion(0).front();
1527 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1528 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1529 collapsedOpBlock->getArguments());
1530 return collapsedOp;
1531}
1532
1533LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1534 RewriterBase &rewriter) {
1535 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1536 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1537 } else {
1538 return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1539 }
1540}
1541
1542/// Implementation of fusion with reshape operation by collapsing dimensions.
1543FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1544 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1545 RewriterBase &rewriter) {
1546 // Bail on trivial no-op cases.
1547 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1548 llvm::all_of(Range&: foldedIterationDims, P: [](ReassociationIndicesRef foldedDims) {
1549 return foldedDims.size() <= 1;
1550 }))
1551 return failure();
1552
1553 bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1554 if (hasPureBufferSemantics &&
1555 !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1556 MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1557 if (!memRefToCollapse)
1558 return true;
1559
1560 return memref::CollapseShapeOp::isGuaranteedCollapsible(
1561 memRefToCollapse, foldedIterationDims);
1562 }))
1563 return rewriter.notifyMatchFailure(op,
1564 "memref is not guaranteed collapsible");
1565
1566 CollapsingInfo collapsingInfo;
1567 if (failed(
1568 collapsingInfo.initialize(origNumLoops: op.getNumLoops(), foldedIterationDims))) {
1569 return rewriter.notifyMatchFailure(
1570 op, "illegal to collapse specified dimensions");
1571 }
1572
1573 // Bail on non-canonical ranges.
1574 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1575 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1576 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1577 return cast<IntegerAttr>(attr).getInt() == value;
1578 llvm::APInt actual;
1579 return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
1580 actual.getSExtValue() == value;
1581 };
1582 if (!llvm::all_of(Range&: loopRanges, P: [&](Range range) {
1583 return opFoldIsConstantValue(range.offset, 0) &&
1584 opFoldIsConstantValue(range.stride, 1);
1585 })) {
1586 return rewriter.notifyMatchFailure(
1587 op, "expected all loop ranges to have zero start and unit stride");
1588 }
1589
1590 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1591
1592 Location loc = op->getLoc();
1593 if (collapsedOp.hasIndexSemantics()) {
1594 // Collect the loop range of the generic op.
1595 OpBuilder::InsertionGuard g(rewriter);
1596 rewriter.setInsertionPoint(collapsedOp);
1597 SmallVector<Value> loopBound =
1598 llvm::map_to_vector(loopRanges, [&](Range range) {
1599 return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
1600 });
1601 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1602 collapsingInfo, loopBound, rewriter);
1603 }
1604
1605 // Insert expanding reshape for the result to get back the original result
1606 // type.
1607 SmallVector<Value> results;
1608 for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1609 Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1610 auto originalResultType =
1611 cast<ShapedType>(originalResult.value().getType());
1612 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1613 if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1614 AffineMap indexingMap =
1615 op.getIndexingMapMatchingResult(originalResult.value());
1616 SmallVector<ReassociationIndices> reassociation =
1617 getOperandReassociation(indexingMap, collapsingInfo);
1618 if (isa<MemRefType>(collapsedOpResult.getType())) {
1619 Value result = rewriter.create<memref::ExpandShapeOp>(
1620 loc, originalResultType, collapsedOpResult, reassociation);
1621 results.push_back(result);
1622 } else {
1623 Value result = rewriter.create<tensor::ExpandShapeOp>(
1624 loc, originalResultType, collapsedOpResult, reassociation);
1625 results.push_back(result);
1626 }
1627 } else {
1628 results.push_back(collapsedOpResult);
1629 }
1630 }
1631 return CollapseResult{results, collapsedOp};
1632}
1633
1634namespace {
1635
1636/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1637/// contracting dimensions of the loop.
1638class FoldWithProducerReshapeOpByCollapsing
1639 : public OpRewritePattern<GenericOp> {
1640public:
1641 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1642 ControlFusionFn foldReshapes,
1643 PatternBenefit benefit = 1)
1644 : OpRewritePattern<GenericOp>(context, benefit),
1645 controlFoldingReshapes(std::move(foldReshapes)) {}
1646
1647 LogicalResult matchAndRewrite(GenericOp genericOp,
1648 PatternRewriter &rewriter) const override {
1649 for (OpOperand &opOperand : genericOp->getOpOperands()) {
1650 tensor::ExpandShapeOp reshapeOp =
1651 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1652 if (!reshapeOp)
1653 continue;
1654
1655 SmallVector<ReassociationIndices> collapsableIterationDims =
1656 getCollapsableIterationSpaceDims(genericOp, &opOperand,
1657 reshapeOp.getReassociationIndices());
1658 if (collapsableIterationDims.empty() ||
1659 !controlFoldingReshapes(&opOperand)) {
1660 continue;
1661 }
1662
1663 std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1664 genericOp, collapsableIterationDims, rewriter);
1665 if (!collapseResult) {
1666 return rewriter.notifyMatchFailure(
1667 genericOp, "failed to do the fusion by collapsing transformation");
1668 }
1669
1670 rewriter.replaceOp(genericOp, collapseResult->results);
1671 return success();
1672 }
1673 return failure();
1674 }
1675
1676private:
1677 ControlFusionFn controlFoldingReshapes;
1678};
1679
1680/// Pattern to collapse dimensions.
1681template <typename LinalgType>
1682class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
1683public:
1684 CollapseLinalgDimensions(MLIRContext *context,
1685 GetCollapsableDimensionsFn collapseDimensions,
1686 PatternBenefit benefit = 1)
1687 : OpRewritePattern<LinalgType>(context, benefit),
1688 controlCollapseDimension(std::move(collapseDimensions)) {}
1689
1690 LogicalResult matchAndRewrite(LinalgType op,
1691 PatternRewriter &rewriter) const override {
1692 SmallVector<ReassociationIndices> collapsableIterationDims =
1693 controlCollapseDimension(op);
1694 if (collapsableIterationDims.empty())
1695 return failure();
1696
1697 // Check if the specified list of dimensions to collapse is a valid list.
1698 if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
1699 collapsableIterationDims)) {
1700 return rewriter.notifyMatchFailure(
1701 op, "specified dimensions cannot be collapsed");
1702 }
1703
1704 std::optional<CollapseResult> collapseResult =
1705 collapseOpIterationDims(op, collapsableIterationDims, rewriter);
1706 if (!collapseResult) {
1707 return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
1708 }
1709 rewriter.replaceOp(op, collapseResult->results);
1710 return success();
1711 }
1712
1713private:
1714 GetCollapsableDimensionsFn controlCollapseDimension;
1715};
1716
1717} // namespace
1718
1719//===---------------------------------------------------------------------===//
1720// Methods and patterns that fuse constants with linalg.generic operations.
1721//===---------------------------------------------------------------------===//
1722
1723namespace {
1724/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1725/// handle cases where the constant is not single-valued.
1726class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1727public:
1728 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1729 : OpRewritePattern<GenericOp>(context, benefit) {}
1730
1731 LogicalResult matchAndRewrite(GenericOp genericOp,
1732 PatternRewriter &rewriter) const override {
1733 if (!genericOp.hasPureTensorSemantics())
1734 return failure();
1735 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1736 Operation *def = opOperand->get().getDefiningOp();
1737 TypedAttr constantAttr;
1738 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1739 {
1740 DenseElementsAttr splatAttr;
1741 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1742 splatAttr.isSplat() &&
1743 splatAttr.getType().getElementType().isIntOrFloat()) {
1744 constantAttr = splatAttr.getSplatValue<TypedAttr>();
1745 return true;
1746 }
1747 }
1748 {
1749 IntegerAttr intAttr;
1750 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1751 constantAttr = intAttr;
1752 return true;
1753 }
1754 }
1755 {
1756 FloatAttr floatAttr;
1757 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1758 constantAttr = floatAttr;
1759 return true;
1760 }
1761 }
1762 return false;
1763 };
1764
1765 auto resultValue = dyn_cast<OpResult>(opOperand->get());
1766 if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1767 continue;
1768
1769 // The operands and the indexing_maps of the fused operation the same as
1770 // the operands and indexing_maps of the generic operations with the
1771 // values at the constant index dropped.
1772 SmallVector<AffineMap> fusedIndexMaps;
1773 SmallVector<Value> fusedOperands;
1774 SmallVector<Location> fusedLocs{genericOp.getLoc()};
1775 fusedIndexMaps.reserve(genericOp->getNumOperands());
1776 fusedOperands.reserve(genericOp.getNumDpsInputs());
1777 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1778 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1779 if (inputOperand == opOperand)
1780 continue;
1781 Value inputValue = inputOperand->get();
1782 fusedIndexMaps.push_back(
1783 genericOp.getMatchingIndexingMap(inputOperand));
1784 fusedOperands.push_back(inputValue);
1785 fusedLocs.push_back(inputValue.getLoc());
1786 }
1787 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
1788 fusedIndexMaps.push_back(
1789 genericOp.getMatchingIndexingMap(&outputOperand));
1790
1791 // Check if the operation shapes to loops map is computable.
1792 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
1793 return rewriter.notifyMatchFailure(
1794 genericOp, "fused op loop bound computation failed");
1795 }
1796
1797 // Create a constant scalar value from the splat constant.
1798 Value scalarConstant =
1799 rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
1800
1801 SmallVector<Value> outputOperands = genericOp.getOutputs();
1802 auto fusedOp = rewriter.create<GenericOp>(
1803 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
1804 /*inputs=*/fusedOperands,
1805 /*outputs=*/outputOperands,
1806 rewriter.getAffineMapArrayAttr(fusedIndexMaps),
1807 genericOp.getIteratorTypes(),
1808 /*doc=*/nullptr,
1809 /*library_call=*/nullptr);
1810
1811 // Map the block argument corresponding to the replaced argument with the
1812 // scalar constant.
1813 Region &region = genericOp->getRegion(0);
1814 Block &entryBlock = *region.begin();
1815 IRMapping mapping;
1816 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
1817 scalarConstant);
1818 Region &fusedRegion = fusedOp->getRegion(0);
1819 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
1820 mapping);
1821 rewriter.replaceOp(genericOp, fusedOp->getResults());
1822 return success();
1823 }
1824 return failure();
1825 }
1826};
1827
1828} // namespace
1829
1830//===---------------------------------------------------------------------===//
1831// Miscellaneous patterns that help fusion.
1832//===---------------------------------------------------------------------===//
1833
1834namespace {
1835/// Forces `outs` operands of linalg operations to use `tensor.empty` if the
1836/// value of the `outs` operand is not used within the op. This is only
1837/// implemented for `linalg.generic` operations for now, but should hold for all
1838/// linalg structured ops.
1839struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
1840 using OpRewritePattern<GenericOp>::OpRewritePattern;
1841
1842 LogicalResult matchAndRewrite(GenericOp op,
1843 PatternRewriter &rewriter) const override {
1844 rewriter.startOpModification(op: op);
1845 bool modifiedOutput = false;
1846 Location loc = op.getLoc();
1847 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1848 if (!op.payloadUsesValueFromOperand(&opOperand)) {
1849 Value operandVal = opOperand.get();
1850 auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
1851 if (!operandType)
1852 continue;
1853
1854 // If outs is sparse, leave it to the sparsifier.
1855 if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
1856 continue;
1857
1858 // If outs is already an `empty` operation, nothing to do.
1859 auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
1860 if (definingOp)
1861 continue;
1862 modifiedOutput = true;
1863 SmallVector<OpFoldResult> mixedSizes =
1864 tensor::getMixedSizes(rewriter, loc, operandVal);
1865 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1866 loc, mixedSizes, operandType.getElementType());
1867 op->setOperand(opOperand.getOperandNumber(), emptyTensor);
1868 }
1869 }
1870 if (!modifiedOutput) {
1871 rewriter.cancelOpModification(op: op);
1872 return failure();
1873 }
1874 rewriter.finalizeOpModification(op: op);
1875 return success();
1876 }
1877};
1878
1879/// Fold linalg.fill into linalg.generic
1880struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
1881 using OpRewritePattern<GenericOp>::OpRewritePattern;
1882
1883 LogicalResult matchAndRewrite(GenericOp genericOp,
1884 PatternRewriter &rewriter) const override {
1885 if (!genericOp.hasPureTensorSemantics())
1886 return failure();
1887 bool fillFound = false;
1888 Block &payload = genericOp.getRegion().front();
1889 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1890 if (!genericOp.payloadUsesValueFromOperand(opOperand))
1891 continue;
1892 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
1893 if (!fillOp)
1894 continue;
1895 fillFound = true;
1896 Value fillVal = fillOp.value();
1897 auto resultType =
1898 cast<RankedTensorType>(fillOp.result().getType()).getElementType();
1899 Value convertedVal =
1900 convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
1901 /*isUnsignedCast =*/false);
1902 rewriter.replaceAllUsesWith(
1903 payload.getArgument(opOperand->getOperandNumber()), convertedVal);
1904 }
1905 return success(isSuccess: fillFound);
1906 }
1907};
1908} // namespace
1909
1910void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1911 RewritePatternSet &patterns,
1912 const ControlFusionFn &controlFoldingReshapes) {
1913 patterns.add<FoldReshapeWithGenericOpByExpansion>(arg: patterns.getContext(),
1914 args: controlFoldingReshapes);
1915 patterns.add<FoldWithProducerReshapeOpByExpansion>(arg: patterns.getContext(),
1916 args: controlFoldingReshapes);
1917}
1918
1919void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
1920 RewritePatternSet &patterns,
1921 const ControlFusionFn &controlFoldingReshapes) {
1922 patterns.add<FoldWithProducerReshapeOpByCollapsing>(arg: patterns.getContext(),
1923 args: controlFoldingReshapes);
1924}
1925
1926void mlir::linalg::populateElementwiseOpsFusionPatterns(
1927 RewritePatternSet &patterns,
1928 const ControlFusionFn &controlElementwiseOpsFusion) {
1929 auto *context = patterns.getContext();
1930 patterns.add<FuseElementwiseOps>(arg&: context, args: controlElementwiseOpsFusion);
1931 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
1932 RemoveOutsDependency>(arg&: context);
1933 // Add the patterns that clean up dead operands and results.
1934 populateEraseUnusedOperandsAndResultsPatterns(patterns);
1935}
1936
1937void mlir::linalg::populateCollapseDimensions(
1938 RewritePatternSet &patterns,
1939 const GetCollapsableDimensionsFn &controlCollapseDimensions) {
1940 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
1941 CollapseLinalgDimensions<linalg::CopyOp>>(
1942 patterns.getContext(), controlCollapseDimensions);
1943}
1944
1945//===---------------------------------------------------------------------===//
1946// Passes
1947//===---------------------------------------------------------------------===//
1948
1949namespace {
1950
1951/// Pass that fuses generic ops on tensors. Used only for testing.
1952// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
1953// patterns added here heavily depends on the cost function used. Having an
1954// opinionated pass of this form is not recommended. Deprecate this pass in
1955// favor of test passes that check the functionality of each of the patterns
1956// added here individually.
1957struct LinalgElementwiseOpFusionPass
1958 : public impl::LinalgElementwiseOpFusionPassBase<
1959 LinalgElementwiseOpFusionPass> {
1960 using impl::LinalgElementwiseOpFusionPassBase<
1961 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
1962 void runOnOperation() override {
1963 Operation *op = getOperation();
1964 MLIRContext *context = op->getContext();
1965 RewritePatternSet patterns(context);
1966
1967 // Add folding with reshape by expansion patterns.
1968 ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
1969 Operation *producer = fusedOperand->get().getDefiningOp();
1970 return producer && producer->hasOneUse();
1971 };
1972
1973 // Add elementwise op fusion patterns.
1974 populateElementwiseOpsFusionPatterns(patterns, controlElementwiseOpsFusion: defaultControlFn);
1975 populateFoldReshapeOpsByExpansionPatterns(patterns, controlFoldingReshapes: defaultControlFn);
1976
1977 // General canonicalization patterns.
1978 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
1979 GenericOp::getCanonicalizationPatterns(patterns, context);
1980 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
1981 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
1982 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
1983 patterns);
1984
1985 // Add constant folding patterns.
1986 populateConstantFoldLinalgOperations(patterns, controlFn: defaultControlFn);
1987
1988 // Use TopDownTraversal for compile time reasons
1989 GreedyRewriteConfig grc;
1990 grc.useTopDownTraversal = true;
1991 (void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc);
1992 }
1993};
1994
1995} // namespace
1996

source code of mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp