1//===- TestLinalgDropUnitDims.cpp - Test Linalg drop unit dims -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass for testing the transformation to drop unit
10// extent dimensions from `linalg.generic` operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20using namespace mlir;
21
22namespace {
23
24LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
25 linalg::GenericOp genericOp) {
26 linalg::ControlDropUnitDims options;
27 options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
28 return linalg::dropUnitDims(rewriter, genericOp, options);
29}
30
31struct TestLinalgDropUnitDims
32 : public PassWrapper<TestLinalgDropUnitDims, OperationPass<func::FuncOp>> {
33
34 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDropUnitDims)
35
36 TestLinalgDropUnitDims() = default;
37 TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass) = default;
38
39 void getDependentDialects(DialectRegistry &registry) const override {
40 registry.insert<linalg::LinalgDialect>();
41 }
42
43 StringRef getArgument() const final { return "test-linalg-drop-unit-dims"; }
44
45 StringRef getDescriptions() const {
46 return "Test transformation to drop unit-extent dims from Linalg "
47 "operations";
48 }
49
50 void runOnOperation() override {
51 MLIRContext *context = &this->getContext();
52 func::FuncOp funcOp = this->getOperation();
53 IRRewriter rewriter(context);
54 SmallVector<linalg::GenericOp> genericOps;
55 funcOp.walk(
56 [&](linalg::GenericOp genericOp) { genericOps.push_back(genericOp); });
57
58 for (auto genericOp : genericOps) {
59 rewriter.setInsertionPoint(genericOp);
60 (void)dropOutermostUnitDims(rewriter, genericOp);
61 }
62 }
63};
64} // namespace
65
66namespace mlir {
67namespace test {
68void registerTestLinalgDropUnitDims() {
69 PassRegistration<TestLinalgDropUnitDims>();
70}
71} // namespace test
72} // namespace mlir
73

source code of mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp