1//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns/pass to remove usage of unit-extent dimensions
10// to specify broadcasting in favor of more canonical representation of the
11// computation
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Linalg/Passes.h"
16
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Linalg/IR/Linalg.h"
20#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21#include "mlir/Dialect/Linalg/Utils/Utils.h"
22#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
23#include "mlir/Dialect/Tensor/IR/Tensor.h"
24#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
25#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
26#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/BuiltinTypes.h"
29#include "mlir/Transforms/FoldUtils.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#include "llvm/Support/Debug.h"
32
33namespace mlir {
34#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
35#include "mlir/Dialect/Linalg/Passes.h.inc"
36} // namespace mlir
37
38#define DEBUG_TYPE "linalg-drop-unit-dims"
39
40using namespace mlir;
41using namespace mlir::linalg;
42
43namespace {
44/// Pattern to move init operands to ins when all the loops are parallel and
45/// blockArgument corresponding to init is used in the region. This is a fix-up
46/// when unit reduction dimensions are all folded away. In this context, it
47/// becomes a elementwise generic op. E.g., it converts
48///
49/// %0 = tensor.empty() : tensor<1x1xf32>
50/// %1 = linalg.fill
51/// ins(%cst : f32)
52/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
53/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
54/// affine_map<(d0) -> (0, d0)>],
55/// iterator_types = ["parallel"]}
56/// ins(%arg0 : tensor<1x?x1x1xf32>)
57/// outs(%1 : tensor<1x1xf32>) {
58/// ^bb0(%in: f32, %out: f32):
59/// %3 = arith.addf %in, %out : f32
60/// linalg.yield %3 : f32
61/// } -> tensor<1x1xf32>
62///
63/// into
64///
65/// %0 = tensor.empty() : tensor<1x1xf32>
66/// %1 = linalg.fill
67/// ins(%cst : f32)
68/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
69/// %2 = tensor.empty() : tensor<1x1xf32>
70/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
71/// affine_map<(d0) -> (0, d0)>,
72/// affine_map<(d0) -> (0, d0)>],
73/// iterator_types = ["parallel"]}
74/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
75/// outs(%2 : tensor<1x1xf32>) {
76/// ^bb0(%in: f32, %in_0: f32, %out: f32):
77/// %4 = arith.addf %in, %in_0 : f32
78/// linalg.yield %4 : f32
79/// } -> tensor<1x1xf32>
80struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
81 using OpRewritePattern<GenericOp>::OpRewritePattern;
82 LogicalResult matchAndRewrite(GenericOp genericOp,
83 PatternRewriter &rewriter) const override {
84 if (!genericOp.hasPureTensorSemantics())
85 return failure();
86 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
87 return failure();
88
89 auto outputOperands = genericOp.getDpsInitsMutable();
90 SetVector<OpOperand *> candidates;
91 for (OpOperand &op : outputOperands) {
92 if (genericOp.getMatchingBlockArgument(opOperand: &op).use_empty())
93 continue;
94 candidates.insert(X: &op);
95 }
96
97 if (candidates.empty())
98 return failure();
99
100 // Compute the modified indexing maps.
101 int64_t origNumInput = genericOp.getNumDpsInputs();
102 SmallVector<Value> newInputOperands = genericOp.getDpsInputs();
103 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
104 SmallVector<AffineMap> newIndexingMaps;
105 newIndexingMaps.append(in_start: indexingMaps.begin(),
106 in_end: std::next(x: indexingMaps.begin(), n: origNumInput));
107 for (OpOperand *op : candidates) {
108 newInputOperands.push_back(Elt: op->get());
109 newIndexingMaps.push_back(Elt: genericOp.getMatchingIndexingMap(opOperand: op));
110 }
111 newIndexingMaps.append(in_start: std::next(x: indexingMaps.begin(), n: origNumInput),
112 in_end: indexingMaps.end());
113
114 Location loc = genericOp.getLoc();
115 SmallVector<Value> newOutputOperands =
116 llvm::to_vector(Range: genericOp.getDpsInits());
117 for (OpOperand *op : candidates) {
118 OpBuilder::InsertionGuard guard(rewriter);
119 rewriter.setInsertionPointAfterValue(op->get());
120 auto elemType = cast<ShapedType>(Val: op->get().getType()).getElementType();
121 auto empty = rewriter.create<tensor::EmptyOp>(
122 location: loc, args: tensor::getMixedSizes(builder&: rewriter, loc, value: op->get()), args&: elemType);
123
124 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
125 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
126 }
127
128 auto newOp = rewriter.create<GenericOp>(
129 location: loc, args: genericOp.getResultTypes(), args&: newInputOperands, args&: newOutputOperands,
130 args&: newIndexingMaps, args: genericOp.getIteratorTypesArray(),
131 /*bodyBuild=*/args: nullptr, args: linalg::getPrunedAttributeList(op: genericOp));
132
133 OpBuilder::InsertionGuard guard(rewriter);
134 Region &region = newOp.getRegion();
135 Block *block = rewriter.createBlock(parent: &region);
136 IRMapping mapper;
137 for (auto bbarg : genericOp.getRegionInputArgs())
138 mapper.map(from: bbarg, to: block->addArgument(type: bbarg.getType(), loc));
139
140 for (OpOperand *op : candidates) {
141 BlockArgument bbarg = genericOp.getMatchingBlockArgument(opOperand: op);
142 mapper.map(from: bbarg, to: block->addArgument(type: bbarg.getType(), loc));
143 }
144
145 for (OpOperand &op : outputOperands) {
146 BlockArgument bbarg = genericOp.getMatchingBlockArgument(opOperand: &op);
147 if (candidates.count(key: &op))
148 block->addArgument(type: bbarg.getType(), loc);
149 else
150 mapper.map(from: bbarg, to: block->addArgument(type: bbarg.getType(), loc));
151 }
152
153 for (auto &op : genericOp.getBody()->getOperations()) {
154 rewriter.clone(op, mapper);
155 }
156 rewriter.replaceOp(op: genericOp, newValues: newOp.getResults());
157
158 return success();
159 }
160};
161} // namespace
162
163//===---------------------------------------------------------------------===//
164// Drop loops that are unit-extents within Linalg operations.
165//===---------------------------------------------------------------------===//
166
167/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
168/// broadcasting. For example,
169///
170/// ```mlir
171/// #accesses = [
172/// affine_map<(d0, d1) -> (0, d1)>,
173/// affine_map<(d0, d1) -> (d0, 0)>,
174/// affine_map<(d0, d1) -> (d0, d1)>
175/// ]
176///
177/// #trait = {
178/// indexing_maps = #accesses,
179/// iterator_types = ["parallel", "parallel"],
180/// library_call = "some_external_fn"
181/// }
182///
183/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
184/// tensor<5x5xf32>
185/// {
186/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
187/// tensor<5xf32> into tensor<1x5xf32>
188/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
189/// tensor<5xf32> into tensor<5x1xf32>
190/// %2 = linalg.generic #trait %0, %1 {
191/// ^bb0(%arg2: f32, %arg3: f32):
192/// %3 = arith.addf %arg2, %arg3 : f32
193/// linalg.yield %3 : f32
194/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
195/// return %2 : tensor<5x5xf32>
196/// }
197///
198/// would canonicalize to
199///
200/// ```mlir
201/// #accesses = [
202/// affine_map<(d0, d1) -> (d1)>,
203/// affine_map<(d0, d1) -> (d0)>,
204/// affine_map<(d0, d1) -> (d0, d1)>
205/// ]
206///
207/// #trait = {
208/// indexing_maps = #accesses,
209/// iterator_types = ["parallel", "parallel"],
210/// library_call = "some_external_fn"
211/// }
212///
213/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
214/// tensor<5x5xf32>
215/// {
216/// %0 = linalg.generic #trait %arg0, %arg1 {
217/// ^bb0(%arg2: f32, %arg3: f32):
218/// %3 = arith.addf %arg2, %arg3 : f32
219/// linalg.yield %3 : f32
220/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
221/// return %0 : tensor<5x5xf32>
222/// }
223
224/// Update the index accesses of linalg operations having index semantics.
225static void
226replaceUnitDimIndexOps(GenericOp genericOp,
227 const llvm::SmallDenseSet<unsigned> &unitDims,
228 RewriterBase &rewriter) {
229 for (IndexOp indexOp :
230 llvm::make_early_inc_range(Range: genericOp.getBody()->getOps<IndexOp>())) {
231 OpBuilder::InsertionGuard guard(rewriter);
232 rewriter.setInsertionPoint(indexOp);
233 if (unitDims.count(V: indexOp.getDim()) != 0) {
234 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op: indexOp, args: 0);
235 } else {
236 // Update the dimension of the index operation if needed.
237 unsigned droppedDims = llvm::count_if(
238 Range: unitDims, P: [&](unsigned dim) { return dim < indexOp.getDim(); });
239 if (droppedDims != 0)
240 rewriter.replaceOpWithNewOp<IndexOp>(op: indexOp,
241 args: indexOp.getDim() - droppedDims);
242 }
243 }
244}
245
246/// Expand the given `value` so that the type matches the type of `origDest`.
247/// The `reassociation` is used when `rankReductionStrategy` is set to
248/// `RankReductionStrategy::ReassociativeReshape`.
249static Value
250expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
251 ArrayRef<ReassociationIndices> reassociation,
252 ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
253 // There are no results for memref outputs.
254 auto origResultType = cast<RankedTensorType>(Val: origDest.getType());
255 if (rankReductionStrategy ==
256 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
257 unsigned rank = origResultType.getRank();
258 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(value: 0));
259 SmallVector<OpFoldResult> sizes =
260 tensor::getMixedSizes(builder&: rewriter, loc, value: origDest);
261 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(value: 1));
262 return rewriter.createOrFold<tensor::InsertSliceOp>(
263 location: loc, args&: result, args&: origDest, args&: offsets, args&: sizes, args&: strides);
264 }
265
266 assert(rankReductionStrategy ==
267 ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
268 "unknown rank reduction strategy");
269 return rewriter
270 .create<tensor::ExpandShapeOp>(location: loc, args&: origResultType, args&: result, args&: reassociation)
271 .getResult();
272}
273
274/// Collapse the given `value` so that the type matches the type of
275/// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
276/// set to `RankReductionStrategy::ReassociativeReshape`.
277static Value collapseValue(
278 RewriterBase &rewriter, Location loc, Value operand,
279 ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
280 ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
281 if (auto memrefType = dyn_cast<MemRefType>(Val: operand.getType())) {
282 if (rankReductionStrategy ==
283 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
284 FailureOr<Value> rankReducingExtract =
285 memref::SubViewOp::rankReduceIfNeeded(b&: rewriter, loc, value: operand,
286 desiredShape: targetShape);
287 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
288 return *rankReducingExtract;
289 }
290
291 assert(
292 rankReductionStrategy ==
293 ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
294 "unknown rank reduction strategy");
295 MemRefLayoutAttrInterface layout;
296 auto targetType = MemRefType::get(shape: targetShape, elementType: memrefType.getElementType(),
297 layout, memorySpace: memrefType.getMemorySpace());
298 return rewriter.create<memref::CollapseShapeOp>(location: loc, args&: targetType, args&: operand,
299 args&: reassociation);
300 }
301 if (auto tensorType = dyn_cast<RankedTensorType>(Val: operand.getType())) {
302 if (rankReductionStrategy ==
303 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
304 FailureOr<Value> rankReducingExtract =
305 tensor::ExtractSliceOp::rankReduceIfNeeded(b&: rewriter, loc, value: operand,
306 desiredShape: targetShape);
307 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
308 return *rankReducingExtract;
309 }
310
311 assert(
312 rankReductionStrategy ==
313 ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
314 "unknown rank reduction strategy");
315 auto targetType =
316 RankedTensorType::get(shape: targetShape, elementType: tensorType.getElementType());
317 return rewriter.create<tensor::CollapseShapeOp>(location: loc, args&: targetType, args&: operand,
318 args&: reassociation);
319 }
320 llvm_unreachable("unsupported operand type");
321}
322
323/// Compute the modified metadata for an operands of operation
324/// whose unit dims are being dropped. Return the new indexing map
325/// to use, the shape of the operand in the replacement op
326/// and the `reassocation` to use to go from original operand shape
327/// to modified operand shape.
328struct UnitExtentReplacementInfo {
329 AffineMap indexMap;
330 SmallVector<ReassociationIndices> reassociation;
331 SmallVector<int64_t> targetShape;
332};
333static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
334 MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
335 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
336 ArrayRef<AffineExpr> dimReplacements) {
337 UnitExtentReplacementInfo info;
338 ReassociationIndices reassociationGroup;
339 SmallVector<AffineExpr> newIndexExprs;
340 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
341 ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
342 ArrayRef<AffineExpr> exprs = indexingMap.getResults();
343
344 auto isUnitDim = [&](unsigned dim) {
345 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val: exprs[dim])) {
346 unsigned oldPosition = dimExpr.getPosition();
347 return !oldDimsToNewDimsMap.count(Val: oldPosition) &&
348 (operandShape[dim] == 1);
349 }
350 // Handle the other case where the shape is 1, and is accessed using a
351 // constant 0.
352 if (operandShape[dim] == 1) {
353 auto constAffineExpr = dyn_cast<AffineConstantExpr>(Val: exprs[dim]);
354 return constAffineExpr && constAffineExpr.getValue() == 0;
355 }
356 return false;
357 };
358
359 unsigned dim = 0;
360 while (dim < operandShape.size() && isUnitDim(dim))
361 reassociationGroup.push_back(Elt: dim++);
362 while (dim < operandShape.size()) {
363 assert(!isUnitDim(dim) && "expected non unit-extent");
364 reassociationGroup.push_back(Elt: dim);
365 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
366 newIndexExprs.push_back(Elt: newExpr);
367 info.targetShape.push_back(Elt: operandShape[dim]);
368 ++dim;
369 // Fold all following dimensions that are unit-extent.
370 while (dim < operandShape.size() && isUnitDim(dim)) {
371 reassociationGroup.push_back(Elt: dim++);
372 }
373 info.reassociation.push_back(Elt: reassociationGroup);
374 reassociationGroup.clear();
375 }
376 info.indexMap =
377 AffineMap::get(dimCount: oldDimsToNewDimsMap.size(), symbolCount: indexingMap.getNumSymbols(),
378 results: newIndexExprs, context);
379 return info;
380}
381
382FailureOr<DropUnitDimsResult>
383linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
384 const ControlDropUnitDims &options) {
385 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
386 if (indexingMaps.empty())
387 return failure();
388
389 // 1. Check if any of the iteration dimensions are unit-trip count. They will
390 // end up being unit-trip count if they are used to index into a unit-dim
391 // tensor/memref.
392 AffineMap invertedMap =
393 inversePermutation(map: concatAffineMaps(maps: indexingMaps, context: rewriter.getContext()));
394 if (!invertedMap) {
395 return rewriter.notifyMatchFailure(arg&: genericOp,
396 msg: "invalid indexing maps for operation");
397 }
398
399 SmallVector<int64_t> allShapesSizes;
400 for (OpOperand &opOperand : genericOp->getOpOperands())
401 llvm::append_range(C&: allShapesSizes, R: genericOp.getShape(opOperand: &opOperand));
402
403 // 1a. Get the allowed list of dimensions to drop from the `options`.
404 SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
405 if (allowedUnitDims.empty()) {
406 return rewriter.notifyMatchFailure(
407 arg&: genericOp, msg: "control function returns no allowed unit dims to prune");
408 }
409 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
410 allowedUnitDims.end());
411 llvm::SmallDenseSet<unsigned> unitDims;
412 for (const auto &expr : enumerate(First: invertedMap.getResults())) {
413 if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(Val: expr.value())) {
414 if (allShapesSizes[dimExpr.getPosition()] == 1 &&
415 unitDimsFilter.count(V: expr.index()))
416 unitDims.insert(V: expr.index());
417 }
418 }
419
420 // 2. Compute the iterator types of the modified op by dropping the one-trip
421 // count loops.
422 SmallVector<utils::IteratorType> newIteratorTypes;
423 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
424 SmallVector<AffineExpr> dimReplacements;
425 unsigned newDims = 0;
426 for (auto [index, attr] :
427 llvm::enumerate(First: genericOp.getIteratorTypesArray())) {
428 if (unitDims.count(V: index)) {
429 dimReplacements.push_back(
430 Elt: getAffineConstantExpr(constant: 0, context: rewriter.getContext()));
431 } else {
432 newIteratorTypes.push_back(Elt: attr);
433 oldDimToNewDimMap[index] = newDims;
434 dimReplacements.push_back(
435 Elt: getAffineDimExpr(position: newDims, context: rewriter.getContext()));
436 newDims++;
437 }
438 }
439
440 // 3. For each of the operands, find the
441 // - modified affine map to use.
442 // - shape of the operands after the unit-dims are dropped.
443 // - the reassociation indices used to convert from the original
444 // operand type to modified operand (needed only when using reshapes
445 // for rank reduction strategy)
446 // Note that the indexing maps might need changing even if there are no
447 // unit dimensions that are dropped to handle cases where `0` is used to
448 // access a unit-extent tensor. Consider moving this out of this specific
449 // transformation as a stand-alone transformation. Kept here right now due
450 // to legacy.
451 SmallVector<AffineMap> newIndexingMaps;
452 SmallVector<SmallVector<ReassociationIndices>> reassociations;
453 SmallVector<SmallVector<int64_t>> targetShapes;
454 SmallVector<bool> collapsed;
455 auto hasCollapsibleType = [](OpOperand &operand) {
456 Type operandType = operand.get().getType();
457 if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(Val&: operandType)) {
458 return memrefOperandType.getLayout().isIdentity();
459 }
460 if (auto tensorOperandType = dyn_cast<RankedTensorType>(Val&: operandType)) {
461 return tensorOperandType.getEncoding() == nullptr;
462 }
463 return false;
464 };
465 for (OpOperand &opOperand : genericOp->getOpOperands()) {
466 auto indexingMap = genericOp.getMatchingIndexingMap(opOperand: &opOperand);
467 ArrayRef<int64_t> shape = genericOp.getShape(opOperand: &opOperand);
468 if (!hasCollapsibleType(opOperand)) {
469 AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
470 dimReplacements, symReplacements: ArrayRef<AffineExpr>{}, numResultDims: oldDimToNewDimMap.size(), numResultSyms: 0);
471 newIndexingMaps.push_back(Elt: newIndexingMap);
472 targetShapes.push_back(Elt: llvm::to_vector(Range&: shape));
473 collapsed.push_back(Elt: false);
474 reassociations.push_back(Elt: {});
475 continue;
476 }
477 auto replacementInfo = dropUnitExtentFromOperandMetadata(
478 context: rewriter.getContext(), genericOp, opOperand: &opOperand, oldDimsToNewDimsMap&: oldDimToNewDimMap,
479 dimReplacements);
480 reassociations.push_back(Elt: replacementInfo.reassociation);
481 newIndexingMaps.push_back(Elt: replacementInfo.indexMap);
482 targetShapes.push_back(Elt: replacementInfo.targetShape);
483 collapsed.push_back(Elt: !(replacementInfo.indexMap.getNumResults() ==
484 indexingMap.getNumResults()));
485 }
486
487 // Abort if the indexing maps of the result operation are not invertible
488 // (i.e. not legal) or if no dimension was reduced.
489 if (newIndexingMaps == indexingMaps ||
490 !inversePermutation(
491 map: concatAffineMaps(maps: newIndexingMaps, context: rewriter.getContext())))
492 return failure();
493
494 Location loc = genericOp.getLoc();
495 // 4. For each of the operands, collapse the operand to convert
496 // from original shape to shape in the modified operation if needed,
497 // either through use of reshapes or rank-reducing slices as
498 // specified in `options`.
499 SmallVector<Value> newOperands;
500 for (OpOperand &opOperand : genericOp->getOpOperands()) {
501 int64_t idx = opOperand.getOperandNumber();
502 if (!collapsed[idx]) {
503 newOperands.push_back(Elt: opOperand.get());
504 continue;
505 }
506 newOperands.push_back(Elt: collapseValue(rewriter, loc, operand: opOperand.get(),
507 targetShape: targetShapes[idx], reassociation: reassociations[idx],
508 rankReductionStrategy: options.rankReductionStrategy));
509 }
510
511 // 5. Create the `linalg.generic` operation with the new operands,
512 // indexing maps, iterator types and result types.
513 ArrayRef<Value> newInputs =
514 ArrayRef<Value>(newOperands).take_front(N: genericOp.getNumDpsInputs());
515 ArrayRef<Value> newOutputs =
516 ArrayRef<Value>(newOperands).take_back(N: genericOp.getNumDpsInits());
517 SmallVector<Type> resultTypes;
518 resultTypes.reserve(N: genericOp.getNumResults());
519 for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: genericOp.getNumResults()))
520 resultTypes.push_back(Elt: newOutputs[i].getType());
521 GenericOp replacementOp =
522 rewriter.create<GenericOp>(location: loc, args&: resultTypes, args&: newInputs, args&: newOutputs,
523 args&: newIndexingMaps, args&: newIteratorTypes);
524 rewriter.inlineRegionBefore(region&: genericOp.getRegion(), parent&: replacementOp.getRegion(),
525 before: replacementOp.getRegion().begin());
526 // 5a. Replace `linalg.index` operations that refer to the dropped unit
527 // dimensions.
528 replaceUnitDimIndexOps(genericOp: replacementOp, unitDims, rewriter);
529
530 // 6. If any result type changes, insert a reshape/slice to convert from the
531 // original type to the new type.
532 SmallVector<Value> resultReplacements;
533 for (auto [index, result] : llvm::enumerate(First: replacementOp.getResults())) {
534 unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
535 Value origDest = genericOp.getDpsInitOperand(i: index)->get();
536 if (!collapsed[opOperandIndex]) {
537 resultReplacements.push_back(Elt: result);
538 continue;
539 }
540 Value expandedValue = expandValue(rewriter, loc, result, origDest,
541 reassociation: reassociations[opOperandIndex],
542 rankReductionStrategy: options.rankReductionStrategy);
543 resultReplacements.push_back(Elt: expandedValue);
544 }
545
546 return DropUnitDimsResult{.resultOp: replacementOp, .replacements: resultReplacements};
547}
548
549namespace {
550struct DropUnitDims : public OpRewritePattern<GenericOp> {
551 DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
552 PatternBenefit benefit = 1)
553 : OpRewritePattern(context, benefit), options(std::move(options)) {}
554
555 LogicalResult matchAndRewrite(GenericOp genericOp,
556 PatternRewriter &rewriter) const override {
557 FailureOr<DropUnitDimsResult> result =
558 dropUnitDims(rewriter, genericOp, options);
559 if (failed(Result: result)) {
560 return failure();
561 }
562 rewriter.replaceOp(op: genericOp, newValues: result->replacements);
563 return success();
564 }
565
566private:
567 ControlDropUnitDims options;
568};
569} // namespace
570
571//===---------------------------------------------------------------------===//
572// Drop dimensions that are unit-extents within tensor operations.
573//===---------------------------------------------------------------------===//
574
575namespace {
576struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
577 DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
578 PatternBenefit benefit = 1)
579 : OpRewritePattern(context, benefit), options(std::move(options)) {}
580
581 LogicalResult matchAndRewrite(tensor::PadOp padOp,
582 PatternRewriter &rewriter) const override {
583 // 1a. Get the allowed list of dimensions to drop from the `options`.
584 SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp);
585 if (allowedUnitDims.empty()) {
586 return rewriter.notifyMatchFailure(
587 arg&: padOp, msg: "control function returns no allowed unit dims to prune");
588 }
589
590 if (padOp.getSourceType().getEncoding()) {
591 return rewriter.notifyMatchFailure(
592 arg&: padOp, msg: "cannot collapse dims of tensor with encoding");
593 }
594
595 // Fail for non-constant padding values. The body of the pad could
596 // depend on the padding indices and/or properties of the padded
597 // tensor so for now we fail.
598 // TODO: Support non-constant padding values.
599 Value paddingVal = padOp.getConstantPaddingValue();
600 if (!paddingVal) {
601 return rewriter.notifyMatchFailure(
602 arg&: padOp, msg: "unimplemented: non-constant padding value");
603 }
604
605 ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
606 int64_t padRank = sourceShape.size();
607
608 auto isStaticZero = [](OpFoldResult f) {
609 return getConstantIntValue(ofr: f) == 0;
610 };
611
612 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
613 allowedUnitDims.end());
614 llvm::SmallDenseSet<unsigned> unitDims;
615 SmallVector<int64_t> newShape;
616 SmallVector<OpFoldResult> newLowPad;
617 SmallVector<OpFoldResult> newHighPad;
618 for (const auto [dim, size, low, high] :
619 zip_equal(t: llvm::seq(Begin: static_cast<int64_t>(0), End: padRank), u&: sourceShape,
620 args: padOp.getMixedLowPad(), args: padOp.getMixedHighPad())) {
621 if (unitDimsFilter.contains(V: dim) && size == 1 && isStaticZero(low) &&
622 isStaticZero(high)) {
623 unitDims.insert(V: dim);
624 } else {
625 newShape.push_back(Elt: size);
626 newLowPad.push_back(Elt: low);
627 newHighPad.push_back(Elt: high);
628 }
629 }
630
631 if (unitDims.empty()) {
632 return rewriter.notifyMatchFailure(arg&: padOp, msg: "no unit dims to collapse");
633 }
634
635 ReassociationIndices reassociationGroup;
636 SmallVector<ReassociationIndices> reassociationMap;
637 int64_t dim = 0;
638 while (dim < padRank && unitDims.contains(V: dim))
639 reassociationGroup.push_back(Elt: dim++);
640 while (dim < padRank) {
641 assert(!unitDims.contains(dim) && "expected non unit-extent");
642 reassociationGroup.push_back(Elt: dim);
643 dim++;
644 // Fold all following dimensions that are unit-extent.
645 while (dim < padRank && unitDims.contains(V: dim))
646 reassociationGroup.push_back(Elt: dim++);
647 reassociationMap.push_back(Elt: reassociationGroup);
648 reassociationGroup.clear();
649 }
650
651 Value collapsedSource =
652 collapseValue(rewriter, loc: padOp.getLoc(), operand: padOp.getSource(), targetShape: newShape,
653 reassociation: reassociationMap, rankReductionStrategy: options.rankReductionStrategy);
654
655 auto newPadOp = rewriter.create<tensor::PadOp>(
656 location: padOp.getLoc(), /*result=*/args: Type(), args&: collapsedSource, args&: newLowPad,
657 args&: newHighPad, args&: paddingVal, args: padOp.getNofold());
658
659 Value dest = padOp.getResult();
660 if (options.rankReductionStrategy ==
661 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
662 SmallVector<OpFoldResult> expandedSizes;
663 int64_t numUnitDims = 0;
664 for (auto dim : llvm::seq(Begin: static_cast<int64_t>(0), End: padRank)) {
665 if (unitDims.contains(V: dim)) {
666 expandedSizes.push_back(Elt: rewriter.getIndexAttr(value: 1));
667 numUnitDims++;
668 continue;
669 }
670 expandedSizes.push_back(Elt: tensor::getMixedSize(
671 builder&: rewriter, loc: padOp.getLoc(), value: newPadOp, dim: dim - numUnitDims));
672 }
673 dest = rewriter.create<tensor::EmptyOp>(
674 location: padOp.getLoc(), args&: expandedSizes,
675 args: padOp.getResultType().getElementType());
676 }
677
678 Value expandedValue =
679 expandValue(rewriter, loc: padOp.getLoc(), result: newPadOp.getResult(), origDest: dest,
680 reassociation: reassociationMap, rankReductionStrategy: options.rankReductionStrategy);
681 rewriter.replaceOp(op: padOp, newValues: expandedValue);
682 return success();
683 }
684
685private:
686 ControlDropUnitDims options;
687};
688} // namespace
689
690namespace {
691/// Convert `extract_slice` operations to rank-reduced versions.
692struct RankReducedExtractSliceOp
693 : public OpRewritePattern<tensor::ExtractSliceOp> {
694 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
695
696 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
697 PatternRewriter &rewriter) const override {
698 RankedTensorType resultType = sliceOp.getType();
699 SmallVector<OpFoldResult> targetShape;
700 for (auto size : resultType.getShape())
701 targetShape.push_back(Elt: rewriter.getIndexAttr(value: size));
702 auto reassociation = getReassociationMapForFoldingUnitDims(mixedSizes: targetShape);
703 if (!reassociation ||
704 reassociation->size() == static_cast<size_t>(resultType.getRank()))
705 return failure();
706
707 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
708 SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
709 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
710 auto rankReducedType = cast<RankedTensorType>(
711 Val: tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
712 resultRank: reassociation->size(), sourceRankedTensorType: sliceOp.getSourceType(), staticOffsets: offsets, staticSizes: sizes,
713 staticStrides: strides));
714
715 Location loc = sliceOp.getLoc();
716 Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
717 location: loc, args&: rankReducedType, args: sliceOp.getSource(), args&: offsets, args&: sizes, args&: strides);
718 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
719 op: sliceOp, args&: resultType, args&: newSlice, args&: *reassociation);
720 return success();
721 }
722};
723
724/// Convert `insert_slice` operations to rank-reduced versions.
725/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
726template <typename InsertOpTy>
727struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
728 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
729
730 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
731 PatternRewriter &rewriter) const override {
732 RankedTensorType sourceType = insertSliceOp.getSourceType();
733 SmallVector<OpFoldResult> targetShape;
734 for (auto size : sourceType.getShape())
735 targetShape.push_back(Elt: rewriter.getIndexAttr(value: size));
736 auto reassociation = getReassociationMapForFoldingUnitDims(mixedSizes: targetShape);
737 if (!reassociation ||
738 reassociation->size() == static_cast<size_t>(sourceType.getRank()))
739 return failure();
740
741 Location loc = insertSliceOp.getLoc();
742 tensor::CollapseShapeOp reshapedSource;
743 {
744 OpBuilder::InsertionGuard g(rewriter);
745 // The only difference between InsertSliceOp and ParallelInsertSliceOp
746 // is the insertion point is just before the ParallelCombiningOp in the
747 // parallel case.
748 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
749 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
750 reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
751 loc, insertSliceOp.getSource(), *reassociation);
752 }
753 rewriter.replaceOpWithNewOp<InsertOpTy>(
754 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
755 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
756 insertSliceOp.getMixedStrides());
757 return success();
758 }
759};
760} // namespace
761
762/// Patterns that are used to canonicalize the use of unit-extent dims for
763/// broadcasting.
764static void
765populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
766 ControlDropUnitDims &options) {
767 auto *context = patterns.getContext();
768 patterns.add<DropUnitDims>(arg&: context, args&: options);
769 patterns.add<DropPadUnitDims>(arg&: context, args&: options);
770 // TODO: Patterns unrelated to unit dim folding should be factored out.
771 patterns.add<RankReducedExtractSliceOp,
772 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
773 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
774 arg&: context);
775 linalg::FillOp::getCanonicalizationPatterns(results&: patterns, context);
776 tensor::CollapseShapeOp::getCanonicalizationPatterns(results&: patterns, context);
777 tensor::EmptyOp::getCanonicalizationPatterns(results&: patterns, context);
778 tensor::ExpandShapeOp::getCanonicalizationPatterns(results&: patterns, context);
779 tensor::populateFoldTensorEmptyPatterns(patterns);
780 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
781 memref::populateResolveShapedTypeResultDimsPatterns(patterns);
782}
783
784static void
785populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
786 ControlDropUnitDims &options) {
787 auto *context = patterns.getContext();
788 patterns.add<DropUnitDims>(arg&: context, args&: options);
789 patterns.add<DropPadUnitDims>(arg&: context, args&: options);
790 // TODO: Patterns unrelated to unit dim folding should be factored out.
791 linalg::FillOp::getCanonicalizationPatterns(results&: patterns, context);
792 tensor::EmptyOp::getCanonicalizationPatterns(results&: patterns, context);
793 tensor::populateFoldTensorEmptyPatterns(patterns);
794 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
795 memref::populateResolveShapedTypeResultDimsPatterns(patterns);
796}
797
798void mlir::linalg::populateFoldUnitExtentDimsPatterns(
799 RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
800 if (options.rankReductionStrategy ==
801 linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
802 populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options);
803 } else if (options.rankReductionStrategy ==
804 linalg::ControlDropUnitDims::RankReductionStrategy::
805 ReassociativeReshape) {
806 populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options);
807 }
808}
809
810void mlir::linalg::populateMoveInitOperandsToInputPattern(
811 RewritePatternSet &patterns) {
812 patterns.add<MoveInitOperandsToInput>(arg: patterns.getContext());
813}
814
815namespace {
816/// Pass that removes unit-extent dims within generic ops.
817struct LinalgFoldUnitExtentDimsPass
818 : public impl::LinalgFoldUnitExtentDimsPassBase<
819 LinalgFoldUnitExtentDimsPass> {
820 using impl::LinalgFoldUnitExtentDimsPassBase<
821 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
822 void runOnOperation() override {
823 Operation *op = getOperation();
824 MLIRContext *context = op->getContext();
825 RewritePatternSet patterns(context);
826 ControlDropUnitDims options;
827 if (useRankReducingSlices) {
828 options.rankReductionStrategy = linalg::ControlDropUnitDims::
829 RankReductionStrategy::ExtractInsertSlice;
830 }
831 linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
832 populateMoveInitOperandsToInputPattern(patterns);
833 (void)applyPatternsGreedily(op, patterns: std::move(patterns));
834 }
835};
836
837} // namespace
838
839namespace {
840
841/// Returns reassociation indices for collapsing/expanding a
842/// tensor of rank `rank` at position `pos`.
843static SmallVector<ReassociationIndices>
844getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
845 SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
846 bool lastDim = pos == rank - 1;
847 if (rank > 2) {
848 for (int64_t i = 0; i < rank - 1; i++) {
849 if (i == pos || (lastDim && i == pos - 1))
850 reassociation[i] = ReassociationIndices{i, i + 1};
851 else if (i < pos)
852 reassociation[i] = ReassociationIndices{i};
853 else
854 reassociation[i] = ReassociationIndices{i + 1};
855 }
856 }
857 return reassociation;
858}
859
860/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
861/// If `pos < 0`, then don't collapse.
862static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
863 int64_t pos) {
864 if (pos < 0)
865 return val;
866 auto valType = cast<ShapedType>(Val: val.getType());
867 SmallVector<int64_t> collapsedShape(valType.getShape());
868 collapsedShape.erase(CI: collapsedShape.begin() + pos);
869 return collapseValue(
870 rewriter, loc: val.getLoc(), operand: val, targetShape: collapsedShape,
871 reassociation: getReassociationForReshapeAtDim(rank: valType.getRank(), pos),
872 rankReductionStrategy: ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
873}
874
875/// Base class for all rank reduction patterns for contraction ops
876/// with unit dimensions. All patterns should convert one named op
877/// to another named op. Intended to reduce only one iteration space dim
878/// at a time.
879/// Reducing multiple dims will happen with recusive application of
880/// pattern rewrites.
881template <typename FromOpTy, typename ToOpTy>
882struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
883 using OpRewritePattern<FromOpTy>::OpRewritePattern;
884
885 /// Collapse all collapsable operands.
886 SmallVector<Value>
887 collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
888 ArrayRef<int64_t> operandCollapseDims) const {
889 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
890 "expected 3 operands and dims");
891 return llvm::map_to_vector(
892 llvm::zip(t&: operands, u&: operandCollapseDims), [&](auto pair) {
893 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
894 std::get<1>(pair));
895 });
896 }
897
898 /// Expand result tensor.
899 Value expandResult(PatternRewriter &rewriter, Value result,
900 RankedTensorType expandedType, int64_t dim) const {
901 return rewriter.create<tensor::ExpandShapeOp>(
902 location: result.getLoc(), args&: expandedType, args&: result,
903 args: getReassociationForReshapeAtDim(rank: expandedType.getRank(), pos: dim));
904 }
905
906 LogicalResult matchAndRewrite(FromOpTy contractionOp,
907 PatternRewriter &rewriter) const override {
908 if (contractionOp.hasUserDefinedMaps()) {
909 return rewriter.notifyMatchFailure(
910 contractionOp, "ops with user-defined maps are not supported");
911 }
912
913 auto loc = contractionOp.getLoc();
914 auto inputs = contractionOp.getDpsInputs();
915 auto inits = contractionOp.getDpsInits();
916 if (inputs.size() != 2 || inits.size() != 1)
917 return rewriter.notifyMatchFailure(contractionOp,
918 "expected 2 inputs and 1 init");
919 auto lhs = inputs[0];
920 auto rhs = inputs[1];
921 auto init = inits[0];
922 SmallVector<Value> operands{lhs, rhs, init};
923
924 SmallVector<int64_t> operandUnitDims;
925 if (failed(getOperandUnitDims(op: contractionOp, operandUnitDims)))
926 return rewriter.notifyMatchFailure(contractionOp,
927 "no reducable dims found");
928
929 SmallVector<Value> collapsedOperands =
930 collapseOperands(rewriter, operands, operandCollapseDims: operandUnitDims);
931 Value collapsedLhs = collapsedOperands[0];
932 Value collapsedRhs = collapsedOperands[1];
933 Value collapsedInit = collapsedOperands[2];
934 SmallVector<Type, 1> collapsedResultTy;
935 if (isa<RankedTensorType>(Val: collapsedInit.getType()))
936 collapsedResultTy.push_back(Elt: collapsedInit.getType());
937 auto collapsedOp = rewriter.create<ToOpTy>(
938 loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
939 ValueRange{collapsedInit});
940 for (auto attr : contractionOp->getAttrs()) {
941 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
942 attr.getName() == "indexing_maps")
943 continue;
944 collapsedOp->setAttr(attr.getName(), attr.getValue());
945 }
946
947 auto results = contractionOp.getResults();
948 assert(results.size() < 2 && "expected at most one result");
949 if (results.empty()) {
950 rewriter.replaceOp(contractionOp, collapsedOp);
951 } else {
952 rewriter.replaceOp(
953 contractionOp,
954 expandResult(rewriter, result: collapsedOp.getResultTensors()[0],
955 expandedType: cast<RankedTensorType>(results[0].getType()),
956 dim: operandUnitDims[2]));
957 }
958
959 return success();
960 }
961
962 /// Populate `operandUnitDims` with 3 indices indicating the unit dim
963 /// for each operand that should be collapsed in this pattern. If an
964 /// operand shouldn't be collapsed, the index should be negative.
965 virtual LogicalResult
966 getOperandUnitDims(LinalgOp op,
967 SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
968};
969
970/// Patterns for unbatching batched contraction ops
971template <typename FromOpTy, typename ToOpTy>
972struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
973 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
974
975 /// Look for unit batch dims to collapse.
976 LogicalResult
977 getOperandUnitDims(LinalgOp op,
978 SmallVectorImpl<int64_t> &operandUnitDims) const override {
979 FailureOr<ContractionDimensions> maybeContractionDims =
980 inferContractionDims(linalgOp: op);
981 if (failed(Result: maybeContractionDims)) {
982 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
983 return failure();
984 }
985 ContractionDimensions contractionDims = maybeContractionDims.value();
986
987 if (contractionDims.batch.size() != 1)
988 return failure();
989 auto batchDim = contractionDims.batch[0];
990 SmallVector<std::pair<Value, unsigned>, 3> bOperands;
991 op.mapIterationSpaceDimToAllOperandDims(dimPos: batchDim, operandDimPairs&: bOperands);
992 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
993 return cast<ShapedType>(std::get<0>(pair).getType())
994 .getShape()[std::get<1>(pair)] != 1;
995 })) {
996 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
997 return failure();
998 }
999
1000 operandUnitDims = SmallVector<int64_t>{std::get<1>(in&: bOperands[0]),
1001 std::get<1>(in&: bOperands[1]),
1002 std::get<1>(in&: bOperands[2])};
1003 return success();
1004 }
1005};
1006
1007/// Patterns for reducing non-batch dimensions
1008template <typename FromOpTy, typename ToOpTy>
1009struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1010 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1011
1012 /// Helper for determining whether the lhs/init or rhs/init are reduced.
1013 static bool constexpr reduceLeft =
1014 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1015 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1016 (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1017 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1018 (std::is_same_v<FromOpTy, MatmulOp> &&
1019 std::is_same_v<ToOpTy, VecmatOp>) ||
1020 (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1021 std::is_same_v<ToOpTy, VecmatOp>) ||
1022 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1023
1024 /// Look for non-batch spatial dims to collapse.
1025 LogicalResult
1026 getOperandUnitDims(LinalgOp op,
1027 SmallVectorImpl<int64_t> &operandUnitDims) const override {
1028 FailureOr<ContractionDimensions> maybeContractionDims =
1029 inferContractionDims(linalgOp: op);
1030 if (failed(Result: maybeContractionDims)) {
1031 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1032 return failure();
1033 }
1034 ContractionDimensions contractionDims = maybeContractionDims.value();
1035
1036 if constexpr (reduceLeft) {
1037 auto m = contractionDims.m[0];
1038 SmallVector<std::pair<Value, unsigned>, 2> mOperands;
1039 op.mapIterationSpaceDimToAllOperandDims(dimPos: m, operandDimPairs&: mOperands);
1040 if (mOperands.size() != 2)
1041 return failure();
1042 if (llvm::all_of(mOperands, [](auto pair) {
1043 return cast<ShapedType>(std::get<0>(pair).getType())
1044 .getShape()[std::get<1>(pair)] == 1;
1045 })) {
1046 operandUnitDims = SmallVector<int64_t>{std::get<1>(in&: mOperands[0]), -1,
1047 std::get<1>(in&: mOperands[1])};
1048 return success();
1049 }
1050 } else {
1051 auto n = contractionDims.n[0];
1052 SmallVector<std::pair<Value, unsigned>, 2> nOperands;
1053 op.mapIterationSpaceDimToAllOperandDims(dimPos: n, operandDimPairs&: nOperands);
1054 if (nOperands.size() != 2)
1055 return failure();
1056 if (llvm::all_of(nOperands, [](auto pair) {
1057 return cast<ShapedType>(std::get<0>(pair).getType())
1058 .getShape()[std::get<1>(pair)] == 1;
1059 })) {
1060 operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(in&: nOperands[0]),
1061 std::get<1>(in&: nOperands[1])};
1062 return success();
1063 }
1064 }
1065 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1066 return failure();
1067 }
1068};
1069
1070} // namespace
1071
1072void mlir::linalg::populateContractionOpRankReducingPatterns(
1073 RewritePatternSet &patterns) {
1074 MLIRContext *context = patterns.getContext();
1075 // Unbatching patterns for unit batch size
1076 patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(arg&: context);
1077 patterns
1078 .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1079 arg&: context);
1080 patterns
1081 .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1082 arg&: context);
1083 patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(arg&: context);
1084 patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(arg&: context);
1085
1086 // Non-batch rank 1 reducing patterns
1087 patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(arg&: context);
1088 patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(arg&: context);
1089 patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(arg&: context);
1090 patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(arg&: context);
1091 // Batch rank 1 reducing patterns
1092 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(arg&: context);
1093 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(arg&: context);
1094 patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1095 arg&: context);
1096 patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1097 arg&: context);
1098
1099 // Non-batch rank 0 reducing patterns
1100 patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(arg&: context);
1101 patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(arg&: context);
1102}
1103

source code of mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp