1//===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to map numeric MemRef memory spaces to
10// symbolic ones defined in the SPIR-V specification.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
15
16#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/BuiltinAttributes.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/Operation.h"
25#include "mlir/IR/Visitors.h"
26#include "mlir/Interfaces/FunctionInterfaces.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/ADT/StringExtras.h"
29#include "llvm/Support/Debug.h"
30#include <optional>
31
32namespace mlir {
33#define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
34#include "mlir/Conversion/Passes.h.inc"
35} // namespace mlir
36
37#define DEBUG_TYPE "mlir-map-memref-storage-class"
38
39using namespace mlir;
40
41//===----------------------------------------------------------------------===//
42// Mappings
43//===----------------------------------------------------------------------===//
44
45/// Mapping between SPIR-V storage classes to memref memory spaces.
46///
47/// Note: memref does not have a defined semantics for each memory space; it
48/// depends on the context where it is used. There are no particular reasons
49/// behind the number assignments; we try to follow NVVM conventions and largely
50/// give common storage classes a smaller number.
51#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
52 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
53 MAP_FN(spirv::StorageClass::Generic, 1) \
54 MAP_FN(spirv::StorageClass::Workgroup, 3) \
55 MAP_FN(spirv::StorageClass::Uniform, 4) \
56 MAP_FN(spirv::StorageClass::Private, 5) \
57 MAP_FN(spirv::StorageClass::Function, 6) \
58 MAP_FN(spirv::StorageClass::PushConstant, 7) \
59 MAP_FN(spirv::StorageClass::UniformConstant, 8) \
60 MAP_FN(spirv::StorageClass::Input, 9) \
61 MAP_FN(spirv::StorageClass::Output, 10) \
62 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11) \
63 MAP_FN(spirv::StorageClass::Image, 12)
64
65std::optional<spirv::StorageClass>
66spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
67 // Handle null memory space attribute specially.
68 if (!memorySpaceAttr)
69 return spirv::StorageClass::StorageBuffer;
70
71 // Unknown dialect custom attributes are not supported by default.
72 // Downstream callers should plug in more specialized ones.
73 auto intAttr = dyn_cast<IntegerAttr>(Val&: memorySpaceAttr);
74 if (!intAttr)
75 return std::nullopt;
76 unsigned memorySpace = intAttr.getInt();
77
78#define STORAGE_SPACE_MAP_FN(storage, space) \
79 case space: \
80 return storage;
81
82 switch (memorySpace) {
83 VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
84 default:
85 break;
86 }
87 return std::nullopt;
88
89#undef STORAGE_SPACE_MAP_FN
90}
91
92std::optional<unsigned>
93spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
94#define STORAGE_SPACE_MAP_FN(storage, space) \
95 case storage: \
96 return space;
97
98 switch (storageClass) {
99 VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
100 default:
101 break;
102 }
103 return std::nullopt;
104
105#undef STORAGE_SPACE_MAP_FN
106}
107
108#undef VULKAN_STORAGE_SPACE_MAP_LIST
109
110#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
111 MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
112 MAP_FN(spirv::StorageClass::Generic, 1) \
113 MAP_FN(spirv::StorageClass::Workgroup, 3) \
114 MAP_FN(spirv::StorageClass::UniformConstant, 4) \
115 MAP_FN(spirv::StorageClass::Private, 5) \
116 MAP_FN(spirv::StorageClass::Function, 6) \
117 MAP_FN(spirv::StorageClass::Image, 7)
118
119std::optional<spirv::StorageClass>
120spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) {
121 // Handle null memory space attribute specially.
122 if (!memorySpaceAttr)
123 return spirv::StorageClass::CrossWorkgroup;
124
125 // Unknown dialect custom attributes are not supported by default.
126 // Downstream callers should plug in more specialized ones.
127 auto intAttr = dyn_cast<IntegerAttr>(Val&: memorySpaceAttr);
128 if (!intAttr)
129 return std::nullopt;
130 unsigned memorySpace = intAttr.getInt();
131
132#define STORAGE_SPACE_MAP_FN(storage, space) \
133 case space: \
134 return storage;
135
136 switch (memorySpace) {
137 OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
138 default:
139 break;
140 }
141 return std::nullopt;
142
143#undef STORAGE_SPACE_MAP_FN
144}
145
146std::optional<unsigned>
147spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
148#define STORAGE_SPACE_MAP_FN(storage, space) \
149 case storage: \
150 return space;
151
152 switch (storageClass) {
153 OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
154 default:
155 break;
156 }
157 return std::nullopt;
158
159#undef STORAGE_SPACE_MAP_FN
160}
161
162#undef OPENCL_STORAGE_SPACE_MAP_LIST
163
164//===----------------------------------------------------------------------===//
165// Type Converter
166//===----------------------------------------------------------------------===//
167
168spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
169 const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
170 : memorySpaceMap(memorySpaceMap) {
171 // Pass through for all other types.
172 addConversion(callback: [](Type type) { return type; });
173
174 addConversion(callback: [this](BaseMemRefType memRefType) -> std::optional<Type> {
175 std::optional<spirv::StorageClass> storage =
176 this->memorySpaceMap(memRefType.getMemorySpace());
177 if (!storage) {
178 LLVM_DEBUG(llvm::dbgs()
179 << "cannot convert " << memRefType
180 << " due to being unable to find memory space in map\n");
181 return std::nullopt;
182 }
183
184 auto storageAttr =
185 spirv::StorageClassAttr::get(context: memRefType.getContext(), value: *storage);
186 if (auto rankedType = dyn_cast<MemRefType>(Val&: memRefType)) {
187 return MemRefType::get(shape: memRefType.getShape(), elementType: memRefType.getElementType(),
188 layout: rankedType.getLayout(), memorySpace: storageAttr);
189 }
190 return UnrankedMemRefType::get(elementType: memRefType.getElementType(), memorySpace: storageAttr);
191 });
192
193 addConversion(callback: [this](FunctionType type) {
194 auto inputs = llvm::map_to_vector(
195 C: type.getInputs(), F: [this](Type ty) { return convertType(t: ty); });
196 auto results = llvm::map_to_vector(
197 C: type.getResults(), F: [this](Type ty) { return convertType(t: ty); });
198 return FunctionType::get(context: type.getContext(), inputs, results);
199 });
200}
201
202//===----------------------------------------------------------------------===//
203// Conversion Target
204//===----------------------------------------------------------------------===//
205
206/// Returns true if the given `type` is considered as legal for SPIR-V
207/// conversion.
208static bool isLegalType(Type type) {
209 if (auto memRefType = dyn_cast<BaseMemRefType>(Val&: type)) {
210 Attribute spaceAttr = memRefType.getMemorySpace();
211 return isa_and_nonnull<spirv::StorageClassAttr>(Val: spaceAttr);
212 }
213 return true;
214}
215
216/// Returns true if the given `attr` is considered as legal for SPIR-V
217/// conversion.
218static bool isLegalAttr(Attribute attr) {
219 if (auto typeAttr = dyn_cast<TypeAttr>(Val&: attr))
220 return isLegalType(type: typeAttr.getValue());
221 return true;
222}
223
224/// Returns true if the given `op` is considered as legal for SPIR-V conversion.
225static bool isLegalOp(Operation *op) {
226 if (auto funcOp = dyn_cast<FunctionOpInterface>(Val: op)) {
227 return llvm::all_of(Range: funcOp.getArgumentTypes(), P: isLegalType) &&
228 llvm::all_of(Range: funcOp.getResultTypes(), P: isLegalType) &&
229 llvm::all_of(Range: funcOp.getFunctionBody().getArgumentTypes(),
230 P: isLegalType);
231 }
232
233 auto attrs = llvm::map_range(C: op->getAttrs(), F: [](const NamedAttribute &attr) {
234 return attr.getValue();
235 });
236
237 return llvm::all_of(Range: op->getOperandTypes(), P: isLegalType) &&
238 llvm::all_of(Range: op->getResultTypes(), P: isLegalType) &&
239 llvm::all_of(Range&: attrs, P: isLegalAttr);
240}
241
242std::unique_ptr<ConversionTarget>
243spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
244 auto target = std::make_unique<ConversionTarget>(args&: context);
245 target->markUnknownOpDynamicallyLegal(fn: isLegalOp);
246 return target;
247}
248
249void spirv::convertMemRefTypesAndAttrs(
250 Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
251 AttrTypeReplacer replacer;
252 replacer.addReplacement(callback: [&typeConverter](BaseMemRefType origType)
253 -> std::optional<BaseMemRefType> {
254 return typeConverter.convertType<BaseMemRefType>(t: origType);
255 });
256
257 replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
258 /*replaceLocs=*/false,
259 /*replaceTypes=*/true);
260}
261
262//===----------------------------------------------------------------------===//
263// Conversion Pass
264//===----------------------------------------------------------------------===//
265
266namespace {
267class MapMemRefStorageClassPass final
268 : public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
269public:
270 MapMemRefStorageClassPass() = default;
271
272 explicit MapMemRefStorageClassPass(
273 const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
274 : memorySpaceMap(memorySpaceMap) {}
275
276 LogicalResult initializeOptions(
277 StringRef options,
278 function_ref<LogicalResult(const Twine &)> errorHandler) override {
279 if (failed(Result: Pass::initializeOptions(options, errorHandler)))
280 return failure();
281
282 if (clientAPI == "opencl")
283 memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
284 else if (clientAPI != "vulkan")
285 return errorHandler(llvm::Twine("Invalid clienAPI: ") + clientAPI);
286
287 return success();
288 }
289
290 void runOnOperation() override {
291 MLIRContext *context = &getContext();
292 Operation *op = getOperation();
293
294 spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
295 if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
296 spirv::TargetEnv targetEnv(attr);
297 if (targetEnv.allows(spirv::Capability::Kernel)) {
298 spaceToStorage = spirv::mapMemorySpaceToOpenCLStorageClass;
299 } else if (targetEnv.allows(spirv::Capability::Shader)) {
300 spaceToStorage = spirv::mapMemorySpaceToVulkanStorageClass;
301 }
302 }
303
304 spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
305 // Perform the replacement.
306 spirv::convertMemRefTypesAndAttrs(op, typeConverter&: converter);
307
308 // Check if there are any illegal ops remaining.
309 std::unique_ptr<ConversionTarget> target =
310 spirv::getMemorySpaceToStorageClassTarget(context&: *context);
311 op->walk(callback: [&target, this](Operation *childOp) {
312 if (target->isIllegal(op: childOp)) {
313 childOp->emitOpError(message: "failed to legalize memory space");
314 signalPassFailure();
315 return WalkResult::interrupt();
316 }
317 return WalkResult::advance();
318 });
319 }
320
321private:
322 spirv::MemorySpaceToStorageClassMap memorySpaceMap =
323 spirv::mapMemorySpaceToVulkanStorageClass;
324};
325} // namespace
326
327std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
328 return std::make_unique<MapMemRefStorageClassPass>();
329}
330

source code of mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp