1//===- TestLinalgDropUnitDims.cpp - Test Linalg drop unit dims -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass for testing the transformation to drop unit
10// extent dimensions from `linalg.generic` operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20using namespace mlir;
21
22namespace {
23
24LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
25 linalg::GenericOp genericOp) {
26 linalg::ControlDropUnitDims options;
27 options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
28 FailureOr<linalg::DropUnitDimsResult> result =
29 linalg::dropUnitDims(rewriter, genericOp, options);
30 if (failed(result)) {
31 return failure();
32 }
33 rewriter.replaceOp(genericOp, result->replacements);
34 return success();
35}
36
37struct TestLinalgDropUnitDims
38 : public PassWrapper<TestLinalgDropUnitDims, OperationPass<func::FuncOp>> {
39
40 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDropUnitDims)
41
42 TestLinalgDropUnitDims() = default;
43 TestLinalgDropUnitDims(const TestLinalgDropUnitDims &pass) = default;
44
45 void getDependentDialects(DialectRegistry &registry) const override {
46 registry.insert<linalg::LinalgDialect>();
47 }
48
49 StringRef getArgument() const final { return "test-linalg-drop-unit-dims"; }
50
51 StringRef getDescriptions() const {
52 return "Test transformation to drop unit-extent dims from Linalg "
53 "operations";
54 }
55
56 void runOnOperation() override {
57 MLIRContext *context = &this->getContext();
58 func::FuncOp funcOp = this->getOperation();
59 IRRewriter rewriter(context);
60 SmallVector<linalg::GenericOp> genericOps;
61 funcOp.walk(
62 [&](linalg::GenericOp genericOp) { genericOps.push_back(genericOp); });
63
64 for (auto genericOp : genericOps) {
65 rewriter.setInsertionPoint(genericOp);
66 (void)dropOutermostUnitDims(rewriter, genericOp);
67 }
68 }
69};
70} // namespace
71
72namespace mlir {
73namespace test {
74void registerTestLinalgDropUnitDims() {
75 PassRegistration<TestLinalgDropUnitDims>();
76}
77} // namespace test
78} // namespace mlir
79

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp