1//===- TestMatchReduction.cpp - Test the match reduction utility ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a test pass for the match reduction utility.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Analysis/SliceAnalysis.h"
14#include "mlir/Interfaces/FunctionInterfaces.h"
15#include "mlir/Pass/Pass.h"
16
17using namespace mlir;
18
19namespace {
20
21void printReductionResult(Operation *redRegionOp, unsigned numOutput,
22 Value reducedValue,
23 ArrayRef<Operation *> combinerOps) {
24 if (reducedValue) {
25 redRegionOp->emitRemark(message: "Reduction found in output #") << numOutput << "!";
26 redRegionOp->emitRemark(message: "Reduced Value: ") << reducedValue;
27 for (Operation *combOp : combinerOps)
28 redRegionOp->emitRemark(message: "Combiner Op: ") << *combOp;
29
30 return;
31 }
32
33 redRegionOp->emitRemark(message: "Reduction NOT found in output #")
34 << numOutput << "!";
35}
36
37struct TestMatchReductionPass
38 : public PassWrapper<TestMatchReductionPass,
39 InterfacePass<FunctionOpInterface>> {
40 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMatchReductionPass)
41
42 StringRef getArgument() const final { return "test-match-reduction"; }
43 StringRef getDescription() const final {
44 return "Test the match reduction utility.";
45 }
46
47 void runOnOperation() override {
48 FunctionOpInterface func = getOperation();
49 func->emitRemark("Testing function");
50
51 func.walk<WalkOrder::PreOrder>([](Operation *op) {
52 if (isa<FunctionOpInterface>(Val: op))
53 return;
54
55 // Limit testing to ops with only one region.
56 if (op->getNumRegions() != 1)
57 return;
58
59 Region &region = op->getRegion(index: 0);
60 if (!region.hasOneBlock())
61 return;
62
63 // We expect all the tested region ops to have 1 input by default. The
64 // remaining arguments are assumed to be outputs/reductions and there must
65 // be at least one.
66 // TODO: Extend it to support more generic cases.
67 Block &regionEntry = region.front();
68 auto args = regionEntry.getArguments();
69 if (args.size() < 2)
70 return;
71
72 auto outputs = args.drop_front();
73 for (int i = 0, size = outputs.size(); i < size; ++i) {
74 SmallVector<Operation *, 4> combinerOps;
75 Value reducedValue = matchReduction(iterCarriedArgs: outputs, redPos: i, combinerOps);
76 printReductionResult(redRegionOp: op, numOutput: i, reducedValue, combinerOps);
77 }
78 });
79 }
80};
81
82} // namespace
83
84namespace mlir {
85namespace test {
86void registerTestMatchReductionPass() {
87 PassRegistration<TestMatchReductionPass>();
88}
89} // namespace test
90} // namespace mlir
91

source code of mlir/test/lib/Analysis/TestMatchReduction.cpp