1//===- TestAllReduceLowering.cpp - Test gpu.all_reduce lowering -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains test passes for lowering the gpu.all_reduce op.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/GPU/Transforms/Passes.h"
16#include "mlir/Dialect/Index/IR/IndexDialect.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23using namespace mlir;
24
25namespace {
26struct TestGpuRewritePass
27 : public PassWrapper<TestGpuRewritePass, OperationPass<ModuleOp>> {
28 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGpuRewritePass)
29
30 void getDependentDialects(DialectRegistry &registry) const override {
31 registry.insert<arith::ArithDialect, func::FuncDialect, index::IndexDialect,
32 memref::MemRefDialect>();
33 }
34 StringRef getArgument() const final { return "test-gpu-rewrite"; }
35 StringRef getDescription() const final {
36 return "Applies all rewrite patterns within the GPU dialect.";
37 }
38 void runOnOperation() override {
39 RewritePatternSet patterns(&getContext());
40 populateGpuRewritePatterns(patterns);
41 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
42 }
43};
44
45struct TestGpuSubgroupReduceLoweringPass
46 : public PassWrapper<TestGpuSubgroupReduceLoweringPass,
47 OperationPass<ModuleOp>> {
48 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
49 TestGpuSubgroupReduceLoweringPass)
50
51 TestGpuSubgroupReduceLoweringPass() = default;
52 TestGpuSubgroupReduceLoweringPass(
53 const TestGpuSubgroupReduceLoweringPass &pass)
54 : PassWrapper(pass) {}
55
56 void getDependentDialects(DialectRegistry &registry) const override {
57 registry.insert<arith::ArithDialect, vector::VectorDialect>();
58 }
59
60 StringRef getArgument() const final {
61 return "test-gpu-subgroup-reduce-lowering";
62 }
63
64 StringRef getDescription() const final {
65 return "Applies gpu.subgroup_reduce lowering patterns.";
66 }
67
68 Option<bool> expandToShuffles{
69 *this, "expand-to-shuffles",
70 llvm::cl::desc("Expand subgroup_reduce ops to shuffle ops."),
71 llvm::cl::init(Val: false)};
72
73 void runOnOperation() override {
74 RewritePatternSet patterns(&getContext());
75
76 // Since both pattern sets match on the same ops, set higher benefit to
77 // perform fewer failing matches.
78 populateGpuBreakDownSubgrupReducePatterns(patterns,
79 /*maxShuffleBitwidth=*/32,
80 benefit: PatternBenefit(2));
81 if (expandToShuffles)
82 populateGpuLowerSubgroupReduceToShufflePattenrs(
83 patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
84
85 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
86 }
87};
88} // namespace
89
90namespace mlir {
91void registerTestGpuLoweringPasses() {
92 PassRegistration<TestGpuRewritePass>();
93 PassRegistration<TestGpuSubgroupReduceLoweringPass>();
94}
95} // namespace mlir
96

source code of mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp