1//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
10#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
11#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
12#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
13#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
14#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
15#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
16#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
17#include "mlir/Dialect/Arith/Transforms/Passes.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
20#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
21#include "mlir/Dialect/Vector/IR/VectorOps.h"
22#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
23#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Pass/PassOptions.h"
27#include "mlir/Rewrite/FrozenRewritePatternSet.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include <memory>
30
31#define DEBUG_TYPE "test-convert-to-spirv"
32
33using namespace mlir;
34
35namespace {
36
37/// Map memRef memory space to SPIR-V storage class.
38void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr) {
39 spirv::TargetEnv targetEnv(targetAttr);
40 bool targetEnvSupportsKernelCapability =
41 targetEnv.allows(spirv::Capability::Kernel);
42 spirv::MemorySpaceToStorageClassMap memorySpaceMap =
43 targetEnvSupportsKernelCapability
44 ? spirv::mapMemorySpaceToOpenCLStorageClass
45 : spirv::mapMemorySpaceToVulkanStorageClass;
46 spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
47 spirv::convertMemRefTypesAndAttrs(op, typeConverter&: converter);
48}
49
50/// Populate patterns for each dialect.
51void populateConvertToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
52 ScfToSPIRVContext &scfToSPIRVContext,
53 RewritePatternSet &patterns) {
54 arith::populateCeilFloorDivExpandOpsPatterns(patterns);
55 arith::populateArithToSPIRVPatterns(typeConverter, patterns);
56 populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
57 populateFuncToSPIRVPatterns(typeConverter, patterns);
58 populateGPUToSPIRVPatterns(typeConverter, patterns);
59 index::populateIndexToSPIRVPatterns(converter: typeConverter, patterns);
60 populateMemRefToSPIRVPatterns(typeConverter, patterns);
61 populateVectorToSPIRVPatterns(typeConverter, patterns);
62 populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
63 ub::populateUBToSPIRVConversionPatterns(converter: typeConverter, patterns);
64}
65
66/// A pass to perform the SPIR-V conversion.
67struct TestConvertToSPIRVPass final
68 : PassWrapper<TestConvertToSPIRVPass, OperationPass<>> {
69 Option<bool> runSignatureConversion{
70 *this, "run-signature-conversion",
71 llvm::cl::desc(
72 "Run function signature conversion to convert vector types"),
73 llvm::cl::init(Val: true)};
74 Option<bool> runVectorUnrolling{
75 *this, "run-vector-unrolling",
76 llvm::cl::desc(
77 "Run vector unrolling to convert vector types in function bodies"),
78 llvm::cl::init(Val: true)};
79 Option<bool> convertGPUModules{
80 *this, "convert-gpu-modules",
81 llvm::cl::desc("Clone and convert GPU modules"), llvm::cl::init(Val: false)};
82 Option<bool> nestInGPUModule{
83 *this, "nest-in-gpu-module",
84 llvm::cl::desc("Put converted SPIR-V module inside the gpu.module "
85 "instead of alongside it."),
86 llvm::cl::init(Val: false)};
87
88 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToSPIRVPass)
89
90 StringRef getArgument() const final { return "test-convert-to-spirv"; }
91 StringRef getDescription() const final {
92 return "Conversion to SPIR-V pass only used for internal tests.";
93 }
94 void getDependentDialects(DialectRegistry &registry) const override {
95 registry.insert<spirv::SPIRVDialect>();
96 registry.insert<vector::VectorDialect>();
97 }
98
99 TestConvertToSPIRVPass() = default;
100 TestConvertToSPIRVPass(bool convertGPUModules, bool nestInGPUModule) {
101 this->convertGPUModules = convertGPUModules;
102 this->nestInGPUModule = nestInGPUModule;
103 };
104 TestConvertToSPIRVPass(const TestConvertToSPIRVPass &) {}
105
106 void runOnOperation() override {
107 Operation *op = getOperation();
108 MLIRContext *context = &getContext();
109
110 // Unroll vectors in function signatures to native size.
111 if (runSignatureConversion && failed(Result: spirv::unrollVectorsInSignatures(op)))
112 return signalPassFailure();
113
114 // Unroll vectors in function bodies to native size.
115 if (runVectorUnrolling && failed(Result: spirv::unrollVectorsInFuncBodies(op)))
116 return signalPassFailure();
117
118 // Generic conversion.
119 if (!convertGPUModules) {
120 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
121 std::unique_ptr<ConversionTarget> target =
122 SPIRVConversionTarget::get(targetAttr);
123 SPIRVTypeConverter typeConverter(targetAttr);
124 RewritePatternSet patterns(context);
125 ScfToSPIRVContext scfToSPIRVContext;
126 mapToMemRef(op, targetAttr);
127 populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
128 patterns);
129 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
130 return signalPassFailure();
131 return;
132 }
133
134 // Clone each GPU kernel module for conversion, given that the GPU
135 // launch op still needs the original GPU kernel module.
136 SmallVector<Operation *, 1> gpuModules;
137 OpBuilder builder(context);
138 op->walk(callback: [&](gpu::GPUModuleOp gpuModule) {
139 if (nestInGPUModule)
140 builder.setInsertionPointToStart(gpuModule.getBody());
141 else
142 builder.setInsertionPoint(gpuModule);
143 gpuModules.push_back(Elt: builder.clone(*gpuModule));
144 });
145 // Run conversion for each module independently as they can have
146 // different TargetEnv attributes.
147 for (Operation *gpuModule : gpuModules) {
148 spirv::TargetEnvAttr targetAttr =
149 spirv::lookupTargetEnvOrDefault(op: gpuModule);
150 std::unique_ptr<ConversionTarget> target =
151 SPIRVConversionTarget::get(targetAttr);
152 SPIRVTypeConverter typeConverter(targetAttr);
153 RewritePatternSet patterns(context);
154 ScfToSPIRVContext scfToSPIRVContext;
155 mapToMemRef(op: gpuModule, targetAttr);
156 populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
157 patterns);
158 if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
159 return signalPassFailure();
160 }
161 }
162};
163
164} // namespace
165
166namespace mlir::test {
167void registerTestConvertToSPIRVPass() {
168 PassRegistration<TestConvertToSPIRVPass>();
169}
170std::unique_ptr<Pass> createTestConvertToSPIRVPass(bool convertGPUModules,
171 bool nestInGPUModule) {
172 return std::make_unique<TestConvertToSPIRVPass>(args&: convertGPUModules,
173 args&: nestInGPUModule);
174}
175} // namespace mlir::test
176

source code of mlir/test/lib/Pass/TestConvertToSPIRVPass.cpp