1//===- TestLinalgRankReduceContractionOps.cpp -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass for testing rank reduing patterns for named
10// contraction ops with unit dims.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Pass/PassManager.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21using namespace mlir;
22
23namespace {
24
25struct TestLinalgRankReduceContractionOps
26 : public PassWrapper<TestLinalgRankReduceContractionOps,
27 OperationPass<func::FuncOp>> {
28 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
29 TestLinalgRankReduceContractionOps)
30
31 TestLinalgRankReduceContractionOps() = default;
32 TestLinalgRankReduceContractionOps(
33 const TestLinalgRankReduceContractionOps &pass)
34 : PassWrapper(pass) {}
35 void getDependentDialects(DialectRegistry &registry) const override {
36 registry.insert<affine::AffineDialect, linalg::LinalgDialect,
37 memref::MemRefDialect, tensor::TensorDialect>();
38 }
39 StringRef getArgument() const final {
40 return "test-linalg-rank-reduce-contraction-ops";
41 }
42 StringRef getDescription() const final {
43 return "Test Linalg rank reduce contraction ops with unit dims";
44 }
45
46 void runOnOperation() override {
47 MLIRContext *context = &this->getContext();
48 func::FuncOp funcOp = this->getOperation();
49
50 RewritePatternSet patterns(context);
51 linalg::populateContractionOpRankReducingPatterns(patterns);
52 if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
53 return signalPassFailure();
54 }
55};
56
57} // namespace
58
59namespace mlir {
60namespace test {
61void registerTestLinalgRankReduceContractionOps() {
62 PassRegistration<TestLinalgRankReduceContractionOps>();
63}
64} // namespace test
65} // namespace mlir
66

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp