1//===- TestVectorReductionToSPIRVDotProd.cpp - Test reduction to dot prod -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
13#include "mlir/Dialect/Vector/IR/VectorOps.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Pass/PassManager.h"
16#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17
18namespace mlir {
19namespace {
20
21struct TestVectorReductionToSPIRVDotProd
22 : PassWrapper<TestVectorReductionToSPIRVDotProd,
23 OperationPass<func::FuncOp>> {
24 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
25 TestVectorReductionToSPIRVDotProd)
26
27 StringRef getArgument() const final {
28 return "test-vector-reduction-to-spirv-dot-prod";
29 }
30
31 StringRef getDescription() const final {
32 return "Test lowering patterns that converts vector.reduction to SPIR-V "
33 "integer dot product ops";
34 }
35
36 void getDependentDialects(DialectRegistry &registry) const override {
37 registry.insert<arith::ArithDialect, func::FuncDialect, spirv::SPIRVDialect,
38 vector::VectorDialect>();
39 }
40
41 void runOnOperation() override {
42 RewritePatternSet patterns(&getContext());
43 populateVectorReductionToSPIRVDotProductPatterns(patterns);
44 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
45 }
46};
47
48} // namespace
49
50namespace test {
51void registerTestVectorReductionToSPIRVDotProd() {
52 PassRegistration<TestVectorReductionToSPIRVDotProd>();
53}
54} // namespace test
55} // namespace mlir
56

source code of mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp