1//===- TestExpandMath.cpp - Test expand math op into exp form -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains test passes for expanding math operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Math/Transforms/Passes.h"
15#include "mlir/Dialect/SCF/IR/SCF.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20using namespace mlir;
21
22namespace {
23struct TestExpandMathPass
24 : public PassWrapper<TestExpandMathPass, OperationPass<>> {
25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass)
26
27 void runOnOperation() override;
28 StringRef getArgument() const final { return "test-expand-math"; }
29 void getDependentDialects(DialectRegistry &registry) const override {
30 registry
31 .insert<arith::ArithDialect, scf::SCFDialect, vector::VectorDialect>();
32 }
33 StringRef getDescription() const final { return "Test expanding math"; }
34};
35} // namespace
36
37void TestExpandMathPass::runOnOperation() {
38 RewritePatternSet patterns(&getContext());
39 populateExpandCtlzPattern(patterns);
40 populateExpandExp2FPattern(patterns);
41 populateExpandTanPattern(patterns);
42 populateExpandSinhPattern(patterns);
43 populateExpandCoshPattern(patterns);
44 populateExpandTanhPattern(patterns);
45 populateExpandFmaFPattern(patterns);
46 populateExpandFloorFPattern(patterns);
47 populateExpandCeilFPattern(patterns);
48 populateExpandPowFPattern(patterns);
49 populateExpandFPowIPattern(patterns);
50 populateExpandRoundFPattern(patterns);
51 populateExpandRoundEvenPattern(patterns);
52 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
53}
54
55namespace mlir {
56namespace test {
57void registerTestExpandMathPass() { PassRegistration<TestExpandMathPass>(); }
58} // namespace test
59} // namespace mlir
60

source code of mlir/test/lib/Dialect/Math/TestExpandMath.cpp