1//===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to lower attributes that specify the shader ABI
10// for the functions in the generated SPIR-V module.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
15
16#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
20#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
22#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
23#include "mlir/IR/BuiltinAttributes.h"
24#include "mlir/Transforms/DialectConversion.h"
25
26namespace mlir {
27namespace spirv {
28#define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
29#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
30} // namespace spirv
31} // namespace mlir
32
33using namespace mlir;
34
35/// Creates a global variable for an argument based on the ABI info.
36static spirv::GlobalVariableOp
37createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
38 unsigned argIndex,
39 spirv::InterfaceVarABIAttr abiInfo) {
40 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
41 if (!spirvModule)
42 return nullptr;
43
44 OpBuilder::InsertionGuard moduleInsertionGuard(builder);
45 builder.setInsertionPoint(funcOp.getOperation());
46 std::string varName =
47 funcOp.getName().str() + "_arg_" + std::to_string(val: argIndex);
48
49 // Get the type of variable. If this is a scalar/vector type and has an ABI
50 // info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If
51 // not it must already be a !spirv.ptr<!spirv.struct<...>>.
52 auto varType = funcOp.getFunctionType().getInput(i: argIndex);
53 if (cast<spirv::SPIRVType>(Val&: varType).isScalarOrVector()) {
54 auto storageClass = abiInfo.getStorageClass();
55 if (!storageClass)
56 return nullptr;
57 varType =
58 spirv::PointerType::get(pointeeType: spirv::StructType::get(memberTypes: varType), storageClass: *storageClass);
59 }
60 auto varPtrType = cast<spirv::PointerType>(Val&: varType);
61 auto varPointeeType = cast<spirv::StructType>(Val: varPtrType.getPointeeType());
62
63 // Set the offset information.
64 varPointeeType =
65 cast<spirv::StructType>(Val: VulkanLayoutUtils::decorateType(structType: varPointeeType));
66
67 if (!varPointeeType)
68 return nullptr;
69
70 varType =
71 spirv::PointerType::get(pointeeType: varPointeeType, storageClass: varPtrType.getStorageClass());
72
73 return builder.create<spirv::GlobalVariableOp>(
74 location: funcOp.getLoc(), args&: varType, args&: varName, args: abiInfo.getDescriptorSet(),
75 args: abiInfo.getBinding());
76}
77
78/// Gets the global variables that need to be specified as interface variable
79/// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
80static LogicalResult
81getInterfaceVariables(spirv::FuncOp funcOp,
82 SmallVectorImpl<Attribute> &interfaceVars) {
83 auto module = funcOp->getParentOfType<spirv::ModuleOp>();
84 if (!module) {
85 return failure();
86 }
87 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(op: funcOp);
88 spirv::TargetEnv targetEnv(targetEnvAttr);
89
90 SetVector<Operation *> interfaceVarSet;
91
92 // TODO: This should in reality traverse the entry function
93 // call graph and collect all the interfaces. For now, just traverse the
94 // instructions in this function.
95 funcOp.walk(callback: [&](spirv::AddressOfOp addressOfOp) {
96 auto var =
97 module.lookupSymbol<spirv::GlobalVariableOp>(name: addressOfOp.getVariable());
98 // Per SPIR-V spec: "Before version 1.4, the interface's
99 // storage classes are limited to the Input and Output storage classes.
100 // Starting with version 1.4, the interface's storage classes are all
101 // storage classes used in declaring all global variables referenced by the
102 // entry point’s call tree."
103 const spirv::StorageClass storageClass =
104 cast<spirv::PointerType>(Val: var.getType()).getStorageClass();
105 if ((targetEnvAttr && targetEnv.getVersion() >= spirv::Version::V_1_4) ||
106 (llvm::is_contained(
107 Set: {spirv::StorageClass::Input, spirv::StorageClass::Output},
108 Element: storageClass))) {
109 interfaceVarSet.insert(X: var.getOperation());
110 }
111 });
112 for (auto &var : interfaceVarSet) {
113 interfaceVars.push_back(Elt: SymbolRefAttr::get(
114 ctx: funcOp.getContext(), value: cast<spirv::GlobalVariableOp>(Val: var).getSymName()));
115 }
116 return success();
117}
118
119/// Lowers the entry point attribute.
120static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
121 OpBuilder &builder) {
122 auto entryPointAttrName = spirv::getEntryPointABIAttrName();
123 auto entryPointAttr =
124 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(name: entryPointAttrName);
125 if (!entryPointAttr) {
126 return failure();
127 }
128
129 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(op: funcOp);
130 spirv::TargetEnv targetEnv(targetEnvAttr);
131
132 OpBuilder::InsertionGuard moduleInsertionGuard(builder);
133 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
134 builder.setInsertionPointToEnd(spirvModule.getBody());
135
136 // Adds the spirv.EntryPointOp after collecting all the interface variables
137 // needed.
138 SmallVector<Attribute, 1> interfaceVars;
139 if (failed(Result: getInterfaceVariables(funcOp, interfaceVars))) {
140 return failure();
141 }
142
143 FailureOr<spirv::ExecutionModel> executionModel =
144 spirv::getExecutionModel(targetAttr: targetEnvAttr);
145 if (failed(Result: executionModel))
146 return funcOp.emitRemark(message: "lower entry point failure: could not select "
147 "execution model based on 'spirv.target_env'");
148
149 builder.create<spirv::EntryPointOp>(location: funcOp.getLoc(), args&: *executionModel, args&: funcOp,
150 args&: interfaceVars);
151
152 // Specifies the spirv.ExecutionModeOp.
153 if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) {
154 std::optional<ArrayRef<spirv::Capability>> caps =
155 spirv::getCapabilities(value: spirv::ExecutionMode::LocalSize);
156 if (!caps || targetEnv.allows(*caps)) {
157 builder.create<spirv::ExecutionModeOp>(location: funcOp.getLoc(), args&: funcOp,
158 args: spirv::ExecutionMode::LocalSize,
159 args: workgroupSizeAttr.asArrayRef());
160 // Erase workgroup size.
161 entryPointAttr = spirv::EntryPointABIAttr::get(
162 context: entryPointAttr.getContext(), workgroup_size: DenseI32ArrayAttr(),
163 subgroup_size: entryPointAttr.getSubgroupSize(), target_width: entryPointAttr.getTargetWidth());
164 }
165 }
166 if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
167 std::optional<ArrayRef<spirv::Capability>> caps =
168 spirv::getCapabilities(value: spirv::ExecutionMode::SubgroupSize);
169 if (!caps || targetEnv.allows(*caps)) {
170 builder.create<spirv::ExecutionModeOp>(location: funcOp.getLoc(), args&: funcOp,
171 args: spirv::ExecutionMode::SubgroupSize,
172 args&: *subgroupSize);
173 // Erase subgroup size.
174 entryPointAttr = spirv::EntryPointABIAttr::get(
175 context: entryPointAttr.getContext(), workgroup_size: entryPointAttr.getWorkgroupSize(),
176 subgroup_size: std::nullopt, target_width: entryPointAttr.getTargetWidth());
177 }
178 }
179 if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
180 std::optional<ArrayRef<spirv::Capability>> caps =
181 spirv::getCapabilities(value: spirv::ExecutionMode::SignedZeroInfNanPreserve);
182 if (!caps || targetEnv.allows(*caps)) {
183 builder.create<spirv::ExecutionModeOp>(
184 location: funcOp.getLoc(), args&: funcOp,
185 args: spirv::ExecutionMode::SignedZeroInfNanPreserve, args&: *targetWidth);
186 // Erase target width.
187 entryPointAttr = spirv::EntryPointABIAttr::get(
188 context: entryPointAttr.getContext(), workgroup_size: entryPointAttr.getWorkgroupSize(),
189 subgroup_size: entryPointAttr.getSubgroupSize(), target_width: std::nullopt);
190 }
191 }
192 if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
193 entryPointAttr.getTargetWidth())
194 funcOp->setAttr(name: entryPointAttrName, value: entryPointAttr);
195 else
196 funcOp->removeAttr(name: entryPointAttrName);
197 return success();
198}
199
200namespace {
201/// A pattern to convert function signature according to interface variable ABI
202/// attributes.
203///
204/// Specifically, this pattern creates global variables according to interface
205/// variable ABI attributes attached to function arguments and converts all
206/// function argument uses to those global variables. This is necessary because
207/// Vulkan requires all shader entry points to be of void(void) type.
208class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
209public:
210 using OpConversionPattern<spirv::FuncOp>::OpConversionPattern;
211
212 LogicalResult
213 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
214 ConversionPatternRewriter &rewriter) const override;
215};
216
217/// Pass to implement the ABI information specified as attributes.
218class LowerABIAttributesPass final
219 : public spirv::impl::SPIRVLowerABIAttributesPassBase<
220 LowerABIAttributesPass> {
221 void runOnOperation() override;
222};
223} // namespace
224
225LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
226 spirv::FuncOp funcOp, OpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const {
228 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
229 name: spirv::getEntryPointABIAttrName())) {
230 // TODO: Non-entry point functions are not handled.
231 return failure();
232 }
233 TypeConverter::SignatureConversion signatureConverter(
234 funcOp.getFunctionType().getNumInputs());
235
236 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
237 auto indexType = typeConverter.getIndexType();
238
239 auto attrName = spirv::getInterfaceVarABIAttrName();
240
241 OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
242 rewriter.setInsertionPointToStart(&funcOp.front());
243
244 for (const auto &argType :
245 llvm::enumerate(First: funcOp.getFunctionType().getInputs())) {
246 auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
247 index: argType.index(), name: attrName);
248 if (!abiInfo) {
249 // TODO: For non-entry point functions, it should be legal
250 // to pass around scalar/vector values and return a scalar/vector. For now
251 // non-entry point functions are not handled in this ABI lowering and will
252 // produce an error.
253 return failure();
254 }
255 spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
256 builder&: rewriter, funcOp, argIndex: argType.index(), abiInfo);
257 if (!var)
258 return failure();
259
260 // Insert spirv::AddressOf and spirv::AccessChain operations.
261 Value replacement =
262 rewriter.create<spirv::AddressOfOp>(location: funcOp.getLoc(), args&: var);
263 // Check if the arg is a scalar or vector type. In that case, the value
264 // needs to be loaded into registers.
265 // TODO: This is loading value of the scalar into registers
266 // at the start of the function. It is probably better to do the load just
267 // before the use. There might be multiple loads and currently there is no
268 // easy way to replace all uses with a sequence of operations.
269 if (cast<spirv::SPIRVType>(Val: argType.value()).isScalarOrVector()) {
270 auto zero =
271 spirv::ConstantOp::getZero(type: indexType, loc: funcOp.getLoc(), builder&: rewriter);
272 auto loadPtr = rewriter.create<spirv::AccessChainOp>(
273 location: funcOp.getLoc(), args&: replacement, args: zero.getConstant());
274 replacement = rewriter.create<spirv::LoadOp>(location: funcOp.getLoc(), args&: loadPtr);
275 }
276 signatureConverter.remapInput(origInputNo: argType.index(), replacements: replacement);
277 }
278 if (failed(Result: rewriter.convertRegionTypes(region: &funcOp.getBody(), converter: *getTypeConverter(),
279 entryConversion: &signatureConverter)))
280 return failure();
281
282 // Creates a new function with the update signature.
283 rewriter.modifyOpInPlace(root: funcOp, callable: [&] {
284 funcOp.setType(
285 rewriter.getFunctionType(inputs: signatureConverter.getConvertedTypes(), results: {}));
286 });
287 return success();
288}
289
290void LowerABIAttributesPass::runOnOperation() {
291 // Uses the signature conversion methodology of the dialect conversion
292 // framework to implement the conversion.
293 spirv::ModuleOp module = getOperation();
294 MLIRContext *context = &getContext();
295
296 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(op: module);
297 if (!targetEnvAttr) {
298 module->emitOpError(message: "missing SPIR-V target env attribute");
299 return signalPassFailure();
300 }
301 spirv::TargetEnv targetEnv(targetEnvAttr);
302
303 SPIRVTypeConverter typeConverter(targetEnv);
304
305 // Insert a bitcast in the case of a pointer type change.
306 typeConverter.addSourceMaterialization(callback: [](OpBuilder &builder,
307 spirv::PointerType type,
308 ValueRange inputs, Location loc) {
309 if (inputs.size() != 1 || !isa<spirv::PointerType>(Val: inputs[0].getType()))
310 return Value();
311 return builder.create<spirv::BitcastOp>(location: loc, args&: type, args: inputs[0]).getResult();
312 });
313
314 RewritePatternSet patterns(context);
315 patterns.add<ProcessInterfaceVarABI>(arg&: typeConverter, args&: context);
316
317 ConversionTarget target(*context);
318 // "Legal" function ops should have no interface variable ABI attributes.
319 target.addDynamicallyLegalOp<spirv::FuncOp>(callback: [&](spirv::FuncOp op) {
320 StringRef attrName = spirv::getInterfaceVarABIAttrName();
321 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
322 if (op.getArgAttr(index: i, name: attrName))
323 return false;
324 return true;
325 });
326 // All other SPIR-V ops are legal.
327 target.markUnknownOpDynamicallyLegal(fn: [](Operation *op) {
328 return op->getDialect()->getNamespace() ==
329 spirv::SPIRVDialect::getDialectNamespace();
330 });
331 if (failed(Result: applyPartialConversion(op: module, target, patterns: std::move(patterns))))
332 return signalPassFailure();
333
334 // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
335 // attributes.
336 OpBuilder builder(context);
337 SmallVector<spirv::FuncOp, 1> entryPointFns;
338 auto entryPointAttrName = spirv::getEntryPointABIAttrName();
339 module.walk(callback: [&](spirv::FuncOp funcOp) {
340 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(name: entryPointAttrName)) {
341 entryPointFns.push_back(Elt: funcOp);
342 }
343 });
344 for (auto fn : entryPointFns) {
345 if (failed(Result: lowerEntryPointABIAttr(funcOp: fn, builder))) {
346 return signalPassFailure();
347 }
348 }
349}
350

source code of mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp