1//===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass that unifies access of multiple aliased resources
10// into access of one single resource.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
15
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "llvm/ADT/DenseMap.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/Support/Debug.h"
28#include <algorithm>
29#include <iterator>
30
31namespace mlir {
32namespace spirv {
33#define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
34#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
35} // namespace spirv
36} // namespace mlir
37
38#define DEBUG_TYPE "spirv-unify-aliased-resource"
39
40using namespace mlir;
41
42//===----------------------------------------------------------------------===//
43// Utility functions
44//===----------------------------------------------------------------------===//
45
46using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
47using AliasedResourceMap =
48 DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
49
50/// Collects all aliased resources in the given SPIR-V `moduleOp`.
51static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
52 AliasedResourceMap aliasedResources;
53 moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
54 if (varOp->getAttrOfType<UnitAttr>("aliased")) {
55 std::optional<uint32_t> set = varOp.getDescriptorSet();
56 std::optional<uint32_t> binding = varOp.getBinding();
57 if (set && binding)
58 aliasedResources[{*set, *binding}].push_back(varOp);
59 }
60 });
61 return aliasedResources;
62}
63
64/// Returns the element type if the given `type` is a runtime array resource:
65/// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
66/// otherwise.
67static Type getRuntimeArrayElementType(Type type) {
68 auto ptrType = dyn_cast<spirv::PointerType>(Val&: type);
69 if (!ptrType)
70 return {};
71
72 auto structType = dyn_cast<spirv::StructType>(Val: ptrType.getPointeeType());
73 if (!structType || structType.getNumElements() != 1)
74 return {};
75
76 auto rtArrayType =
77 dyn_cast<spirv::RuntimeArrayType>(Val: structType.getElementType(0));
78 if (!rtArrayType)
79 return {};
80
81 return rtArrayType.getElementType();
82}
83
84/// Given a list of resource element `types`, returns the index of the canonical
85/// resource that all resources should be unified into. Returns std::nullopt if
86/// unable to unify.
87static std::optional<int>
88deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
89 // scalarNumBits: contains all resources' scalar types' bit counts.
90 // vectorNumBits: only contains resources whose element types are vectors.
91 // vectorIndices: each vector's original index in `types`.
92 SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
93 scalarNumBits.reserve(N: types.size());
94 vectorNumBits.reserve(N: types.size());
95 vectorIndices.reserve(N: types.size());
96
97 for (const auto &indexedTypes : llvm::enumerate(First&: types)) {
98 spirv::SPIRVType type = indexedTypes.value();
99 assert(type.isScalarOrVector());
100 if (auto vectorType = dyn_cast<VectorType>(type)) {
101 if (vectorType.getNumElements() % 2 != 0)
102 return std::nullopt; // Odd-sized vector has special layout
103 // requirements.
104
105 std::optional<int64_t> numBytes = type.getSizeInBytes();
106 if (!numBytes)
107 return std::nullopt;
108
109 scalarNumBits.push_back(
110 Elt: vectorType.getElementType().getIntOrFloatBitWidth());
111 vectorNumBits.push_back(Elt: *numBytes * 8);
112 vectorIndices.push_back(Elt: indexedTypes.index());
113 } else {
114 scalarNumBits.push_back(Elt: type.getIntOrFloatBitWidth());
115 }
116 }
117
118 if (!vectorNumBits.empty()) {
119 // Choose the *vector* with the smallest bitwidth as the canonical resource,
120 // so that we can still keep vectorized load/store and avoid partial updates
121 // to large vectors.
122 auto *minVal = llvm::min_element(Range&: vectorNumBits);
123 // Make sure that the canonical resource's bitwidth is divisible by others.
124 // With out this, we cannot properly adjust the index later.
125 if (llvm::any_of(Range&: vectorNumBits,
126 P: [&](int bits) { return bits % *minVal != 0; }))
127 return std::nullopt;
128
129 // Require all scalar type bit counts to be a multiple of the chosen
130 // vector's primitive type to avoid reading/writing subcomponents.
131 int index = vectorIndices[std::distance(first: vectorNumBits.begin(), last: minVal)];
132 int baseNumBits = scalarNumBits[index];
133 if (llvm::any_of(Range&: scalarNumBits,
134 P: [&](int bits) { return bits % baseNumBits != 0; }))
135 return std::nullopt;
136
137 return index;
138 }
139
140 // All element types are scalars. Then choose the smallest bitwidth as the
141 // cannonical resource to avoid subcomponent load/store.
142 auto *minVal = llvm::min_element(Range&: scalarNumBits);
143 if (llvm::any_of(Range&: scalarNumBits,
144 P: [minVal](int64_t bit) { return bit % *minVal != 0; }))
145 return std::nullopt;
146 return std::distance(first: scalarNumBits.begin(), last: minVal);
147}
148
149static bool areSameBitwidthScalarType(Type a, Type b) {
150 return a.isIntOrFloat() && b.isIntOrFloat() &&
151 a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
152}
153
154//===----------------------------------------------------------------------===//
155// Analysis
156//===----------------------------------------------------------------------===//
157
158namespace {
159/// A class for analyzing aliased resources.
160///
161/// Resources are expected to be spirv.GlobalVarible that has a descriptor set
162/// and binding number. Such resources are of the type
163/// `!spirv.ptr<!spirv.struct<...>>` per Vulkan requirements.
164///
165/// Right now, we only support the case that there is a single runtime array
166/// inside the struct.
167class ResourceAliasAnalysis {
168public:
169 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)
170
171 explicit ResourceAliasAnalysis(Operation *);
172
173 /// Returns true if the given `op` can be rewritten to use a canonical
174 /// resource.
175 bool shouldUnify(Operation *op) const;
176
177 /// Returns all descriptors and their corresponding aliased resources.
178 const AliasedResourceMap &getResourceMap() const { return resourceMap; }
179
180 /// Returns the canonical resource for the given descriptor/variable.
181 spirv::GlobalVariableOp
182 getCanonicalResource(const Descriptor &descriptor) const;
183 spirv::GlobalVariableOp
184 getCanonicalResource(spirv::GlobalVariableOp varOp) const;
185
186 /// Returns the element type for the given variable.
187 spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
188
189private:
190 /// Given the descriptor and aliased resources bound to it, analyze whether we
191 /// can unify them and record if so.
192 void recordIfUnifiable(const Descriptor &descriptor,
193 ArrayRef<spirv::GlobalVariableOp> resources);
194
195 /// Mapping from a descriptor to all aliased resources bound to it.
196 AliasedResourceMap resourceMap;
197
198 /// Mapping from a descriptor to the chosen canonical resource.
199 DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
200
201 /// Mapping from an aliased resource to its descriptor.
202 DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
203
204 /// Mapping from an aliased resource to its element (scalar/vector) type.
205 DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
206};
207} // namespace
208
209ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
210 // Collect all aliased resources first and put them into different sets
211 // according to the descriptor.
212 AliasedResourceMap aliasedResources =
213 collectAliasedResources(cast<spirv::ModuleOp>(root));
214
215 // For each resource set, analyze whether we can unify; if so, try to identify
216 // a canonical resource, whose element type has the largest bitwidth.
217 for (const auto &descriptorResource : aliasedResources) {
218 recordIfUnifiable(descriptorResource.first, descriptorResource.second);
219 }
220}
221
222bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
223 if (!op)
224 return false;
225
226 if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
227 auto canonicalOp = getCanonicalResource(varOp);
228 return canonicalOp && varOp != canonicalOp;
229 }
230 if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
231 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
232 auto *varOp =
233 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
234 return shouldUnify(op: varOp);
235 }
236
237 if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
238 return shouldUnify(op: acOp.getBasePtr().getDefiningOp());
239 if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
240 return shouldUnify(op: loadOp.getPtr().getDefiningOp());
241 if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
242 return shouldUnify(op: storeOp.getPtr().getDefiningOp());
243
244 return false;
245}
246
247spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
248 const Descriptor &descriptor) const {
249 auto varIt = canonicalResourceMap.find(descriptor);
250 if (varIt == canonicalResourceMap.end())
251 return {};
252 return varIt->second;
253}
254
255spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
256 spirv::GlobalVariableOp varOp) const {
257 auto descriptorIt = descriptorMap.find(varOp);
258 if (descriptorIt == descriptorMap.end())
259 return {};
260 return getCanonicalResource(descriptorIt->second);
261}
262
263spirv::SPIRVType
264ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
265 auto it = elementTypeMap.find(varOp);
266 if (it == elementTypeMap.end())
267 return {};
268 return it->second;
269}
270
271void ResourceAliasAnalysis::recordIfUnifiable(
272 const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
273 // Collect the element types for all resources in the current set.
274 SmallVector<spirv::SPIRVType> elementTypes;
275 for (spirv::GlobalVariableOp resource : resources) {
276 Type elementType = getRuntimeArrayElementType(resource.getType());
277 if (!elementType)
278 return; // Unexpected resource variable type.
279
280 auto type = cast<spirv::SPIRVType>(elementType);
281 if (!type.isScalarOrVector())
282 return; // Unexpected resource element type.
283
284 elementTypes.push_back(type);
285 }
286
287 std::optional<int> index = deduceCanonicalResource(types: elementTypes);
288 if (!index)
289 return;
290
291 // Update internal data structures for later use.
292 resourceMap[descriptor].assign(resources.begin(), resources.end());
293 canonicalResourceMap[descriptor] = resources[*index];
294 for (const auto &resource : llvm::enumerate(resources)) {
295 descriptorMap[resource.value()] = descriptor;
296 elementTypeMap[resource.value()] = elementTypes[resource.index()];
297 }
298}
299
300//===----------------------------------------------------------------------===//
301// Patterns
302//===----------------------------------------------------------------------===//
303
304template <typename OpTy>
305class ConvertAliasResource : public OpConversionPattern<OpTy> {
306public:
307 ConvertAliasResource(const ResourceAliasAnalysis &analysis,
308 MLIRContext *context, PatternBenefit benefit = 1)
309 : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
310
311protected:
312 const ResourceAliasAnalysis &analysis;
313};
314
315struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
316 using ConvertAliasResource::ConvertAliasResource;
317
318 LogicalResult
319 matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
320 ConversionPatternRewriter &rewriter) const override {
321 // Just remove the aliased resource. Users will be rewritten to use the
322 // canonical one.
323 rewriter.eraseOp(op: varOp);
324 return success();
325 }
326};
327
328struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
329 using ConvertAliasResource::ConvertAliasResource;
330
331 LogicalResult
332 matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
333 ConversionPatternRewriter &rewriter) const override {
334 // Rewrite the AddressOf op to get the address of the canoncical resource.
335 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
336 auto srcVarOp = cast<spirv::GlobalVariableOp>(
337 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
338 auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
339 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
340 return success();
341 }
342};
343
344struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
345 using ConvertAliasResource::ConvertAliasResource;
346
347 LogicalResult
348 matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override {
350 auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
351 if (!addressOp)
352 return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
353
354 auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
355 auto srcVarOp = cast<spirv::GlobalVariableOp>(
356 SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
357 auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
358
359 spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
360 spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
361
362 if (srcElemType == dstElemType ||
363 areSameBitwidthScalarType(a: srcElemType, b: dstElemType)) {
364 // We have the same bitwidth for source and destination element types.
365 // Thie indices keep the same.
366 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
367 acOp, adaptor.getBasePtr(), adaptor.getIndices());
368 return success();
369 }
370
371 Location loc = acOp.getLoc();
372
373 if (srcElemType.isIntOrFloat() && isa<VectorType>(Val: dstElemType)) {
374 // The source indices are for a buffer with scalar element types. Rewrite
375 // them into a buffer with vector element types. We need to scale the last
376 // index for the vector as a whole, then add one level of index for inside
377 // the vector.
378 int srcNumBytes = *srcElemType.getSizeInBytes();
379 int dstNumBytes = *dstElemType.getSizeInBytes();
380 assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
381
382 auto indices = llvm::to_vector<4>(acOp.getIndices());
383 Value oldIndex = indices.back();
384 Type indexType = oldIndex.getType();
385
386 int ratio = dstNumBytes / srcNumBytes;
387 auto ratioValue = rewriter.create<spirv::ConstantOp>(
388 loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
389
390 indices.back() =
391 rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
392 indices.push_back(
393 rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
394
395 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
396 acOp, adaptor.getBasePtr(), indices);
397 return success();
398 }
399
400 if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
401 (isa<VectorType>(Val: srcElemType) && isa<VectorType>(Val: dstElemType))) {
402 // The source indices are for a buffer with larger bitwidth scalar/vector
403 // element types. Rewrite them into a buffer with smaller bitwidth element
404 // types. We only need to scale the last index.
405 int srcNumBytes = *srcElemType.getSizeInBytes();
406 int dstNumBytes = *dstElemType.getSizeInBytes();
407 assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
408
409 auto indices = llvm::to_vector<4>(acOp.getIndices());
410 Value oldIndex = indices.back();
411 Type indexType = oldIndex.getType();
412
413 int ratio = srcNumBytes / dstNumBytes;
414 auto ratioValue = rewriter.create<spirv::ConstantOp>(
415 loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
416
417 indices.back() =
418 rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
419
420 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
421 acOp, adaptor.getBasePtr(), indices);
422 return success();
423 }
424
425 return rewriter.notifyMatchFailure(
426 acOp, "unsupported src/dst types for spirv.AccessChain");
427 }
428};
429
430struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
431 using ConvertAliasResource::ConvertAliasResource;
432
433 LogicalResult
434 matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
435 ConversionPatternRewriter &rewriter) const override {
436 auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
437 auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
438 auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
439 auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
440
441 Location loc = loadOp.getLoc();
442 auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
443 if (srcElemType == dstElemType) {
444 rewriter.replaceOp(loadOp, newLoadOp->getResults());
445 return success();
446 }
447
448 if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
449 auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
450 newLoadOp.getValue());
451 rewriter.replaceOp(loadOp, castOp->getResults());
452
453 return success();
454 }
455
456 if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
457 (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
458 // The source and destination have scalar types of different bitwidths, or
459 // vector types of different component counts. For such cases, we load
460 // multiple smaller bitwidth values and construct a larger bitwidth one.
461
462 int srcNumBytes = *srcElemType.getSizeInBytes();
463 int dstNumBytes = *dstElemType.getSizeInBytes();
464 assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
465 int ratio = srcNumBytes / dstNumBytes;
466 if (ratio > 4)
467 return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
468
469 SmallVector<Value> components;
470 components.reserve(N: ratio);
471 components.push_back(Elt: newLoadOp);
472
473 auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
474 if (!acOp)
475 return rewriter.notifyMatchFailure(loadOp, "ptr not spirv.AccessChain");
476
477 auto i32Type = rewriter.getI32Type();
478 Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
479 auto indices = llvm::to_vector<4>(acOp.getIndices());
480 for (int i = 1; i < ratio; ++i) {
481 // Load all subsequent components belonging to this element.
482 indices.back() = rewriter.create<spirv::IAddOp>(
483 loc, i32Type, indices.back(), oneValue);
484 auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
485 loc, acOp.getBasePtr(), indices);
486 // Assuming little endian, this reads lower-ordered bits of the number
487 // to lower-numbered components of the vector.
488 components.push_back(
489 rewriter.create<spirv::LoadOp>(loc, componentAcOp));
490 }
491
492 // Create a vector of the components and then cast back to the larger
493 // bitwidth element type. For spirv.bitcast, the lower-numbered components
494 // of the vector map to lower-ordered bits of the larger bitwidth element
495 // type.
496
497 Type vectorType = srcElemType;
498 if (!isa<VectorType>(srcElemType))
499 vectorType = VectorType::get({ratio}, dstElemType);
500
501 // If both the source and destination are vector types, we need to make
502 // sure the scalar type is the same for composite construction later.
503 if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
504 if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
505 if (srcElemVecType.getElementType() !=
506 dstElemVecType.getElementType()) {
507 int64_t count =
508 dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
509
510 // Make sure not to create 1-element vectors, which are illegal in
511 // SPIR-V.
512 Type castType = srcElemVecType.getElementType();
513 if (count > 1)
514 castType = VectorType::get({count}, castType);
515
516 for (Value &c : components)
517 c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
518 }
519 }
520 Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
521 loc, vectorType, components);
522
523 if (!isa<VectorType>(srcElemType))
524 vectorValue =
525 rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
526 rewriter.replaceOp(loadOp, vectorValue);
527 return success();
528 }
529
530 return rewriter.notifyMatchFailure(
531 loadOp, "unsupported src/dst types for spirv.Load");
532 }
533};
534
535struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
536 using ConvertAliasResource::ConvertAliasResource;
537
538 LogicalResult
539 matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter) const override {
541 auto srcElemType =
542 cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
543 auto dstElemType =
544 cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
545 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
546 return rewriter.notifyMatchFailure(storeOp, "not scalar type");
547 if (!areSameBitwidthScalarType(srcElemType, dstElemType))
548 return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
549
550 Location loc = storeOp.getLoc();
551 Value value = adaptor.getValue();
552 if (srcElemType != dstElemType)
553 value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
554 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
555 value, storeOp->getAttrs());
556 return success();
557 }
558};
559
560//===----------------------------------------------------------------------===//
561// Pass
562//===----------------------------------------------------------------------===//
563
564namespace {
565class UnifyAliasedResourcePass final
566 : public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
567 UnifyAliasedResourcePass> {
568public:
569 explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)
570 : getTargetEnvFn(std::move(getTargetEnv)) {}
571
572 void runOnOperation() override;
573
574private:
575 spirv::GetTargetEnvFn getTargetEnvFn;
576};
577
578void UnifyAliasedResourcePass::runOnOperation() {
579 spirv::ModuleOp moduleOp = getOperation();
580 MLIRContext *context = &getContext();
581
582 if (getTargetEnvFn) {
583 // This pass is only needed for targeting WebGPU, Metal, or layering
584 // Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into
585 // WGSL or MSL. The translation has limitations.
586 spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
587 spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
588 bool isVulkanOnAppleDevices =
589 clientAPI == spirv::ClientAPI::Vulkan &&
590 targetEnv.getVendorID() == spirv::Vendor::Apple;
591 if (clientAPI != spirv::ClientAPI::WebGPU &&
592 clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
593 return;
594 }
595
596 // Analyze aliased resources first.
597 ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
598
599 ConversionTarget target(*context);
600 target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
601 spirv::AccessChainOp, spirv::LoadOp,
602 spirv::StoreOp>(
603 [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
604 target.addLegalDialect<spirv::SPIRVDialect>();
605
606 // Run patterns to rewrite usages of non-canonical resources.
607 RewritePatternSet patterns(context);
608 patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
609 ConvertLoad, ConvertStore>(arg&: analysis, args&: context);
610 if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
611 return signalPassFailure();
612
613 // Drop aliased attribute if we only have one single bound resource for a
614 // descriptor. We need to re-collect the map here given in the above the
615 // conversion is best effort; certain sets may not be converted.
616 AliasedResourceMap resourceMap =
617 collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
618 for (const auto &dr : resourceMap) {
619 const auto &resources = dr.second;
620 if (resources.size() == 1)
621 resources.front()->removeAttr("aliased");
622 }
623}
624} // namespace
625
626std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
627spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) {
628 return std::make_unique<UnifyAliasedResourcePass>(args: std::move(getTargetEnv));
629}
630

source code of mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp