1//===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass that unifies access of multiple aliased resources
10// into access of one single resource.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
15
16#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "llvm/ADT/DenseMap.h"
26#include "llvm/ADT/STLExtras.h"
27#include <iterator>
28
29namespace mlir {
30namespace spirv {
31#define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
32#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
33} // namespace spirv
34} // namespace mlir
35
36using namespace mlir;
37
38//===----------------------------------------------------------------------===//
39// Utility functions
40//===----------------------------------------------------------------------===//
41
42using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
43using AliasedResourceMap =
44 DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
45
46/// Collects all aliased resources in the given SPIR-V `moduleOp`.
47static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
48 AliasedResourceMap aliasedResources;
49 moduleOp->walk(callback: [&aliasedResources](spirv::GlobalVariableOp varOp) {
50 if (varOp->getAttrOfType<UnitAttr>(name: "aliased")) {
51 std::optional<uint32_t> set = varOp.getDescriptorSet();
52 std::optional<uint32_t> binding = varOp.getBinding();
53 if (set && binding)
54 aliasedResources[{*set, *binding}].push_back(Elt: varOp);
55 }
56 });
57 return aliasedResources;
58}
59
60/// Returns the element type if the given `type` is a runtime array resource:
61/// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
62/// otherwise.
63static Type getRuntimeArrayElementType(Type type) {
64 auto ptrType = dyn_cast<spirv::PointerType>(Val&: type);
65 if (!ptrType)
66 return {};
67
68 auto structType = dyn_cast<spirv::StructType>(Val: ptrType.getPointeeType());
69 if (!structType || structType.getNumElements() != 1)
70 return {};
71
72 auto rtArrayType =
73 dyn_cast<spirv::RuntimeArrayType>(Val: structType.getElementType(0));
74 if (!rtArrayType)
75 return {};
76
77 return rtArrayType.getElementType();
78}
79
80/// Given a list of resource element `types`, returns the index of the canonical
81/// resource that all resources should be unified into. Returns std::nullopt if
82/// unable to unify.
83static std::optional<int>
84deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
85 // scalarNumBits: contains all resources' scalar types' bit counts.
86 // vectorNumBits: only contains resources whose element types are vectors.
87 // vectorIndices: each vector's original index in `types`.
88 SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
89 scalarNumBits.reserve(N: types.size());
90 vectorNumBits.reserve(N: types.size());
91 vectorIndices.reserve(N: types.size());
92
93 for (const auto &indexedTypes : llvm::enumerate(First&: types)) {
94 spirv::SPIRVType type = indexedTypes.value();
95 assert(type.isScalarOrVector());
96 if (auto vectorType = dyn_cast<VectorType>(Val&: type)) {
97 if (vectorType.getNumElements() % 2 != 0)
98 return std::nullopt; // Odd-sized vector has special layout
99 // requirements.
100
101 std::optional<int64_t> numBytes = type.getSizeInBytes();
102 if (!numBytes)
103 return std::nullopt;
104
105 scalarNumBits.push_back(
106 Elt: vectorType.getElementType().getIntOrFloatBitWidth());
107 vectorNumBits.push_back(Elt: *numBytes * 8);
108 vectorIndices.push_back(Elt: indexedTypes.index());
109 } else {
110 scalarNumBits.push_back(Elt: type.getIntOrFloatBitWidth());
111 }
112 }
113
114 if (!vectorNumBits.empty()) {
115 // Choose the *vector* with the smallest bitwidth as the canonical resource,
116 // so that we can still keep vectorized load/store and avoid partial updates
117 // to large vectors.
118 auto *minVal = llvm::min_element(Range&: vectorNumBits);
119 // Make sure that the canonical resource's bitwidth is divisible by others.
120 // With out this, we cannot properly adjust the index later.
121 if (llvm::any_of(Range&: vectorNumBits,
122 P: [&](int bits) { return bits % *minVal != 0; }))
123 return std::nullopt;
124
125 // Require all scalar type bit counts to be a multiple of the chosen
126 // vector's primitive type to avoid reading/writing subcomponents.
127 int index = vectorIndices[std::distance(first: vectorNumBits.begin(), last: minVal)];
128 int baseNumBits = scalarNumBits[index];
129 if (llvm::any_of(Range&: scalarNumBits,
130 P: [&](int bits) { return bits % baseNumBits != 0; }))
131 return std::nullopt;
132
133 return index;
134 }
135
136 // All element types are scalars. Then choose the smallest bitwidth as the
137 // cannonical resource to avoid subcomponent load/store.
138 auto *minVal = llvm::min_element(Range&: scalarNumBits);
139 if (llvm::any_of(Range&: scalarNumBits,
140 P: [minVal](int64_t bit) { return bit % *minVal != 0; }))
141 return std::nullopt;
142 return std::distance(first: scalarNumBits.begin(), last: minVal);
143}
144
145static bool areSameBitwidthScalarType(Type a, Type b) {
146 return a.isIntOrFloat() && b.isIntOrFloat() &&
147 a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
148}
149
150//===----------------------------------------------------------------------===//
151// Analysis
152//===----------------------------------------------------------------------===//
153
154namespace {
155/// A class for analyzing aliased resources.
156///
157/// Resources are expected to be spirv.GlobalVarible that has a descriptor set
158/// and binding number. Such resources are of the type
159/// `!spirv.ptr<!spirv.struct<...>>` per Vulkan requirements.
160///
161/// Right now, we only support the case that there is a single runtime array
162/// inside the struct.
163class ResourceAliasAnalysis {
164public:
165 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)
166
167 explicit ResourceAliasAnalysis(Operation *);
168
169 /// Returns true if the given `op` can be rewritten to use a canonical
170 /// resource.
171 bool shouldUnify(Operation *op) const;
172
173 /// Returns all descriptors and their corresponding aliased resources.
174 const AliasedResourceMap &getResourceMap() const { return resourceMap; }
175
176 /// Returns the canonical resource for the given descriptor/variable.
177 spirv::GlobalVariableOp
178 getCanonicalResource(const Descriptor &descriptor) const;
179 spirv::GlobalVariableOp
180 getCanonicalResource(spirv::GlobalVariableOp varOp) const;
181
182 /// Returns the element type for the given variable.
183 spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
184
185private:
186 /// Given the descriptor and aliased resources bound to it, analyze whether we
187 /// can unify them and record if so.
188 void recordIfUnifiable(const Descriptor &descriptor,
189 ArrayRef<spirv::GlobalVariableOp> resources);
190
191 /// Mapping from a descriptor to all aliased resources bound to it.
192 AliasedResourceMap resourceMap;
193
194 /// Mapping from a descriptor to the chosen canonical resource.
195 DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
196
197 /// Mapping from an aliased resource to its descriptor.
198 DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
199
200 /// Mapping from an aliased resource to its element (scalar/vector) type.
201 DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
202};
203} // namespace
204
205ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
206 // Collect all aliased resources first and put them into different sets
207 // according to the descriptor.
208 AliasedResourceMap aliasedResources =
209 collectAliasedResources(moduleOp: cast<spirv::ModuleOp>(Val: root));
210
211 // For each resource set, analyze whether we can unify; if so, try to identify
212 // a canonical resource, whose element type has the largest bitwidth.
213 for (const auto &descriptorResource : aliasedResources) {
214 recordIfUnifiable(descriptor: descriptorResource.first, resources: descriptorResource.second);
215 }
216}
217
218bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
219 if (!op)
220 return false;
221
222 if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(Val: op)) {
223 auto canonicalOp = getCanonicalResource(varOp);
224 return canonicalOp && varOp != canonicalOp;
225 }
226 if (auto addressOp = dyn_cast<spirv::AddressOfOp>(Val: op)) {
227 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
228 auto *varOp =
229 SymbolTable::lookupSymbolIn(op: moduleOp, symbol: addressOp.getVariable());
230 return shouldUnify(op: varOp);
231 }
232
233 if (auto acOp = dyn_cast<spirv::AccessChainOp>(Val: op))
234 return shouldUnify(op: acOp.getBasePtr().getDefiningOp());
235 if (auto loadOp = dyn_cast<spirv::LoadOp>(Val: op))
236 return shouldUnify(op: loadOp.getPtr().getDefiningOp());
237 if (auto storeOp = dyn_cast<spirv::StoreOp>(Val: op))
238 return shouldUnify(op: storeOp.getPtr().getDefiningOp());
239
240 return false;
241}
242
243spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
244 const Descriptor &descriptor) const {
245 auto varIt = canonicalResourceMap.find(Val: descriptor);
246 if (varIt == canonicalResourceMap.end())
247 return {};
248 return varIt->second;
249}
250
251spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
252 spirv::GlobalVariableOp varOp) const {
253 auto descriptorIt = descriptorMap.find(Val: varOp);
254 if (descriptorIt == descriptorMap.end())
255 return {};
256 return getCanonicalResource(descriptor: descriptorIt->second);
257}
258
259spirv::SPIRVType
260ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
261 auto it = elementTypeMap.find(Val: varOp);
262 if (it == elementTypeMap.end())
263 return {};
264 return it->second;
265}
266
267void ResourceAliasAnalysis::recordIfUnifiable(
268 const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
269 // Collect the element types for all resources in the current set.
270 SmallVector<spirv::SPIRVType> elementTypes;
271 for (spirv::GlobalVariableOp resource : resources) {
272 Type elementType = getRuntimeArrayElementType(type: resource.getType());
273 if (!elementType)
274 return; // Unexpected resource variable type.
275
276 auto type = cast<spirv::SPIRVType>(Val&: elementType);
277 if (!type.isScalarOrVector())
278 return; // Unexpected resource element type.
279
280 elementTypes.push_back(Elt: type);
281 }
282
283 std::optional<int> index = deduceCanonicalResource(types: elementTypes);
284 if (!index)
285 return;
286
287 // Update internal data structures for later use.
288 resourceMap[descriptor].assign(in_start: resources.begin(), in_end: resources.end());
289 canonicalResourceMap[descriptor] = resources[*index];
290 for (const auto &resource : llvm::enumerate(First&: resources)) {
291 descriptorMap[resource.value()] = descriptor;
292 elementTypeMap[resource.value()] = elementTypes[resource.index()];
293 }
294}
295
296//===----------------------------------------------------------------------===//
297// Patterns
298//===----------------------------------------------------------------------===//
299
300template <typename OpTy>
301class ConvertAliasResource : public OpConversionPattern<OpTy> {
302public:
303 ConvertAliasResource(const ResourceAliasAnalysis &analysis,
304 MLIRContext *context, PatternBenefit benefit = 1)
305 : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
306
307protected:
308 const ResourceAliasAnalysis &analysis;
309};
310
311struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
312 using ConvertAliasResource::ConvertAliasResource;
313
314 LogicalResult
315 matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
316 ConversionPatternRewriter &rewriter) const override {
317 // Just remove the aliased resource. Users will be rewritten to use the
318 // canonical one.
319 rewriter.eraseOp(op: varOp);
320 return success();
321 }
322};
323
324struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
325 using ConvertAliasResource::ConvertAliasResource;
326
327 LogicalResult
328 matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter) const override {
330 // Rewrite the AddressOf op to get the address of the canoncical resource.
331 auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
332 auto srcVarOp = cast<spirv::GlobalVariableOp>(
333 Val: SymbolTable::lookupSymbolIn(op: moduleOp, symbol: addressOp.getVariable()));
334 auto dstVarOp = analysis.getCanonicalResource(varOp: srcVarOp);
335 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(op: addressOp, args&: dstVarOp);
336 return success();
337 }
338};
339
340struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
341 using ConvertAliasResource::ConvertAliasResource;
342
343 LogicalResult
344 matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
345 ConversionPatternRewriter &rewriter) const override {
346 auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
347 if (!addressOp)
348 return rewriter.notifyMatchFailure(arg&: acOp, msg: "base ptr not addressof op");
349
350 auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
351 auto srcVarOp = cast<spirv::GlobalVariableOp>(
352 Val: SymbolTable::lookupSymbolIn(op: moduleOp, symbol: addressOp.getVariable()));
353 auto dstVarOp = analysis.getCanonicalResource(varOp: srcVarOp);
354
355 spirv::SPIRVType srcElemType = analysis.getElementType(varOp: srcVarOp);
356 spirv::SPIRVType dstElemType = analysis.getElementType(varOp: dstVarOp);
357
358 if (srcElemType == dstElemType ||
359 areSameBitwidthScalarType(a: srcElemType, b: dstElemType)) {
360 // We have the same bitwidth for source and destination element types.
361 // Thie indices keep the same.
362 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
363 op: acOp, args: adaptor.getBasePtr(), args: adaptor.getIndices());
364 return success();
365 }
366
367 Location loc = acOp.getLoc();
368
369 if (srcElemType.isIntOrFloat() && isa<VectorType>(Val: dstElemType)) {
370 // The source indices are for a buffer with scalar element types. Rewrite
371 // them into a buffer with vector element types. We need to scale the last
372 // index for the vector as a whole, then add one level of index for inside
373 // the vector.
374 int srcNumBytes = *srcElemType.getSizeInBytes();
375 int dstNumBytes = *dstElemType.getSizeInBytes();
376 assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
377
378 auto indices = llvm::to_vector<4>(Range: acOp.getIndices());
379 Value oldIndex = indices.back();
380 Type indexType = oldIndex.getType();
381
382 int ratio = dstNumBytes / srcNumBytes;
383 auto ratioValue = rewriter.create<spirv::ConstantOp>(
384 location: loc, args&: indexType, args: rewriter.getIntegerAttr(type: indexType, value: ratio));
385
386 indices.back() =
387 rewriter.create<spirv::SDivOp>(location: loc, args&: indexType, args&: oldIndex, args&: ratioValue);
388 indices.push_back(
389 Elt: rewriter.create<spirv::SModOp>(location: loc, args&: indexType, args&: oldIndex, args&: ratioValue));
390
391 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
392 op: acOp, args: adaptor.getBasePtr(), args&: indices);
393 return success();
394 }
395
396 if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
397 (isa<VectorType>(Val: srcElemType) && isa<VectorType>(Val: dstElemType))) {
398 // The source indices are for a buffer with larger bitwidth scalar/vector
399 // element types. Rewrite them into a buffer with smaller bitwidth element
400 // types. We only need to scale the last index.
401 int srcNumBytes = *srcElemType.getSizeInBytes();
402 int dstNumBytes = *dstElemType.getSizeInBytes();
403 assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
404
405 auto indices = llvm::to_vector<4>(Range: acOp.getIndices());
406 Value oldIndex = indices.back();
407 Type indexType = oldIndex.getType();
408
409 int ratio = srcNumBytes / dstNumBytes;
410 auto ratioValue = rewriter.create<spirv::ConstantOp>(
411 location: loc, args&: indexType, args: rewriter.getIntegerAttr(type: indexType, value: ratio));
412
413 indices.back() =
414 rewriter.create<spirv::IMulOp>(location: loc, args&: indexType, args&: oldIndex, args&: ratioValue);
415
416 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
417 op: acOp, args: adaptor.getBasePtr(), args&: indices);
418 return success();
419 }
420
421 return rewriter.notifyMatchFailure(
422 arg&: acOp, msg: "unsupported src/dst types for spirv.AccessChain");
423 }
424};
425
426struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
427 using ConvertAliasResource::ConvertAliasResource;
428
429 LogicalResult
430 matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
431 ConversionPatternRewriter &rewriter) const override {
432 auto srcPtrType = cast<spirv::PointerType>(Val: loadOp.getPtr().getType());
433 auto srcElemType = cast<spirv::SPIRVType>(Val: srcPtrType.getPointeeType());
434 auto dstPtrType = cast<spirv::PointerType>(Val: adaptor.getPtr().getType());
435 auto dstElemType = cast<spirv::SPIRVType>(Val: dstPtrType.getPointeeType());
436
437 Location loc = loadOp.getLoc();
438 auto newLoadOp = rewriter.create<spirv::LoadOp>(location: loc, args: adaptor.getPtr());
439 if (srcElemType == dstElemType) {
440 rewriter.replaceOp(op: loadOp, newValues: newLoadOp->getResults());
441 return success();
442 }
443
444 if (areSameBitwidthScalarType(a: srcElemType, b: dstElemType)) {
445 auto castOp = rewriter.create<spirv::BitcastOp>(location: loc, args&: srcElemType,
446 args: newLoadOp.getValue());
447 rewriter.replaceOp(op: loadOp, newValues: castOp->getResults());
448
449 return success();
450 }
451
452 if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
453 (isa<VectorType>(Val: srcElemType) && isa<VectorType>(Val: dstElemType))) {
454 // The source and destination have scalar types of different bitwidths, or
455 // vector types of different component counts. For such cases, we load
456 // multiple smaller bitwidth values and construct a larger bitwidth one.
457
458 int srcNumBytes = *srcElemType.getSizeInBytes();
459 int dstNumBytes = *dstElemType.getSizeInBytes();
460 assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
461 int ratio = srcNumBytes / dstNumBytes;
462 if (ratio > 4)
463 return rewriter.notifyMatchFailure(arg&: loadOp, msg: "more than 4 components");
464
465 SmallVector<Value> components;
466 components.reserve(N: ratio);
467 components.push_back(Elt: newLoadOp);
468
469 auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
470 if (!acOp)
471 return rewriter.notifyMatchFailure(arg&: loadOp, msg: "ptr not spirv.AccessChain");
472
473 auto i32Type = rewriter.getI32Type();
474 Value oneValue = spirv::ConstantOp::getOne(type: i32Type, loc, builder&: rewriter);
475 auto indices = llvm::to_vector<4>(Range: acOp.getIndices());
476 for (int i = 1; i < ratio; ++i) {
477 // Load all subsequent components belonging to this element.
478 indices.back() = rewriter.create<spirv::IAddOp>(
479 location: loc, args&: i32Type, args&: indices.back(), args&: oneValue);
480 auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
481 location: loc, args: acOp.getBasePtr(), args&: indices);
482 // Assuming little endian, this reads lower-ordered bits of the number
483 // to lower-numbered components of the vector.
484 components.push_back(
485 Elt: rewriter.create<spirv::LoadOp>(location: loc, args&: componentAcOp));
486 }
487
488 // Create a vector of the components and then cast back to the larger
489 // bitwidth element type. For spirv.bitcast, the lower-numbered components
490 // of the vector map to lower-ordered bits of the larger bitwidth element
491 // type.
492
493 Type vectorType = srcElemType;
494 if (!isa<VectorType>(Val: srcElemType))
495 vectorType = VectorType::get(shape: {ratio}, elementType: dstElemType);
496
497 // If both the source and destination are vector types, we need to make
498 // sure the scalar type is the same for composite construction later.
499 if (auto srcElemVecType = dyn_cast<VectorType>(Val&: srcElemType))
500 if (auto dstElemVecType = dyn_cast<VectorType>(Val&: dstElemType)) {
501 if (srcElemVecType.getElementType() !=
502 dstElemVecType.getElementType()) {
503 int64_t count =
504 dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
505
506 // Make sure not to create 1-element vectors, which are illegal in
507 // SPIR-V.
508 Type castType = srcElemVecType.getElementType();
509 if (count > 1)
510 castType = VectorType::get(shape: {count}, elementType: castType);
511
512 for (Value &c : components)
513 c = rewriter.create<spirv::BitcastOp>(location: loc, args&: castType, args&: c);
514 }
515 }
516 Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
517 location: loc, args&: vectorType, args&: components);
518
519 if (!isa<VectorType>(Val: srcElemType))
520 vectorValue =
521 rewriter.create<spirv::BitcastOp>(location: loc, args&: srcElemType, args&: vectorValue);
522 rewriter.replaceOp(op: loadOp, newValues: vectorValue);
523 return success();
524 }
525
526 return rewriter.notifyMatchFailure(
527 arg&: loadOp, msg: "unsupported src/dst types for spirv.Load");
528 }
529};
530
531struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
532 using ConvertAliasResource::ConvertAliasResource;
533
534 LogicalResult
535 matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
536 ConversionPatternRewriter &rewriter) const override {
537 auto srcElemType =
538 cast<spirv::PointerType>(Val: storeOp.getPtr().getType()).getPointeeType();
539 auto dstElemType =
540 cast<spirv::PointerType>(Val: adaptor.getPtr().getType()).getPointeeType();
541 if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
542 return rewriter.notifyMatchFailure(arg&: storeOp, msg: "not scalar type");
543 if (!areSameBitwidthScalarType(a: srcElemType, b: dstElemType))
544 return rewriter.notifyMatchFailure(arg&: storeOp, msg: "different bitwidth");
545
546 Location loc = storeOp.getLoc();
547 Value value = adaptor.getValue();
548 if (srcElemType != dstElemType)
549 value = rewriter.create<spirv::BitcastOp>(location: loc, args&: dstElemType, args&: value);
550 rewriter.replaceOpWithNewOp<spirv::StoreOp>(op: storeOp, args: adaptor.getPtr(),
551 args&: value, args: storeOp->getAttrs());
552 return success();
553 }
554};
555
556//===----------------------------------------------------------------------===//
557// Pass
558//===----------------------------------------------------------------------===//
559
560namespace {
561class UnifyAliasedResourcePass final
562 : public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
563 UnifyAliasedResourcePass> {
564public:
565 explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)
566 : getTargetEnvFn(std::move(getTargetEnv)) {}
567
568 void runOnOperation() override;
569
570private:
571 spirv::GetTargetEnvFn getTargetEnvFn;
572};
573
574void UnifyAliasedResourcePass::runOnOperation() {
575 spirv::ModuleOp moduleOp = getOperation();
576 MLIRContext *context = &getContext();
577
578 if (getTargetEnvFn) {
579 // This pass is only needed for targeting WebGPU, Metal, or layering
580 // Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into
581 // WGSL or MSL. The translation has limitations.
582 spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
583 spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
584 bool isVulkanOnAppleDevices =
585 clientAPI == spirv::ClientAPI::Vulkan &&
586 targetEnv.getVendorID() == spirv::Vendor::Apple;
587 if (clientAPI != spirv::ClientAPI::WebGPU &&
588 clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
589 return;
590 }
591
592 // Analyze aliased resources first.
593 ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
594
595 ConversionTarget target(*context);
596 target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
597 spirv::AccessChainOp, spirv::LoadOp,
598 spirv::StoreOp>(
599 callback: [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
600 target.addLegalDialect<spirv::SPIRVDialect>();
601
602 // Run patterns to rewrite usages of non-canonical resources.
603 RewritePatternSet patterns(context);
604 patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
605 ConvertLoad, ConvertStore>(arg&: analysis, args&: context);
606 if (failed(Result: applyPartialConversion(op: moduleOp, target, patterns: std::move(patterns))))
607 return signalPassFailure();
608
609 // Drop aliased attribute if we only have one single bound resource for a
610 // descriptor. We need to re-collect the map here given in the above the
611 // conversion is best effort; certain sets may not be converted.
612 AliasedResourceMap resourceMap =
613 collectAliasedResources(moduleOp: cast<spirv::ModuleOp>(Val&: moduleOp));
614 for (const auto &dr : resourceMap) {
615 const auto &resources = dr.second;
616 if (resources.size() == 1)
617 resources.front()->removeAttr(name: "aliased");
618 }
619}
620} // namespace
621
622std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
623spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) {
624 return std::make_unique<UnifyAliasedResourcePass>(args: std::move(getTargetEnv));
625}
626

source code of mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp