1//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns/pass to remove usage of unit-extent dimensions
10// to specify broadcasting in favor of more canonical representation of the
11// computation
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Linalg/Passes.h"
16
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Linalg/IR/Linalg.h"
20#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21#include "mlir/Dialect/Linalg/Utils/Utils.h"
22#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
23#include "mlir/Dialect/Tensor/IR/Tensor.h"
24#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
25#include "mlir/Dialect/Tensor/Utils/Utils.h"
26#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
27#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/BuiltinTypes.h"
30#include "mlir/Transforms/FoldUtils.h"
31#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#include "llvm/ADT/SetVector.h"
33#include "llvm/Support/CommandLine.h"
34#include "llvm/Support/Debug.h"
35
36namespace mlir {
37#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
38#include "mlir/Dialect/Linalg/Passes.h.inc"
39} // namespace mlir
40
41#define DEBUG_TYPE "linalg-drop-unit-dims"
42
43using namespace mlir;
44using namespace mlir::linalg;
45
46namespace {
47/// Pattern to move init operands to ins when all the loops are parallel and
48/// blockArgument corresponding to init is used in the region. This is a fix-up
49/// when unit reduction dimensions are all folded away. In this context, it
50/// becomes a elementwise generic op. E.g., it converts
51///
52/// %0 = tensor.empty() : tensor<1x1xf32>
53/// %1 = linalg.fill
54/// ins(%cst : f32)
55/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
56/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
57/// affine_map<(d0) -> (0, d0)>],
58/// iterator_types = ["parallel"]}
59/// ins(%arg0 : tensor<1x?x1x1xf32>)
60/// outs(%1 : tensor<1x1xf32>) {
61/// ^bb0(%in: f32, %out: f32):
62/// %3 = arith.addf %in, %out : f32
63/// linalg.yield %3 : f32
64/// } -> tensor<1x1xf32>
65///
66/// into
67///
68/// %0 = tensor.empty() : tensor<1x1xf32>
69/// %1 = linalg.fill
70/// ins(%cst : f32)
71/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
72/// %2 = tensor.empty() : tensor<1x1xf32>
73/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
74/// affine_map<(d0) -> (0, d0)>,
75/// affine_map<(d0) -> (0, d0)>],
76/// iterator_types = ["parallel"]}
77/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
78/// outs(%2 : tensor<1x1xf32>) {
79/// ^bb0(%in: f32, %in_0: f32, %out: f32):
80/// %4 = arith.addf %in, %in_0 : f32
81/// linalg.yield %4 : f32
82/// } -> tensor<1x1xf32>
83struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
84 using OpRewritePattern<GenericOp>::OpRewritePattern;
85 LogicalResult matchAndRewrite(GenericOp genericOp,
86 PatternRewriter &rewriter) const override {
87 if (!genericOp.hasPureTensorSemantics())
88 return failure();
89 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
90 return failure();
91
92 auto outputOperands = genericOp.getDpsInitsMutable();
93 SetVector<OpOperand *> candidates;
94 for (OpOperand &op : outputOperands) {
95 if (genericOp.getMatchingBlockArgument(&op).use_empty())
96 continue;
97 candidates.insert(&op);
98 }
99
100 if (candidates.empty())
101 return failure();
102
103 // Compute the modified indexing maps.
104 int64_t origNumInput = genericOp.getNumDpsInputs();
105 SmallVector<Value> newInputOperands = genericOp.getDpsInputs();
106 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
107 SmallVector<AffineMap> newIndexingMaps;
108 newIndexingMaps.append(indexingMaps.begin(),
109 std::next(indexingMaps.begin(), origNumInput));
110 for (OpOperand *op : candidates) {
111 newInputOperands.push_back(op->get());
112 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
113 }
114 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
115 indexingMaps.end());
116
117 Location loc = genericOp.getLoc();
118 SmallVector<Value> newOutputOperands =
119 llvm::to_vector(genericOp.getDpsInits());
120 for (OpOperand *op : candidates) {
121 OpBuilder::InsertionGuard guard(rewriter);
122 rewriter.setInsertionPointAfterValue(op->get());
123 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
124 auto empty = rewriter.create<tensor::EmptyOp>(
125 loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType);
126
127 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
128 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
129 }
130
131 auto newOp = rewriter.create<GenericOp>(
132 loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
133 newIndexingMaps, genericOp.getIteratorTypesArray(),
134 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
135
136 OpBuilder::InsertionGuard guard(rewriter);
137 Region &region = newOp.getRegion();
138 Block *block = rewriter.createBlock(parent: &region);
139 IRMapping mapper;
140 for (auto bbarg : genericOp.getRegionInputArgs())
141 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
142
143 for (OpOperand *op : candidates) {
144 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
145 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
146 }
147
148 for (OpOperand &op : outputOperands) {
149 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
150 if (candidates.count(&op))
151 block->addArgument(bbarg.getType(), loc);
152 else
153 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
154 }
155
156 for (auto &op : genericOp.getBody()->getOperations()) {
157 rewriter.clone(op, mapper);
158 }
159 rewriter.replaceOp(genericOp, newOp.getResults());
160
161 return success();
162 }
163};
164} // namespace
165
166//===---------------------------------------------------------------------===//
167// Drop loops that are unit-extents within Linalg operations.
168//===---------------------------------------------------------------------===//
169
170/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
171/// broadcasting. For example,
172///
173/// ```mlir
174/// #accesses = [
175/// affine_map<(d0, d1) -> (0, d1)>,
176/// affine_map<(d0, d1) -> (d0, 0)>,
177/// affine_map<(d0, d1) -> (d0, d1)>
178/// ]
179///
180/// #trait = {
181/// indexing_maps = #accesses,
182/// iterator_types = ["parallel", "parallel"],
183/// library_call = "some_external_fn"
184/// }
185///
186/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
187/// tensor<5x5xf32>
188/// {
189/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
190/// tensor<5xf32> into tensor<1x5xf32>
191/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
192/// tensor<5xf32> into tensor<5x1xf32>
193/// %2 = linalg.generic #trait %0, %1 {
194/// ^bb0(%arg2: f32, %arg3: f32):
195/// %3 = arith.addf %arg2, %arg3 : f32
196/// linalg.yield %3 : f32
197/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
198/// return %2 : tensor<5x5xf32>
199/// }
200///
201/// would canonicalize to
202///
203/// ```mlir
204/// #accesses = [
205/// affine_map<(d0, d1) -> (d1)>,
206/// affine_map<(d0, d1) -> (d0)>,
207/// affine_map<(d0, d1) -> (d0, d1)>
208/// ]
209///
210/// #trait = {
211/// indexing_maps = #accesses,
212/// iterator_types = ["parallel", "parallel"],
213/// library_call = "some_external_fn"
214/// }
215///
216/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
217/// tensor<5x5xf32>
218/// {
219/// %0 = linalg.generic #trait %arg0, %arg1 {
220/// ^bb0(%arg2: f32, %arg3: f32):
221/// %3 = arith.addf %arg2, %arg3 : f32
222/// linalg.yield %3 : f32
223/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
224/// return %0 : tensor<5x5xf32>
225/// }
226
227/// Update the index accesses of linalg operations having index semantics.
228static void
229replaceUnitDimIndexOps(GenericOp genericOp,
230 const llvm::SmallDenseSet<unsigned> &unitDims,
231 RewriterBase &rewriter) {
232 for (IndexOp indexOp :
233 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
234 OpBuilder::InsertionGuard guard(rewriter);
235 rewriter.setInsertionPoint(indexOp);
236 if (unitDims.count(indexOp.getDim()) != 0) {
237 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
238 } else {
239 // Update the dimension of the index operation if needed.
240 unsigned droppedDims = llvm::count_if(
241 unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
242 if (droppedDims != 0)
243 rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
244 indexOp.getDim() - droppedDims);
245 }
246 }
247}
248
249/// Expand the given `value` so that the type matches the type of `origDest`.
250/// The `reassociation` is used when `rankReductionStrategy` is set to
251/// `RankReductionStrategy::ReassociativeReshape`.
252static Value
253expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
254 ArrayRef<ReassociationIndices> reassociation,
255 ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
256 // There are no results for memref outputs.
257 auto origResultType = cast<RankedTensorType>(origDest.getType());
258 if (rankReductionStrategy ==
259 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
260 unsigned rank = origResultType.getRank();
261 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
262 SmallVector<OpFoldResult> sizes =
263 tensor::getMixedSizes(builder&: rewriter, loc, value: origDest);
264 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
265 return rewriter.createOrFold<tensor::InsertSliceOp>(
266 loc, result, origDest, offsets, sizes, strides);
267 }
268
269 assert(rankReductionStrategy ==
270 ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
271 "unknown rank reduction strategy");
272 return rewriter
273 .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
274 .getResult();
275}
276
277/// Collapse the given `value` so that the type matches the type of
278/// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
279/// set to `RankReductionStrategy::ReassociativeReshape`.
280static Value collapseValue(
281 RewriterBase &rewriter, Location loc, Value operand,
282 ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
283 ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
284 if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
285 if (rankReductionStrategy ==
286 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
287 FailureOr<Value> rankReducingExtract =
288 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
289 targetShape);
290 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
291 return *rankReducingExtract;
292 }
293
294 assert(
295 rankReductionStrategy ==
296 ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
297 "unknown rank reduction strategy");
298 MemRefLayoutAttrInterface layout;
299 auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
300 layout, memrefType.getMemorySpace());
301 return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
302 reassociation);
303 }
304 if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
305 if (rankReductionStrategy ==
306 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
307 FailureOr<Value> rankReducingExtract =
308 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
309 targetShape);
310 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
311 return *rankReducingExtract;
312 }
313
314 assert(
315 rankReductionStrategy ==
316 ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
317 "unknown rank reduction strategy");
318 auto targetType =
319 RankedTensorType::get(targetShape, tensorType.getElementType());
320 return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
321 reassociation);
322 }
323 llvm_unreachable("unsupported operand type");
324}
325
326/// Compute the modified metadata for an operands of operation
327/// whose unit dims are being dropped. Return the new indexing map
328/// to use, the shape of the operand in the replacement op
329/// and the `reassocation` to use to go from original operand shape
330/// to modified operand shape.
331struct UnitExtentReplacementInfo {
332 AffineMap indexMap;
333 SmallVector<ReassociationIndices> reassociation;
334 SmallVector<int64_t> targetShape;
335};
336static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
337 MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
338 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
339 ArrayRef<AffineExpr> dimReplacements) {
340 UnitExtentReplacementInfo info;
341 ReassociationIndices reassociationGroup;
342 SmallVector<AffineExpr> newIndexExprs;
343 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
344 ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
345 ArrayRef<AffineExpr> exprs = indexingMap.getResults();
346
347 auto isUnitDim = [&](unsigned dim) {
348 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
349 unsigned oldPosition = dimExpr.getPosition();
350 return !oldDimsToNewDimsMap.count(oldPosition) &&
351 (operandShape[dim] == 1);
352 }
353 // Handle the other case where the shape is 1, and is accessed using a
354 // constant 0.
355 if (operandShape[dim] == 1) {
356 auto constAffineExpr = dyn_cast<AffineConstantExpr>(Val: exprs[dim]);
357 return constAffineExpr && constAffineExpr.getValue() == 0;
358 }
359 return false;
360 };
361
362 unsigned dim = 0;
363 while (dim < operandShape.size() && isUnitDim(dim))
364 reassociationGroup.push_back(Elt: dim++);
365 while (dim < operandShape.size()) {
366 assert(!isUnitDim(dim) && "expected non unit-extent");
367 reassociationGroup.push_back(Elt: dim);
368 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
369 newIndexExprs.push_back(Elt: newExpr);
370 info.targetShape.push_back(Elt: operandShape[dim]);
371 ++dim;
372 // Fold all following dimensions that are unit-extent.
373 while (dim < operandShape.size() && isUnitDim(dim)) {
374 reassociationGroup.push_back(Elt: dim++);
375 }
376 info.reassociation.push_back(Elt: reassociationGroup);
377 reassociationGroup.clear();
378 }
379 info.indexMap =
380 AffineMap::get(dimCount: oldDimsToNewDimsMap.size(), symbolCount: indexingMap.getNumSymbols(),
381 results: newIndexExprs, context);
382 return info;
383}
384
385FailureOr<DropUnitDimsResult>
386linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
387 const ControlDropUnitDims &options) {
388 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
389 if (indexingMaps.empty())
390 return failure();
391
392 // 1. Check if any of the iteration dimensions are unit-trip count. They will
393 // end up being unit-trip count if they are used to index into a unit-dim
394 // tensor/memref.
395 AffineMap invertedMap =
396 inversePermutation(map: concatAffineMaps(maps: indexingMaps, context: rewriter.getContext()));
397 if (!invertedMap) {
398 return rewriter.notifyMatchFailure(genericOp,
399 "invalid indexing maps for operation");
400 }
401 SmallVector<int64_t> dims = genericOp.getStaticShape();
402
403 // 1a. Get the allowed list of dimensions to drop from the `options`.
404 SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
405 if (allowedUnitDims.empty()) {
406 return rewriter.notifyMatchFailure(
407 genericOp, "control function returns no allowed unit dims to prune");
408 }
409 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
410 allowedUnitDims.end());
411 llvm::SmallDenseSet<unsigned> unitDims;
412 for (const auto &expr : enumerate(invertedMap.getResults())) {
413 if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
414 if (dims[dimExpr.getPosition()] == 1 &&
415 unitDimsFilter.count(expr.index()))
416 unitDims.insert(expr.index());
417 }
418 }
419
420 // 2. Compute the iterator types of the modified op by dropping the one-trip
421 // count loops.
422 SmallVector<utils::IteratorType> newIteratorTypes;
423 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
424 SmallVector<AffineExpr> dimReplacements;
425 unsigned newDims = 0;
426 for (auto [index, attr] :
427 llvm::enumerate(genericOp.getIteratorTypesArray())) {
428 if (unitDims.count(index)) {
429 dimReplacements.push_back(
430 getAffineConstantExpr(0, rewriter.getContext()));
431 } else {
432 newIteratorTypes.push_back(attr);
433 oldDimToNewDimMap[index] = newDims;
434 dimReplacements.push_back(
435 getAffineDimExpr(newDims, rewriter.getContext()));
436 newDims++;
437 }
438 }
439
440 // 3. For each of the operands, find the
441 // - modified affine map to use.
442 // - shape of the operands after the unit-dims are dropped.
443 // - the reassociation indices used to convert from the original
444 // operand type to modified operand (needed only when using reshapes
445 // for rank reduction strategy)
446 // Note that the indexing maps might need changing even if there are no
447 // unit dimensions that are dropped to handle cases where `0` is used to
448 // access a unit-extent tensor. Consider moving this out of this specific
449 // transformation as a stand-alone transformation. Kept here right now due
450 // to legacy.
451 SmallVector<AffineMap> newIndexingMaps;
452 SmallVector<SmallVector<ReassociationIndices>> reassociations;
453 SmallVector<SmallVector<int64_t>> targetShapes;
454 SmallVector<bool> collapsed;
455 auto hasCollapsibleType = [](OpOperand &operand) {
456 Type operandType = operand.get().getType();
457 if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
458 return memrefOperandType.getLayout().isIdentity();
459 }
460 if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
461 return tensorOperandType.getEncoding() == nullptr;
462 }
463 return false;
464 };
465 for (OpOperand &opOperand : genericOp->getOpOperands()) {
466 auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
467 ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
468 if (!hasCollapsibleType(opOperand)) {
469 AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
470 dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
471 newIndexingMaps.push_back(newIndexingMap);
472 targetShapes.push_back(llvm::to_vector(shape));
473 collapsed.push_back(false);
474 reassociations.push_back({});
475 continue;
476 }
477 auto replacementInfo = dropUnitExtentFromOperandMetadata(
478 rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
479 dimReplacements);
480 reassociations.push_back(replacementInfo.reassociation);
481 newIndexingMaps.push_back(replacementInfo.indexMap);
482 targetShapes.push_back(replacementInfo.targetShape);
483 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
484 indexingMap.getNumResults()));
485 }
486
487 // Abort if the indexing maps of the result operation are not invertible
488 // (i.e. not legal) or if no dimension was reduced.
489 if (newIndexingMaps == indexingMaps ||
490 !inversePermutation(
491 map: concatAffineMaps(maps: newIndexingMaps, context: rewriter.getContext())))
492 return failure();
493
494 Location loc = genericOp.getLoc();
495 // 4. For each of the operands, collapse the operand to convert
496 // from original shape to shape in the modified operation if needed,
497 // either through use of reshapes or rank-reducing slices as
498 // specified in `options`.
499 SmallVector<Value> newOperands;
500 for (OpOperand &opOperand : genericOp->getOpOperands()) {
501 int64_t idx = opOperand.getOperandNumber();
502 if (!collapsed[idx]) {
503 newOperands.push_back(opOperand.get());
504 continue;
505 }
506 newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
507 targetShapes[idx], reassociations[idx],
508 options.rankReductionStrategy));
509 }
510
511 // 5. Create the `linalg.generic` operation with the new operands,
512 // indexing maps, iterator types and result types.
513 ArrayRef<Value> newInputs =
514 ArrayRef<Value>(newOperands).take_front(N: genericOp.getNumDpsInputs());
515 ArrayRef<Value> newOutputs =
516 ArrayRef<Value>(newOperands).take_back(N: genericOp.getNumDpsInits());
517 SmallVector<Type> resultTypes;
518 resultTypes.reserve(N: genericOp.getNumResults());
519 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
520 resultTypes.push_back(newOutputs[i].getType());
521 GenericOp replacementOp =
522 rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
523 newIndexingMaps, newIteratorTypes);
524 rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
525 replacementOp.getRegion().begin());
526 // 5a. Replace `linalg.index` operations that refer to the dropped unit
527 // dimensions.
528 replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
529
530 // 6. If any result type changes, insert a reshape/slice to convert from the
531 // original type to the new type.
532 SmallVector<Value> resultReplacements;
533 for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
534 unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
535 Value origDest = genericOp.getDpsInitOperand(index)->get();
536 if (!collapsed[opOperandIndex]) {
537 resultReplacements.push_back(result);
538 continue;
539 }
540 Value expandedValue = expandValue(rewriter, loc, result, origDest,
541 reassociations[opOperandIndex],
542 options.rankReductionStrategy);
543 resultReplacements.push_back(expandedValue);
544 }
545
546 return DropUnitDimsResult{replacementOp, resultReplacements};
547}
548
549namespace {
550struct DropUnitDims : public OpRewritePattern<GenericOp> {
551 DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
552 PatternBenefit benefit = 1)
553 : OpRewritePattern(context, benefit), options(std::move(options)) {}
554
555 LogicalResult matchAndRewrite(GenericOp genericOp,
556 PatternRewriter &rewriter) const override {
557 FailureOr<DropUnitDimsResult> result =
558 dropUnitDims(rewriter, genericOp, options);
559 if (failed(Result: result)) {
560 return failure();
561 }
562 rewriter.replaceOp(genericOp, result->replacements);
563 return success();
564 }
565
566private:
567 ControlDropUnitDims options;
568};
569} // namespace
570
571//===---------------------------------------------------------------------===//
572// Drop dimensions that are unit-extents within tensor operations.
573//===---------------------------------------------------------------------===//
574
575namespace {
576struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
577 DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
578 PatternBenefit benefit = 1)
579 : OpRewritePattern(context, benefit), options(std::move(options)) {}
580
581 LogicalResult matchAndRewrite(tensor::PadOp padOp,
582 PatternRewriter &rewriter) const override {
583 // 1a. Get the allowed list of dimensions to drop from the `options`.
584 SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp);
585 if (allowedUnitDims.empty()) {
586 return rewriter.notifyMatchFailure(
587 padOp, "control function returns no allowed unit dims to prune");
588 }
589
590 if (padOp.getSourceType().getEncoding()) {
591 return rewriter.notifyMatchFailure(
592 padOp, "cannot collapse dims of tensor with encoding");
593 }
594
595 // Fail for non-constant padding values. The body of the pad could
596 // depend on the padding indices and/or properties of the padded
597 // tensor so for now we fail.
598 // TODO: Support non-constant padding values.
599 Value paddingVal = padOp.getConstantPaddingValue();
600 if (!paddingVal) {
601 return rewriter.notifyMatchFailure(
602 padOp, "unimplemented: non-constant padding value");
603 }
604
605 ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
606 int64_t padRank = sourceShape.size();
607
608 auto isStaticZero = [](OpFoldResult f) {
609 std::optional<int64_t> maybeInt = getConstantIntValue(ofr: f);
610 return maybeInt && *maybeInt == 0;
611 };
612
613 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
614 allowedUnitDims.end());
615 llvm::SmallDenseSet<unsigned> unitDims;
616 SmallVector<int64_t> newShape;
617 SmallVector<OpFoldResult> newLowPad;
618 SmallVector<OpFoldResult> newHighPad;
619 for (const auto [dim, size, low, high] :
620 zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
621 padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
622 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
623 isStaticZero(high)) {
624 unitDims.insert(dim);
625 } else {
626 newShape.push_back(size);
627 newLowPad.push_back(low);
628 newHighPad.push_back(high);
629 }
630 }
631
632 if (unitDims.empty()) {
633 return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse");
634 }
635
636 ReassociationIndices reassociationGroup;
637 SmallVector<ReassociationIndices> reassociationMap;
638 int64_t dim = 0;
639 while (dim < padRank && unitDims.contains(V: dim))
640 reassociationGroup.push_back(Elt: dim++);
641 while (dim < padRank) {
642 assert(!unitDims.contains(dim) && "expected non unit-extent");
643 reassociationGroup.push_back(Elt: dim);
644 dim++;
645 // Fold all following dimensions that are unit-extent.
646 while (dim < padRank && unitDims.contains(V: dim))
647 reassociationGroup.push_back(Elt: dim++);
648 reassociationMap.push_back(Elt: reassociationGroup);
649 reassociationGroup.clear();
650 }
651
652 Value collapsedSource =
653 collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
654 reassociationMap, options.rankReductionStrategy);
655
656 auto newPadOp = rewriter.create<tensor::PadOp>(
657 padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
658 newHighPad, paddingVal, padOp.getNofold());
659
660 Value dest = padOp.getResult();
661 if (options.rankReductionStrategy ==
662 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
663 SmallVector<OpFoldResult> expandedSizes;
664 int64_t numUnitDims = 0;
665 for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) {
666 if (unitDims.contains(dim)) {
667 expandedSizes.push_back(rewriter.getIndexAttr(1));
668 numUnitDims++;
669 continue;
670 }
671 expandedSizes.push_back(tensor::getMixedSize(
672 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
673 }
674 dest = rewriter.create<tensor::EmptyOp>(
675 padOp.getLoc(), expandedSizes,
676 padOp.getResultType().getElementType());
677 }
678
679 Value expandedValue =
680 expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
681 reassociationMap, options.rankReductionStrategy);
682 rewriter.replaceOp(padOp, expandedValue);
683 return success();
684 }
685
686private:
687 ControlDropUnitDims options;
688};
689} // namespace
690
691namespace {
692/// Convert `extract_slice` operations to rank-reduced versions.
693struct RankReducedExtractSliceOp
694 : public OpRewritePattern<tensor::ExtractSliceOp> {
695 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
696
697 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
698 PatternRewriter &rewriter) const override {
699 RankedTensorType resultType = sliceOp.getType();
700 SmallVector<OpFoldResult> targetShape;
701 for (auto size : resultType.getShape())
702 targetShape.push_back(rewriter.getIndexAttr(size));
703 auto reassociation = getReassociationMapForFoldingUnitDims(mixedSizes: targetShape);
704 if (!reassociation ||
705 reassociation->size() == static_cast<size_t>(resultType.getRank()))
706 return failure();
707
708 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
709 SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
710 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
711 auto rankReducedType = cast<RankedTensorType>(
712 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
713 reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
714 strides));
715
716 Location loc = sliceOp.getLoc();
717 Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
718 loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
719 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
720 sliceOp, resultType, newSlice, *reassociation);
721 return success();
722 }
723};
724
725/// Convert `insert_slice` operations to rank-reduced versions.
726/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
727template <typename InsertOpTy>
728struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
729 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
730
731 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
732 PatternRewriter &rewriter) const override {
733 RankedTensorType sourceType = insertSliceOp.getSourceType();
734 SmallVector<OpFoldResult> targetShape;
735 for (auto size : sourceType.getShape())
736 targetShape.push_back(rewriter.getIndexAttr(size));
737 auto reassociation = getReassociationMapForFoldingUnitDims(mixedSizes: targetShape);
738 if (!reassociation ||
739 reassociation->size() == static_cast<size_t>(sourceType.getRank()))
740 return failure();
741
742 Location loc = insertSliceOp.getLoc();
743 tensor::CollapseShapeOp reshapedSource;
744 {
745 OpBuilder::InsertionGuard g(rewriter);
746 // The only difference between InsertSliceOp and ParallelInsertSliceOp
747 // is the insertion point is just before the ParallelCombiningOp in the
748 // parallel case.
749 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
750 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
751 reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
752 loc, insertSliceOp.getSource(), *reassociation);
753 }
754 rewriter.replaceOpWithNewOp<InsertOpTy>(
755 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
756 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
757 insertSliceOp.getMixedStrides());
758 return success();
759 }
760};
761} // namespace
762
763/// Patterns that are used to canonicalize the use of unit-extent dims for
764/// broadcasting.
765static void
766populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
767 ControlDropUnitDims &options) {
768 auto *context = patterns.getContext();
769 patterns.add<DropUnitDims>(arg&: context, args&: options);
770 patterns.add<DropPadUnitDims>(arg&: context, args&: options);
771 // TODO: Patterns unrelated to unit dim folding should be factored out.
772 patterns.add<RankReducedExtractSliceOp,
773 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
774 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
775 context);
776 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
777 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
778 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
779 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
780 tensor::populateFoldTensorEmptyPatterns(patterns);
781 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
782 memref::populateResolveShapedTypeResultDimsPatterns(patterns);
783}
784
785static void
786populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
787 ControlDropUnitDims &options) {
788 auto *context = patterns.getContext();
789 patterns.add<DropUnitDims>(arg&: context, args&: options);
790 patterns.add<DropPadUnitDims>(arg&: context, args&: options);
791 // TODO: Patterns unrelated to unit dim folding should be factored out.
792 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
793 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
794 tensor::populateFoldTensorEmptyPatterns(patterns);
795 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
796 memref::populateResolveShapedTypeResultDimsPatterns(patterns);
797}
798
799void mlir::linalg::populateFoldUnitExtentDimsPatterns(
800 RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
801 if (options.rankReductionStrategy ==
802 linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
803 populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options);
804 } else if (options.rankReductionStrategy ==
805 linalg::ControlDropUnitDims::RankReductionStrategy::
806 ReassociativeReshape) {
807 populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options);
808 }
809}
810
811void mlir::linalg::populateMoveInitOperandsToInputPattern(
812 RewritePatternSet &patterns) {
813 patterns.add<MoveInitOperandsToInput>(arg: patterns.getContext());
814}
815
816namespace {
817/// Pass that removes unit-extent dims within generic ops.
818struct LinalgFoldUnitExtentDimsPass
819 : public impl::LinalgFoldUnitExtentDimsPassBase<
820 LinalgFoldUnitExtentDimsPass> {
821 using impl::LinalgFoldUnitExtentDimsPassBase<
822 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
823 void runOnOperation() override {
824 Operation *op = getOperation();
825 MLIRContext *context = op->getContext();
826 RewritePatternSet patterns(context);
827 ControlDropUnitDims options;
828 if (useRankReducingSlices) {
829 options.rankReductionStrategy = linalg::ControlDropUnitDims::
830 RankReductionStrategy::ExtractInsertSlice;
831 }
832 linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
833 populateMoveInitOperandsToInputPattern(patterns);
834 (void)applyPatternsGreedily(op, std::move(patterns));
835 }
836};
837
838} // namespace
839
840namespace {
841
842/// Returns reassociation indices for collapsing/expanding a
843/// tensor of rank `rank` at position `pos`.
844static SmallVector<ReassociationIndices>
845getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
846 SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
847 bool lastDim = pos == rank - 1;
848 if (rank > 2) {
849 for (int64_t i = 0; i < rank - 1; i++) {
850 if (i == pos || (lastDim && i == pos - 1))
851 reassociation[i] = ReassociationIndices{i, i + 1};
852 else if (i < pos)
853 reassociation[i] = ReassociationIndices{i};
854 else
855 reassociation[i] = ReassociationIndices{i + 1};
856 }
857 }
858 return reassociation;
859}
860
861/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
862/// If `pos < 0`, then don't collapse.
863static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
864 int64_t pos) {
865 if (pos < 0)
866 return val;
867 auto valType = cast<ShapedType>(val.getType());
868 SmallVector<int64_t> collapsedShape(valType.getShape());
869 collapsedShape.erase(CI: collapsedShape.begin() + pos);
870 return collapseValue(
871 rewriter, val.getLoc(), val, collapsedShape,
872 getReassociationForReshapeAtDim(valType.getRank(), pos),
873 ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
874}
875
876/// Base class for all rank reduction patterns for contraction ops
877/// with unit dimensions. All patterns should convert one named op
878/// to another named op. Intended to reduce only one iteration space dim
879/// at a time.
880/// Reducing multiple dims will happen with recusive application of
881/// pattern rewrites.
882template <typename FromOpTy, typename ToOpTy>
883struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
884 using OpRewritePattern<FromOpTy>::OpRewritePattern;
885
886 /// Collapse all collapsable operands.
887 SmallVector<Value>
888 collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
889 ArrayRef<int64_t> operandCollapseDims) const {
890 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
891 "expected 3 operands and dims");
892 return llvm::map_to_vector(
893 llvm::zip(t&: operands, u&: operandCollapseDims), [&](auto pair) {
894 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
895 std::get<1>(pair));
896 });
897 }
898
899 /// Expand result tensor.
900 Value expandResult(PatternRewriter &rewriter, Value result,
901 RankedTensorType expandedType, int64_t dim) const {
902 return rewriter.create<tensor::ExpandShapeOp>(
903 result.getLoc(), expandedType, result,
904 getReassociationForReshapeAtDim(expandedType.getRank(), dim));
905 }
906
907 LogicalResult matchAndRewrite(FromOpTy contractionOp,
908 PatternRewriter &rewriter) const override {
909 if (contractionOp.hasUserDefinedMaps()) {
910 return rewriter.notifyMatchFailure(
911 contractionOp, "ops with user-defined maps are not supported");
912 }
913
914 auto loc = contractionOp.getLoc();
915 auto inputs = contractionOp.getDpsInputs();
916 auto inits = contractionOp.getDpsInits();
917 if (inputs.size() != 2 || inits.size() != 1)
918 return rewriter.notifyMatchFailure(contractionOp,
919 "expected 2 inputs and 1 init");
920 auto lhs = inputs[0];
921 auto rhs = inputs[1];
922 auto init = inits[0];
923 SmallVector<Value> operands{lhs, rhs, init};
924
925 SmallVector<int64_t> operandUnitDims;
926 if (failed(getOperandUnitDims(op: contractionOp, operandUnitDims)))
927 return rewriter.notifyMatchFailure(contractionOp,
928 "no reducable dims found");
929
930 SmallVector<Value> collapsedOperands =
931 collapseOperands(rewriter, operands, operandCollapseDims: operandUnitDims);
932 Value collapsedLhs = collapsedOperands[0];
933 Value collapsedRhs = collapsedOperands[1];
934 Value collapsedInit = collapsedOperands[2];
935 SmallVector<Type, 1> collapsedResultTy;
936 if (isa<RankedTensorType>(Val: collapsedInit.getType()))
937 collapsedResultTy.push_back(Elt: collapsedInit.getType());
938 auto collapsedOp = rewriter.create<ToOpTy>(
939 loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
940 ValueRange{collapsedInit});
941 for (auto attr : contractionOp->getAttrs()) {
942 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
943 attr.getName() == "indexing_maps")
944 continue;
945 collapsedOp->setAttr(attr.getName(), attr.getValue());
946 }
947
948 auto results = contractionOp.getResults();
949 assert(results.size() < 2 && "expected at most one result");
950 if (results.empty()) {
951 rewriter.replaceOp(contractionOp, collapsedOp);
952 } else {
953 rewriter.replaceOp(
954 contractionOp,
955 expandResult(rewriter, result: collapsedOp.getResultTensors()[0],
956 expandedType: cast<RankedTensorType>(results[0].getType()),
957 dim: operandUnitDims[2]));
958 }
959
960 return success();
961 }
962
963 /// Populate `operandUnitDims` with 3 indices indicating the unit dim
964 /// for each operand that should be collapsed in this pattern. If an
965 /// operand shouldn't be collapsed, the index should be negative.
966 virtual LogicalResult
967 getOperandUnitDims(LinalgOp op,
968 SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
969};
970
971/// Patterns for unbatching batched contraction ops
972template <typename FromOpTy, typename ToOpTy>
973struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
974 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
975
976 /// Look for unit batch dims to collapse.
977 LogicalResult
978 getOperandUnitDims(LinalgOp op,
979 SmallVectorImpl<int64_t> &operandUnitDims) const override {
980 FailureOr<ContractionDimensions> maybeContractionDims =
981 inferContractionDims(op);
982 if (failed(Result: maybeContractionDims)) {
983 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
984 return failure();
985 }
986 ContractionDimensions contractionDims = maybeContractionDims.value();
987
988 if (contractionDims.batch.size() != 1)
989 return failure();
990 auto batchDim = contractionDims.batch[0];
991 SmallVector<std::pair<Value, unsigned>, 3> bOperands;
992 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
993 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
994 return cast<ShapedType>(std::get<0>(pair).getType())
995 .getShape()[std::get<1>(pair)] != 1;
996 })) {
997 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
998 return failure();
999 }
1000
1001 operandUnitDims = SmallVector<int64_t>{std::get<1>(in&: bOperands[0]),
1002 std::get<1>(in&: bOperands[1]),
1003 std::get<1>(in&: bOperands[2])};
1004 return success();
1005 }
1006};
1007
1008/// Patterns for reducing non-batch dimensions
1009template <typename FromOpTy, typename ToOpTy>
1010struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1011 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1012
1013 /// Helper for determining whether the lhs/init or rhs/init are reduced.
1014 static bool constexpr reduceLeft =
1015 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1016 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1017 (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1018 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1019 (std::is_same_v<FromOpTy, MatmulOp> &&
1020 std::is_same_v<ToOpTy, VecmatOp>) ||
1021 (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1022 std::is_same_v<ToOpTy, VecmatOp>) ||
1023 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1024
1025 /// Look for non-batch spatial dims to collapse.
1026 LogicalResult
1027 getOperandUnitDims(LinalgOp op,
1028 SmallVectorImpl<int64_t> &operandUnitDims) const override {
1029 FailureOr<ContractionDimensions> maybeContractionDims =
1030 inferContractionDims(op);
1031 if (failed(Result: maybeContractionDims)) {
1032 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1033 return failure();
1034 }
1035 ContractionDimensions contractionDims = maybeContractionDims.value();
1036
1037 if constexpr (reduceLeft) {
1038 auto m = contractionDims.m[0];
1039 SmallVector<std::pair<Value, unsigned>, 2> mOperands;
1040 op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1041 if (mOperands.size() != 2)
1042 return failure();
1043 if (llvm::all_of(mOperands, [](auto pair) {
1044 return cast<ShapedType>(std::get<0>(pair).getType())
1045 .getShape()[std::get<1>(pair)] == 1;
1046 })) {
1047 operandUnitDims = SmallVector<int64_t>{std::get<1>(in&: mOperands[0]), -1,
1048 std::get<1>(in&: mOperands[1])};
1049 return success();
1050 }
1051 } else {
1052 auto n = contractionDims.n[0];
1053 SmallVector<std::pair<Value, unsigned>, 2> nOperands;
1054 op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1055 if (nOperands.size() != 2)
1056 return failure();
1057 if (llvm::all_of(nOperands, [](auto pair) {
1058 return cast<ShapedType>(std::get<0>(pair).getType())
1059 .getShape()[std::get<1>(pair)] == 1;
1060 })) {
1061 operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(in&: nOperands[0]),
1062 std::get<1>(in&: nOperands[1])};
1063 return success();
1064 }
1065 }
1066 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1067 return failure();
1068 }
1069};
1070
1071} // namespace
1072
1073void mlir::linalg::populateContractionOpRankReducingPatterns(
1074 RewritePatternSet &patterns) {
1075 MLIRContext *context = patterns.getContext();
1076 // Unbatching patterns for unit batch size
1077 patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1078 patterns
1079 .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1080 context);
1081 patterns
1082 .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1083 context);
1084 patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1085 patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1086
1087 // Non-batch rank 1 reducing patterns
1088 patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1089 patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1090 patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
1091 patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
1092 // Batch rank 1 reducing patterns
1093 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1094 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1095 patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1096 context);
1097 patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1098 context);
1099
1100 // Non-batch rank 0 reducing patterns
1101 patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1102 patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
1103}
1104

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp