1//===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to map numeric MemRef memory spaces to
10// symbolic ones defined in the SPIR-V specification.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
15
16#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/BuiltinAttributes.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/Operation.h"
25#include "mlir/IR/Visitors.h"
26#include "mlir/Interfaces/FunctionInterfaces.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/ADT/StringExtras.h"
29#include "llvm/Support/Debug.h"
30#include <optional>
31
32namespace mlir {
33#define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
34#include "mlir/Conversion/Passes.h.inc"
35} // namespace mlir
36
37#define DEBUG_TYPE "mlir-map-memref-storage-class"
38
39using namespace mlir;
40
41//===----------------------------------------------------------------------===//
42// Mappings
43//===----------------------------------------------------------------------===//
44
45/// Mapping between SPIR-V storage classes to memref memory spaces.
46///
47/// Note: memref does not have a defined semantics for each memory space; it
48/// depends on the context where it is used. There are no particular reasons
49/// behind the number assignments; we try to follow NVVM conventions and largely
50/// give common storage classes a smaller number.
51#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
52 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
53 MAP_FN(spirv::StorageClass::Generic, 1) \
54 MAP_FN(spirv::StorageClass::Workgroup, 3) \
55 MAP_FN(spirv::StorageClass::Uniform, 4) \
56 MAP_FN(spirv::StorageClass::Private, 5) \
57 MAP_FN(spirv::StorageClass::Function, 6) \
58 MAP_FN(spirv::StorageClass::PushConstant, 7) \
59 MAP_FN(spirv::StorageClass::UniformConstant, 8) \
60 MAP_FN(spirv::StorageClass::Input, 9) \
61 MAP_FN(spirv::StorageClass::Output, 10) \
62 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
63
64std::optional<spirv::StorageClass>
65spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
66 // Handle null memory space attribute specially.
67 if (!memorySpaceAttr)
68 return spirv::StorageClass::StorageBuffer;
69
70 // Unknown dialect custom attributes are not supported by default.
71 // Downstream callers should plug in more specialized ones.
72 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
73 if (!intAttr)
74 return std::nullopt;
75 unsigned memorySpace = intAttr.getInt();
76
77#define STORAGE_SPACE_MAP_FN(storage, space) \
78 case space: \
79 return storage;
80
81 switch (memorySpace) {
82 VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
83 default:
84 break;
85 }
86 return std::nullopt;
87
88#undef STORAGE_SPACE_MAP_FN
89}
90
91std::optional<unsigned>
92spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
93#define STORAGE_SPACE_MAP_FN(storage, space) \
94 case storage: \
95 return space;
96
97 switch (storageClass) {
98 VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
99 default:
100 break;
101 }
102 return std::nullopt;
103
104#undef STORAGE_SPACE_MAP_FN
105}
106
107#undef VULKAN_STORAGE_SPACE_MAP_LIST
108
109#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
110 MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
111 MAP_FN(spirv::StorageClass::Generic, 1) \
112 MAP_FN(spirv::StorageClass::Workgroup, 3) \
113 MAP_FN(spirv::StorageClass::UniformConstant, 4) \
114 MAP_FN(spirv::StorageClass::Private, 5) \
115 MAP_FN(spirv::StorageClass::Function, 6) \
116 MAP_FN(spirv::StorageClass::Image, 7)
117
118std::optional<spirv::StorageClass>
119spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) {
120 // Handle null memory space attribute specially.
121 if (!memorySpaceAttr)
122 return spirv::StorageClass::CrossWorkgroup;
123
124 // Unknown dialect custom attributes are not supported by default.
125 // Downstream callers should plug in more specialized ones.
126 auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
127 if (!intAttr)
128 return std::nullopt;
129 unsigned memorySpace = intAttr.getInt();
130
131#define STORAGE_SPACE_MAP_FN(storage, space) \
132 case space: \
133 return storage;
134
135 switch (memorySpace) {
136 OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
137 default:
138 break;
139 }
140 return std::nullopt;
141
142#undef STORAGE_SPACE_MAP_FN
143}
144
145std::optional<unsigned>
146spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
147#define STORAGE_SPACE_MAP_FN(storage, space) \
148 case storage: \
149 return space;
150
151 switch (storageClass) {
152 OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
153 default:
154 break;
155 }
156 return std::nullopt;
157
158#undef STORAGE_SPACE_MAP_FN
159}
160
161#undef OPENCL_STORAGE_SPACE_MAP_LIST
162
163//===----------------------------------------------------------------------===//
164// Type Converter
165//===----------------------------------------------------------------------===//
166
167spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
168 const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
169 : memorySpaceMap(memorySpaceMap) {
170 // Pass through for all other types.
171 addConversion([](Type type) { return type; });
172
173 addConversion([this](BaseMemRefType memRefType) -> std::optional<Type> {
174 std::optional<spirv::StorageClass> storage =
175 this->memorySpaceMap(memRefType.getMemorySpace());
176 if (!storage) {
177 LLVM_DEBUG(llvm::dbgs()
178 << "cannot convert " << memRefType
179 << " due to being unable to find memory space in map\n");
180 return std::nullopt;
181 }
182
183 auto storageAttr =
184 spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
185 if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
186 return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
187 rankedType.getLayout(), storageAttr);
188 }
189 return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr);
190 });
191
192 addConversion([this](FunctionType type) {
193 auto inputs = llvm::map_to_vector(
194 type.getInputs(), [this](Type ty) { return convertType(ty); });
195 auto results = llvm::map_to_vector(
196 type.getResults(), [this](Type ty) { return convertType(ty); });
197 return FunctionType::get(type.getContext(), inputs, results);
198 });
199}
200
201//===----------------------------------------------------------------------===//
202// Conversion Target
203//===----------------------------------------------------------------------===//
204
205/// Returns true if the given `type` is considered as legal for SPIR-V
206/// conversion.
207static bool isLegalType(Type type) {
208 if (auto memRefType = dyn_cast<BaseMemRefType>(Val&: type)) {
209 Attribute spaceAttr = memRefType.getMemorySpace();
210 return isa_and_nonnull<spirv::StorageClassAttr>(spaceAttr);
211 }
212 return true;
213}
214
215/// Returns true if the given `attr` is considered as legal for SPIR-V
216/// conversion.
217static bool isLegalAttr(Attribute attr) {
218 if (auto typeAttr = dyn_cast<TypeAttr>(attr))
219 return isLegalType(typeAttr.getValue());
220 return true;
221}
222
223/// Returns true if the given `op` is considered as legal for SPIR-V conversion.
224static bool isLegalOp(Operation *op) {
225 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
226 return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
227 llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
228 llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
229 isLegalType);
230 }
231
232 auto attrs = llvm::map_range(C: op->getAttrs(), F: [](const NamedAttribute &attr) {
233 return attr.getValue();
234 });
235
236 return llvm::all_of(Range: op->getOperandTypes(), P: isLegalType) &&
237 llvm::all_of(Range: op->getResultTypes(), P: isLegalType) &&
238 llvm::all_of(Range&: attrs, P: isLegalAttr);
239}
240
241std::unique_ptr<ConversionTarget>
242spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
243 auto target = std::make_unique<ConversionTarget>(args&: context);
244 target->markUnknownOpDynamicallyLegal(fn: isLegalOp);
245 return target;
246}
247
248void spirv::convertMemRefTypesAndAttrs(
249 Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
250 AttrTypeReplacer replacer;
251 replacer.addReplacement(callback: [&typeConverter](BaseMemRefType origType)
252 -> std::optional<BaseMemRefType> {
253 return typeConverter.convertType<BaseMemRefType>(origType);
254 });
255
256 replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
257 /*replaceLocs=*/false,
258 /*replaceTypes=*/true);
259}
260
261//===----------------------------------------------------------------------===//
262// Conversion Pass
263//===----------------------------------------------------------------------===//
264
265namespace {
266class MapMemRefStorageClassPass final
267 : public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
268public:
269 MapMemRefStorageClassPass() = default;
270
271 explicit MapMemRefStorageClassPass(
272 const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
273 : memorySpaceMap(memorySpaceMap) {}
274
275 LogicalResult initializeOptions(
276 StringRef options,
277 function_ref<LogicalResult(const Twine &)> errorHandler) override {
278 if (failed(Pass::initializeOptions(options, errorHandler)))
279 return failure();
280
281 if (clientAPI == "opencl")
282 memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
283 else if (clientAPI != "vulkan")
284 return errorHandler(llvm::Twine("Invalid clienAPI: ") + clientAPI);
285
286 return success();
287 }
288
289 void runOnOperation() override {
290 MLIRContext *context = &getContext();
291 Operation *op = getOperation();
292
293 spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
294 if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
295 spirv::TargetEnv targetEnv(attr);
296 if (targetEnv.allows(spirv::Capability::Kernel)) {
297 spaceToStorage = spirv::mapMemorySpaceToOpenCLStorageClass;
298 } else if (targetEnv.allows(spirv::Capability::Shader)) {
299 spaceToStorage = spirv::mapMemorySpaceToVulkanStorageClass;
300 }
301 }
302
303 spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
304 // Perform the replacement.
305 spirv::convertMemRefTypesAndAttrs(op, typeConverter&: converter);
306
307 // Check if there are any illegal ops remaining.
308 std::unique_ptr<ConversionTarget> target =
309 spirv::getMemorySpaceToStorageClassTarget(context&: *context);
310 op->walk(callback: [&target, this](Operation *childOp) {
311 if (target->isIllegal(op: childOp)) {
312 childOp->emitOpError(message: "failed to legalize memory space");
313 signalPassFailure();
314 return WalkResult::interrupt();
315 }
316 return WalkResult::advance();
317 });
318 }
319
320private:
321 spirv::MemorySpaceToStorageClassMap memorySpaceMap =
322 spirv::mapMemorySpaceToVulkanStorageClass;
323};
324} // namespace
325
326std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
327 return std::make_unique<MapMemRefStorageClassPass>();
328}
329

source code of mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp