1 | //===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass that unifies access of multiple aliased resources |
10 | // into access of one single resource. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h" |
15 | |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
19 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/IR/BuiltinAttributes.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/SymbolTable.h" |
24 | #include "mlir/Transforms/DialectConversion.h" |
25 | #include "llvm/ADT/DenseMap.h" |
26 | #include "llvm/ADT/STLExtras.h" |
27 | #include "llvm/Support/Debug.h" |
28 | #include <algorithm> |
29 | #include <iterator> |
30 | |
31 | namespace mlir { |
32 | namespace spirv { |
33 | #define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS |
34 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" |
35 | } // namespace spirv |
36 | } // namespace mlir |
37 | |
38 | #define DEBUG_TYPE "spirv-unify-aliased-resource" |
39 | |
40 | using namespace mlir; |
41 | |
42 | //===----------------------------------------------------------------------===// |
43 | // Utility functions |
44 | //===----------------------------------------------------------------------===// |
45 | |
46 | using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #) |
47 | using AliasedResourceMap = |
48 | DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>; |
49 | |
50 | /// Collects all aliased resources in the given SPIR-V `moduleOp`. |
51 | static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) { |
52 | AliasedResourceMap aliasedResources; |
53 | moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) { |
54 | if (varOp->getAttrOfType<UnitAttr>("aliased" )) { |
55 | std::optional<uint32_t> set = varOp.getDescriptorSet(); |
56 | std::optional<uint32_t> binding = varOp.getBinding(); |
57 | if (set && binding) |
58 | aliasedResources[{*set, *binding}].push_back(varOp); |
59 | } |
60 | }); |
61 | return aliasedResources; |
62 | } |
63 | |
64 | /// Returns the element type if the given `type` is a runtime array resource: |
65 | /// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type |
66 | /// otherwise. |
67 | static Type getRuntimeArrayElementType(Type type) { |
68 | auto ptrType = dyn_cast<spirv::PointerType>(Val&: type); |
69 | if (!ptrType) |
70 | return {}; |
71 | |
72 | auto structType = dyn_cast<spirv::StructType>(Val: ptrType.getPointeeType()); |
73 | if (!structType || structType.getNumElements() != 1) |
74 | return {}; |
75 | |
76 | auto rtArrayType = |
77 | dyn_cast<spirv::RuntimeArrayType>(Val: structType.getElementType(0)); |
78 | if (!rtArrayType) |
79 | return {}; |
80 | |
81 | return rtArrayType.getElementType(); |
82 | } |
83 | |
84 | /// Given a list of resource element `types`, returns the index of the canonical |
85 | /// resource that all resources should be unified into. Returns std::nullopt if |
86 | /// unable to unify. |
87 | static std::optional<int> |
88 | deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) { |
89 | // scalarNumBits: contains all resources' scalar types' bit counts. |
90 | // vectorNumBits: only contains resources whose element types are vectors. |
91 | // vectorIndices: each vector's original index in `types`. |
92 | SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices; |
93 | scalarNumBits.reserve(N: types.size()); |
94 | vectorNumBits.reserve(N: types.size()); |
95 | vectorIndices.reserve(N: types.size()); |
96 | |
97 | for (const auto &indexedTypes : llvm::enumerate(First&: types)) { |
98 | spirv::SPIRVType type = indexedTypes.value(); |
99 | assert(type.isScalarOrVector()); |
100 | if (auto vectorType = dyn_cast<VectorType>(type)) { |
101 | if (vectorType.getNumElements() % 2 != 0) |
102 | return std::nullopt; // Odd-sized vector has special layout |
103 | // requirements. |
104 | |
105 | std::optional<int64_t> numBytes = type.getSizeInBytes(); |
106 | if (!numBytes) |
107 | return std::nullopt; |
108 | |
109 | scalarNumBits.push_back( |
110 | Elt: vectorType.getElementType().getIntOrFloatBitWidth()); |
111 | vectorNumBits.push_back(Elt: *numBytes * 8); |
112 | vectorIndices.push_back(Elt: indexedTypes.index()); |
113 | } else { |
114 | scalarNumBits.push_back(Elt: type.getIntOrFloatBitWidth()); |
115 | } |
116 | } |
117 | |
118 | if (!vectorNumBits.empty()) { |
119 | // Choose the *vector* with the smallest bitwidth as the canonical resource, |
120 | // so that we can still keep vectorized load/store and avoid partial updates |
121 | // to large vectors. |
122 | auto *minVal = llvm::min_element(Range&: vectorNumBits); |
123 | // Make sure that the canonical resource's bitwidth is divisible by others. |
124 | // With out this, we cannot properly adjust the index later. |
125 | if (llvm::any_of(Range&: vectorNumBits, |
126 | P: [&](int bits) { return bits % *minVal != 0; })) |
127 | return std::nullopt; |
128 | |
129 | // Require all scalar type bit counts to be a multiple of the chosen |
130 | // vector's primitive type to avoid reading/writing subcomponents. |
131 | int index = vectorIndices[std::distance(first: vectorNumBits.begin(), last: minVal)]; |
132 | int baseNumBits = scalarNumBits[index]; |
133 | if (llvm::any_of(Range&: scalarNumBits, |
134 | P: [&](int bits) { return bits % baseNumBits != 0; })) |
135 | return std::nullopt; |
136 | |
137 | return index; |
138 | } |
139 | |
140 | // All element types are scalars. Then choose the smallest bitwidth as the |
141 | // cannonical resource to avoid subcomponent load/store. |
142 | auto *minVal = llvm::min_element(Range&: scalarNumBits); |
143 | if (llvm::any_of(Range&: scalarNumBits, |
144 | P: [minVal](int64_t bit) { return bit % *minVal != 0; })) |
145 | return std::nullopt; |
146 | return std::distance(first: scalarNumBits.begin(), last: minVal); |
147 | } |
148 | |
149 | static bool areSameBitwidthScalarType(Type a, Type b) { |
150 | return a.isIntOrFloat() && b.isIntOrFloat() && |
151 | a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth(); |
152 | } |
153 | |
154 | //===----------------------------------------------------------------------===// |
155 | // Analysis |
156 | //===----------------------------------------------------------------------===// |
157 | |
158 | namespace { |
159 | /// A class for analyzing aliased resources. |
160 | /// |
161 | /// Resources are expected to be spirv.GlobalVarible that has a descriptor set |
162 | /// and binding number. Such resources are of the type |
163 | /// `!spirv.ptr<!spirv.struct<...>>` per Vulkan requirements. |
164 | /// |
165 | /// Right now, we only support the case that there is a single runtime array |
166 | /// inside the struct. |
167 | class ResourceAliasAnalysis { |
168 | public: |
169 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis) |
170 | |
171 | explicit ResourceAliasAnalysis(Operation *); |
172 | |
173 | /// Returns true if the given `op` can be rewritten to use a canonical |
174 | /// resource. |
175 | bool shouldUnify(Operation *op) const; |
176 | |
177 | /// Returns all descriptors and their corresponding aliased resources. |
178 | const AliasedResourceMap &getResourceMap() const { return resourceMap; } |
179 | |
180 | /// Returns the canonical resource for the given descriptor/variable. |
181 | spirv::GlobalVariableOp |
182 | getCanonicalResource(const Descriptor &descriptor) const; |
183 | spirv::GlobalVariableOp |
184 | getCanonicalResource(spirv::GlobalVariableOp varOp) const; |
185 | |
186 | /// Returns the element type for the given variable. |
187 | spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const; |
188 | |
189 | private: |
190 | /// Given the descriptor and aliased resources bound to it, analyze whether we |
191 | /// can unify them and record if so. |
192 | void recordIfUnifiable(const Descriptor &descriptor, |
193 | ArrayRef<spirv::GlobalVariableOp> resources); |
194 | |
195 | /// Mapping from a descriptor to all aliased resources bound to it. |
196 | AliasedResourceMap resourceMap; |
197 | |
198 | /// Mapping from a descriptor to the chosen canonical resource. |
199 | DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap; |
200 | |
201 | /// Mapping from an aliased resource to its descriptor. |
202 | DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap; |
203 | |
204 | /// Mapping from an aliased resource to its element (scalar/vector) type. |
205 | DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap; |
206 | }; |
207 | } // namespace |
208 | |
209 | ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) { |
210 | // Collect all aliased resources first and put them into different sets |
211 | // according to the descriptor. |
212 | AliasedResourceMap aliasedResources = |
213 | collectAliasedResources(cast<spirv::ModuleOp>(root)); |
214 | |
215 | // For each resource set, analyze whether we can unify; if so, try to identify |
216 | // a canonical resource, whose element type has the largest bitwidth. |
217 | for (const auto &descriptorResource : aliasedResources) { |
218 | recordIfUnifiable(descriptorResource.first, descriptorResource.second); |
219 | } |
220 | } |
221 | |
222 | bool ResourceAliasAnalysis::shouldUnify(Operation *op) const { |
223 | if (!op) |
224 | return false; |
225 | |
226 | if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) { |
227 | auto canonicalOp = getCanonicalResource(varOp); |
228 | return canonicalOp && varOp != canonicalOp; |
229 | } |
230 | if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) { |
231 | auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>(); |
232 | auto *varOp = |
233 | SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()); |
234 | return shouldUnify(op: varOp); |
235 | } |
236 | |
237 | if (auto acOp = dyn_cast<spirv::AccessChainOp>(op)) |
238 | return shouldUnify(op: acOp.getBasePtr().getDefiningOp()); |
239 | if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) |
240 | return shouldUnify(op: loadOp.getPtr().getDefiningOp()); |
241 | if (auto storeOp = dyn_cast<spirv::StoreOp>(op)) |
242 | return shouldUnify(op: storeOp.getPtr().getDefiningOp()); |
243 | |
244 | return false; |
245 | } |
246 | |
247 | spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource( |
248 | const Descriptor &descriptor) const { |
249 | auto varIt = canonicalResourceMap.find(descriptor); |
250 | if (varIt == canonicalResourceMap.end()) |
251 | return {}; |
252 | return varIt->second; |
253 | } |
254 | |
255 | spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource( |
256 | spirv::GlobalVariableOp varOp) const { |
257 | auto descriptorIt = descriptorMap.find(varOp); |
258 | if (descriptorIt == descriptorMap.end()) |
259 | return {}; |
260 | return getCanonicalResource(descriptorIt->second); |
261 | } |
262 | |
263 | spirv::SPIRVType |
264 | ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const { |
265 | auto it = elementTypeMap.find(varOp); |
266 | if (it == elementTypeMap.end()) |
267 | return {}; |
268 | return it->second; |
269 | } |
270 | |
271 | void ResourceAliasAnalysis::recordIfUnifiable( |
272 | const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) { |
273 | // Collect the element types for all resources in the current set. |
274 | SmallVector<spirv::SPIRVType> elementTypes; |
275 | for (spirv::GlobalVariableOp resource : resources) { |
276 | Type elementType = getRuntimeArrayElementType(resource.getType()); |
277 | if (!elementType) |
278 | return; // Unexpected resource variable type. |
279 | |
280 | auto type = cast<spirv::SPIRVType>(elementType); |
281 | if (!type.isScalarOrVector()) |
282 | return; // Unexpected resource element type. |
283 | |
284 | elementTypes.push_back(type); |
285 | } |
286 | |
287 | std::optional<int> index = deduceCanonicalResource(types: elementTypes); |
288 | if (!index) |
289 | return; |
290 | |
291 | // Update internal data structures for later use. |
292 | resourceMap[descriptor].assign(resources.begin(), resources.end()); |
293 | canonicalResourceMap[descriptor] = resources[*index]; |
294 | for (const auto &resource : llvm::enumerate(resources)) { |
295 | descriptorMap[resource.value()] = descriptor; |
296 | elementTypeMap[resource.value()] = elementTypes[resource.index()]; |
297 | } |
298 | } |
299 | |
300 | //===----------------------------------------------------------------------===// |
301 | // Patterns |
302 | //===----------------------------------------------------------------------===// |
303 | |
304 | template <typename OpTy> |
305 | class ConvertAliasResource : public OpConversionPattern<OpTy> { |
306 | public: |
307 | ConvertAliasResource(const ResourceAliasAnalysis &analysis, |
308 | MLIRContext *context, PatternBenefit benefit = 1) |
309 | : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {} |
310 | |
311 | protected: |
312 | const ResourceAliasAnalysis &analysis; |
313 | }; |
314 | |
315 | struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> { |
316 | using ConvertAliasResource::ConvertAliasResource; |
317 | |
318 | LogicalResult |
319 | matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor, |
320 | ConversionPatternRewriter &rewriter) const override { |
321 | // Just remove the aliased resource. Users will be rewritten to use the |
322 | // canonical one. |
323 | rewriter.eraseOp(op: varOp); |
324 | return success(); |
325 | } |
326 | }; |
327 | |
328 | struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> { |
329 | using ConvertAliasResource::ConvertAliasResource; |
330 | |
331 | LogicalResult |
332 | matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor, |
333 | ConversionPatternRewriter &rewriter) const override { |
334 | // Rewrite the AddressOf op to get the address of the canoncical resource. |
335 | auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>(); |
336 | auto srcVarOp = cast<spirv::GlobalVariableOp>( |
337 | SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable())); |
338 | auto dstVarOp = analysis.getCanonicalResource(srcVarOp); |
339 | rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp); |
340 | return success(); |
341 | } |
342 | }; |
343 | |
344 | struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> { |
345 | using ConvertAliasResource::ConvertAliasResource; |
346 | |
347 | LogicalResult |
348 | matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, |
349 | ConversionPatternRewriter &rewriter) const override { |
350 | auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>(); |
351 | if (!addressOp) |
352 | return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op" ); |
353 | |
354 | auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>(); |
355 | auto srcVarOp = cast<spirv::GlobalVariableOp>( |
356 | SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable())); |
357 | auto dstVarOp = analysis.getCanonicalResource(srcVarOp); |
358 | |
359 | spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp); |
360 | spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp); |
361 | |
362 | if (srcElemType == dstElemType || |
363 | areSameBitwidthScalarType(a: srcElemType, b: dstElemType)) { |
364 | // We have the same bitwidth for source and destination element types. |
365 | // Thie indices keep the same. |
366 | rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( |
367 | acOp, adaptor.getBasePtr(), adaptor.getIndices()); |
368 | return success(); |
369 | } |
370 | |
371 | Location loc = acOp.getLoc(); |
372 | |
373 | if (srcElemType.isIntOrFloat() && isa<VectorType>(Val: dstElemType)) { |
374 | // The source indices are for a buffer with scalar element types. Rewrite |
375 | // them into a buffer with vector element types. We need to scale the last |
376 | // index for the vector as a whole, then add one level of index for inside |
377 | // the vector. |
378 | int srcNumBytes = *srcElemType.getSizeInBytes(); |
379 | int dstNumBytes = *dstElemType.getSizeInBytes(); |
380 | assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0); |
381 | |
382 | auto indices = llvm::to_vector<4>(acOp.getIndices()); |
383 | Value oldIndex = indices.back(); |
384 | Type indexType = oldIndex.getType(); |
385 | |
386 | int ratio = dstNumBytes / srcNumBytes; |
387 | auto ratioValue = rewriter.create<spirv::ConstantOp>( |
388 | loc, indexType, rewriter.getIntegerAttr(indexType, ratio)); |
389 | |
390 | indices.back() = |
391 | rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue); |
392 | indices.push_back( |
393 | rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue)); |
394 | |
395 | rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( |
396 | acOp, adaptor.getBasePtr(), indices); |
397 | return success(); |
398 | } |
399 | |
400 | if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) || |
401 | (isa<VectorType>(Val: srcElemType) && isa<VectorType>(Val: dstElemType))) { |
402 | // The source indices are for a buffer with larger bitwidth scalar/vector |
403 | // element types. Rewrite them into a buffer with smaller bitwidth element |
404 | // types. We only need to scale the last index. |
405 | int srcNumBytes = *srcElemType.getSizeInBytes(); |
406 | int dstNumBytes = *dstElemType.getSizeInBytes(); |
407 | assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0); |
408 | |
409 | auto indices = llvm::to_vector<4>(acOp.getIndices()); |
410 | Value oldIndex = indices.back(); |
411 | Type indexType = oldIndex.getType(); |
412 | |
413 | int ratio = srcNumBytes / dstNumBytes; |
414 | auto ratioValue = rewriter.create<spirv::ConstantOp>( |
415 | loc, indexType, rewriter.getIntegerAttr(indexType, ratio)); |
416 | |
417 | indices.back() = |
418 | rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue); |
419 | |
420 | rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( |
421 | acOp, adaptor.getBasePtr(), indices); |
422 | return success(); |
423 | } |
424 | |
425 | return rewriter.notifyMatchFailure( |
426 | acOp, "unsupported src/dst types for spirv.AccessChain" ); |
427 | } |
428 | }; |
429 | |
430 | struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> { |
431 | using ConvertAliasResource::ConvertAliasResource; |
432 | |
433 | LogicalResult |
434 | matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, |
435 | ConversionPatternRewriter &rewriter) const override { |
436 | auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType()); |
437 | auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType()); |
438 | auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType()); |
439 | auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType()); |
440 | |
441 | Location loc = loadOp.getLoc(); |
442 | auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr()); |
443 | if (srcElemType == dstElemType) { |
444 | rewriter.replaceOp(loadOp, newLoadOp->getResults()); |
445 | return success(); |
446 | } |
447 | |
448 | if (areSameBitwidthScalarType(srcElemType, dstElemType)) { |
449 | auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType, |
450 | newLoadOp.getValue()); |
451 | rewriter.replaceOp(loadOp, castOp->getResults()); |
452 | |
453 | return success(); |
454 | } |
455 | |
456 | if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) || |
457 | (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) { |
458 | // The source and destination have scalar types of different bitwidths, or |
459 | // vector types of different component counts. For such cases, we load |
460 | // multiple smaller bitwidth values and construct a larger bitwidth one. |
461 | |
462 | int srcNumBytes = *srcElemType.getSizeInBytes(); |
463 | int dstNumBytes = *dstElemType.getSizeInBytes(); |
464 | assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0); |
465 | int ratio = srcNumBytes / dstNumBytes; |
466 | if (ratio > 4) |
467 | return rewriter.notifyMatchFailure(loadOp, "more than 4 components" ); |
468 | |
469 | SmallVector<Value> components; |
470 | components.reserve(N: ratio); |
471 | components.push_back(Elt: newLoadOp); |
472 | |
473 | auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>(); |
474 | if (!acOp) |
475 | return rewriter.notifyMatchFailure(loadOp, "ptr not spirv.AccessChain" ); |
476 | |
477 | auto i32Type = rewriter.getI32Type(); |
478 | Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter); |
479 | auto indices = llvm::to_vector<4>(acOp.getIndices()); |
480 | for (int i = 1; i < ratio; ++i) { |
481 | // Load all subsequent components belonging to this element. |
482 | indices.back() = rewriter.create<spirv::IAddOp>( |
483 | loc, i32Type, indices.back(), oneValue); |
484 | auto componentAcOp = rewriter.create<spirv::AccessChainOp>( |
485 | loc, acOp.getBasePtr(), indices); |
486 | // Assuming little endian, this reads lower-ordered bits of the number |
487 | // to lower-numbered components of the vector. |
488 | components.push_back( |
489 | rewriter.create<spirv::LoadOp>(loc, componentAcOp)); |
490 | } |
491 | |
492 | // Create a vector of the components and then cast back to the larger |
493 | // bitwidth element type. For spirv.bitcast, the lower-numbered components |
494 | // of the vector map to lower-ordered bits of the larger bitwidth element |
495 | // type. |
496 | |
497 | Type vectorType = srcElemType; |
498 | if (!isa<VectorType>(srcElemType)) |
499 | vectorType = VectorType::get({ratio}, dstElemType); |
500 | |
501 | // If both the source and destination are vector types, we need to make |
502 | // sure the scalar type is the same for composite construction later. |
503 | if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType)) |
504 | if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) { |
505 | if (srcElemVecType.getElementType() != |
506 | dstElemVecType.getElementType()) { |
507 | int64_t count = |
508 | dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8); |
509 | |
510 | // Make sure not to create 1-element vectors, which are illegal in |
511 | // SPIR-V. |
512 | Type castType = srcElemVecType.getElementType(); |
513 | if (count > 1) |
514 | castType = VectorType::get({count}, castType); |
515 | |
516 | for (Value &c : components) |
517 | c = rewriter.create<spirv::BitcastOp>(loc, castType, c); |
518 | } |
519 | } |
520 | Value vectorValue = rewriter.create<spirv::CompositeConstructOp>( |
521 | loc, vectorType, components); |
522 | |
523 | if (!isa<VectorType>(srcElemType)) |
524 | vectorValue = |
525 | rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue); |
526 | rewriter.replaceOp(loadOp, vectorValue); |
527 | return success(); |
528 | } |
529 | |
530 | return rewriter.notifyMatchFailure( |
531 | loadOp, "unsupported src/dst types for spirv.Load" ); |
532 | } |
533 | }; |
534 | |
535 | struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> { |
536 | using ConvertAliasResource::ConvertAliasResource; |
537 | |
538 | LogicalResult |
539 | matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, |
540 | ConversionPatternRewriter &rewriter) const override { |
541 | auto srcElemType = |
542 | cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType(); |
543 | auto dstElemType = |
544 | cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType(); |
545 | if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) |
546 | return rewriter.notifyMatchFailure(storeOp, "not scalar type" ); |
547 | if (!areSameBitwidthScalarType(srcElemType, dstElemType)) |
548 | return rewriter.notifyMatchFailure(storeOp, "different bitwidth" ); |
549 | |
550 | Location loc = storeOp.getLoc(); |
551 | Value value = adaptor.getValue(); |
552 | if (srcElemType != dstElemType) |
553 | value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value); |
554 | rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(), |
555 | value, storeOp->getAttrs()); |
556 | return success(); |
557 | } |
558 | }; |
559 | |
560 | //===----------------------------------------------------------------------===// |
561 | // Pass |
562 | //===----------------------------------------------------------------------===// |
563 | |
564 | namespace { |
565 | class UnifyAliasedResourcePass final |
566 | : public spirv::impl::SPIRVUnifyAliasedResourcePassBase< |
567 | UnifyAliasedResourcePass> { |
568 | public: |
569 | explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) |
570 | : getTargetEnvFn(std::move(getTargetEnv)) {} |
571 | |
572 | void runOnOperation() override; |
573 | |
574 | private: |
575 | spirv::GetTargetEnvFn getTargetEnvFn; |
576 | }; |
577 | |
578 | void UnifyAliasedResourcePass::runOnOperation() { |
579 | spirv::ModuleOp moduleOp = getOperation(); |
580 | MLIRContext *context = &getContext(); |
581 | |
582 | if (getTargetEnvFn) { |
583 | // This pass is only needed for targeting WebGPU, Metal, or layering |
584 | // Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into |
585 | // WGSL or MSL. The translation has limitations. |
586 | spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp); |
587 | spirv::ClientAPI clientAPI = targetEnv.getClientAPI(); |
588 | bool isVulkanOnAppleDevices = |
589 | clientAPI == spirv::ClientAPI::Vulkan && |
590 | targetEnv.getVendorID() == spirv::Vendor::Apple; |
591 | if (clientAPI != spirv::ClientAPI::WebGPU && |
592 | clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices) |
593 | return; |
594 | } |
595 | |
596 | // Analyze aliased resources first. |
597 | ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>(); |
598 | |
599 | ConversionTarget target(*context); |
600 | target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp, |
601 | spirv::AccessChainOp, spirv::LoadOp, |
602 | spirv::StoreOp>( |
603 | [&analysis](Operation *op) { return !analysis.shouldUnify(op); }); |
604 | target.addLegalDialect<spirv::SPIRVDialect>(); |
605 | |
606 | // Run patterns to rewrite usages of non-canonical resources. |
607 | RewritePatternSet patterns(context); |
608 | patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain, |
609 | ConvertLoad, ConvertStore>(arg&: analysis, args&: context); |
610 | if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) |
611 | return signalPassFailure(); |
612 | |
613 | // Drop aliased attribute if we only have one single bound resource for a |
614 | // descriptor. We need to re-collect the map here given in the above the |
615 | // conversion is best effort; certain sets may not be converted. |
616 | AliasedResourceMap resourceMap = |
617 | collectAliasedResources(cast<spirv::ModuleOp>(moduleOp)); |
618 | for (const auto &dr : resourceMap) { |
619 | const auto &resources = dr.second; |
620 | if (resources.size() == 1) |
621 | resources.front()->removeAttr("aliased" ); |
622 | } |
623 | } |
624 | } // namespace |
625 | |
626 | std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>> |
627 | spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) { |
628 | return std::make_unique<UnifyAliasedResourcePass>(args: std::move(getTargetEnv)); |
629 | } |
630 | |