1//===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to lower attributes that specify the shader ABI
10// for the functions in the generated SPIR-V module.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
15
16#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
20#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
22#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
23#include "mlir/IR/BuiltinAttributes.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "llvm/ADT/SetVector.h"
26
27namespace mlir {
28namespace spirv {
29#define GEN_PASS_DEF_SPIRVLOWERABIATTRIBUTESPASS
30#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
31} // namespace spirv
32} // namespace mlir
33
34using namespace mlir;
35
36/// Creates a global variable for an argument based on the ABI info.
37static spirv::GlobalVariableOp
38createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
39 unsigned argIndex,
40 spirv::InterfaceVarABIAttr abiInfo) {
41 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
42 if (!spirvModule)
43 return nullptr;
44
45 OpBuilder::InsertionGuard moduleInsertionGuard(builder);
46 builder.setInsertionPoint(funcOp.getOperation());
47 std::string varName =
48 funcOp.getName().str() + "_arg_" + std::to_string(val: argIndex);
49
50 // Get the type of variable. If this is a scalar/vector type and has an ABI
51 // info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If
52 // not it must already be a !spirv.ptr<!spirv.struct<...>>.
53 auto varType = funcOp.getFunctionType().getInput(argIndex);
54 if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
55 auto storageClass = abiInfo.getStorageClass();
56 if (!storageClass)
57 return nullptr;
58 varType =
59 spirv::PointerType::get(spirv::StructType::get(memberTypes: varType), *storageClass);
60 }
61 auto varPtrType = cast<spirv::PointerType>(varType);
62 auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
63
64 // Set the offset information.
65 varPointeeType =
66 cast<spirv::StructType>(VulkanLayoutUtils::decorateType(varPointeeType));
67
68 if (!varPointeeType)
69 return nullptr;
70
71 varType =
72 spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
73
74 return builder.create<spirv::GlobalVariableOp>(
75 funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
76 abiInfo.getBinding());
77}
78
79/// Gets the global variables that need to be specified as interface variable
80/// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
81static LogicalResult
82getInterfaceVariables(spirv::FuncOp funcOp,
83 SmallVectorImpl<Attribute> &interfaceVars) {
84 auto module = funcOp->getParentOfType<spirv::ModuleOp>();
85 if (!module) {
86 return failure();
87 }
88 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(op: funcOp);
89 spirv::TargetEnv targetEnv(targetEnvAttr);
90
91 SetVector<Operation *> interfaceVarSet;
92
93 // TODO: This should in reality traverse the entry function
94 // call graph and collect all the interfaces. For now, just traverse the
95 // instructions in this function.
96 funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
97 auto var =
98 module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
99 // Per SPIR-V spec: "Before version 1.4, the interface's
100 // storage classes are limited to the Input and Output storage classes.
101 // Starting with version 1.4, the interface's storage classes are all
102 // storage classes used in declaring all global variables referenced by the
103 // entry point’s call tree."
104 const spirv::StorageClass storageClass =
105 cast<spirv::PointerType>(var.getType()).getStorageClass();
106 if ((targetEnvAttr && targetEnv.getVersion() >= spirv::Version::V_1_4) ||
107 (llvm::is_contained(
108 {spirv::StorageClass::Input, spirv::StorageClass::Output},
109 storageClass))) {
110 interfaceVarSet.insert(var.getOperation());
111 }
112 });
113 for (auto &var : interfaceVarSet) {
114 interfaceVars.push_back(SymbolRefAttr::get(
115 funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
116 }
117 return success();
118}
119
120/// Lowers the entry point attribute.
121static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
122 OpBuilder &builder) {
123 auto entryPointAttrName = spirv::getEntryPointABIAttrName();
124 auto entryPointAttr =
125 funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
126 if (!entryPointAttr) {
127 return failure();
128 }
129
130 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(op: funcOp);
131 spirv::TargetEnv targetEnv(targetEnvAttr);
132
133 OpBuilder::InsertionGuard moduleInsertionGuard(builder);
134 auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
135 builder.setInsertionPointToEnd(spirvModule.getBody());
136
137 // Adds the spirv.EntryPointOp after collecting all the interface variables
138 // needed.
139 SmallVector<Attribute, 1> interfaceVars;
140 if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
141 return failure();
142 }
143
144 FailureOr<spirv::ExecutionModel> executionModel =
145 spirv::getExecutionModel(targetEnvAttr);
146 if (failed(executionModel))
147 return funcOp.emitRemark("lower entry point failure: could not select "
148 "execution model based on 'spirv.target_env'");
149
150 builder.create<spirv::EntryPointOp>(funcOp.getLoc(), *executionModel, funcOp,
151 interfaceVars);
152
153 // Specifies the spirv.ExecutionModeOp.
154 if (DenseI32ArrayAttr workgroupSizeAttr = entryPointAttr.getWorkgroupSize()) {
155 std::optional<ArrayRef<spirv::Capability>> caps =
156 spirv::getCapabilities(spirv::ExecutionMode::LocalSize);
157 if (!caps || targetEnv.allows(*caps)) {
158 builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
159 spirv::ExecutionMode::LocalSize,
160 workgroupSizeAttr.asArrayRef());
161 // Erase workgroup size.
162 entryPointAttr = spirv::EntryPointABIAttr::get(
163 entryPointAttr.getContext(), DenseI32ArrayAttr(),
164 entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
165 }
166 }
167 if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
168 std::optional<ArrayRef<spirv::Capability>> caps =
169 spirv::getCapabilities(spirv::ExecutionMode::SubgroupSize);
170 if (!caps || targetEnv.allows(*caps)) {
171 builder.create<spirv::ExecutionModeOp>(funcOp.getLoc(), funcOp,
172 spirv::ExecutionMode::SubgroupSize,
173 *subgroupSize);
174 // Erase subgroup size.
175 entryPointAttr = spirv::EntryPointABIAttr::get(
176 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
177 std::nullopt, entryPointAttr.getTargetWidth());
178 }
179 }
180 if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
181 std::optional<ArrayRef<spirv::Capability>> caps =
182 spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
183 if (!caps || targetEnv.allows(*caps)) {
184 builder.create<spirv::ExecutionModeOp>(
185 funcOp.getLoc(), funcOp,
186 spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
187 // Erase target width.
188 entryPointAttr = spirv::EntryPointABIAttr::get(
189 entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
190 entryPointAttr.getSubgroupSize(), std::nullopt);
191 }
192 }
193 if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
194 entryPointAttr.getTargetWidth())
195 funcOp->setAttr(entryPointAttrName, entryPointAttr);
196 else
197 funcOp->removeAttr(entryPointAttrName);
198 return success();
199}
200
201namespace {
202/// A pattern to convert function signature according to interface variable ABI
203/// attributes.
204///
205/// Specifically, this pattern creates global variables according to interface
206/// variable ABI attributes attached to function arguments and converts all
207/// function argument uses to those global variables. This is necessary because
208/// Vulkan requires all shader entry points to be of void(void) type.
209class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
210public:
211 using OpConversionPattern<spirv::FuncOp>::OpConversionPattern;
212
213 LogicalResult
214 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
215 ConversionPatternRewriter &rewriter) const override;
216};
217
218/// Pass to implement the ABI information specified as attributes.
219class LowerABIAttributesPass final
220 : public spirv::impl::SPIRVLowerABIAttributesPassBase<
221 LowerABIAttributesPass> {
222 void runOnOperation() override;
223};
224} // namespace
225
226LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
227 spirv::FuncOp funcOp, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter) const {
229 if (!funcOp->getAttrOfType<spirv::EntryPointABIAttr>(
230 spirv::getEntryPointABIAttrName())) {
231 // TODO: Non-entry point functions are not handled.
232 return failure();
233 }
234 TypeConverter::SignatureConversion signatureConverter(
235 funcOp.getFunctionType().getNumInputs());
236
237 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
238 auto indexType = typeConverter.getIndexType();
239
240 auto attrName = spirv::getInterfaceVarABIAttrName();
241
242 OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
243 rewriter.setInsertionPointToStart(&funcOp.front());
244
245 for (const auto &argType :
246 llvm::enumerate(funcOp.getFunctionType().getInputs())) {
247 auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
248 argType.index(), attrName);
249 if (!abiInfo) {
250 // TODO: For non-entry point functions, it should be legal
251 // to pass around scalar/vector values and return a scalar/vector. For now
252 // non-entry point functions are not handled in this ABI lowering and will
253 // produce an error.
254 return failure();
255 }
256 spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
257 rewriter, funcOp, argType.index(), abiInfo);
258 if (!var)
259 return failure();
260
261 // Insert spirv::AddressOf and spirv::AccessChain operations.
262 Value replacement =
263 rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
264 // Check if the arg is a scalar or vector type. In that case, the value
265 // needs to be loaded into registers.
266 // TODO: This is loading value of the scalar into registers
267 // at the start of the function. It is probably better to do the load just
268 // before the use. There might be multiple loads and currently there is no
269 // easy way to replace all uses with a sequence of operations.
270 if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
271 auto zero =
272 spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
273 auto loadPtr = rewriter.create<spirv::AccessChainOp>(
274 funcOp.getLoc(), replacement, zero.getConstant());
275 replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
276 }
277 signatureConverter.remapInput(argType.index(), replacement);
278 }
279 if (failed(rewriter.convertRegionTypes(region: &funcOp.getBody(), converter: *getTypeConverter(),
280 entryConversion: &signatureConverter)))
281 return failure();
282
283 // Creates a new function with the update signature.
284 rewriter.modifyOpInPlace(funcOp, [&] {
285 funcOp.setType(rewriter.getFunctionType(
286 signatureConverter.getConvertedTypes(), std::nullopt));
287 });
288 return success();
289}
290
291void LowerABIAttributesPass::runOnOperation() {
292 // Uses the signature conversion methodology of the dialect conversion
293 // framework to implement the conversion.
294 spirv::ModuleOp module = getOperation();
295 MLIRContext *context = &getContext();
296
297 spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(op: module);
298 if (!targetEnvAttr) {
299 module->emitOpError("missing SPIR-V target env attribute");
300 return signalPassFailure();
301 }
302 spirv::TargetEnv targetEnv(targetEnvAttr);
303
304 SPIRVTypeConverter typeConverter(targetEnv);
305
306 // Insert a bitcast in the case of a pointer type change.
307 typeConverter.addSourceMaterialization(callback: [](OpBuilder &builder,
308 spirv::PointerType type,
309 ValueRange inputs, Location loc) {
310 if (inputs.size() != 1 || !isa<spirv::PointerType>(Val: inputs[0].getType()))
311 return Value();
312 return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
313 });
314
315 RewritePatternSet patterns(context);
316 patterns.add<ProcessInterfaceVarABI>(arg&: typeConverter, args&: context);
317
318 ConversionTarget target(*context);
319 // "Legal" function ops should have no interface variable ABI attributes.
320 target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
321 StringRef attrName = spirv::getInterfaceVarABIAttrName();
322 for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
323 if (op.getArgAttr(i, attrName))
324 return false;
325 return true;
326 });
327 // All other SPIR-V ops are legal.
328 target.markUnknownOpDynamicallyLegal([](Operation *op) {
329 return op->getDialect()->getNamespace() ==
330 spirv::SPIRVDialect::getDialectNamespace();
331 });
332 if (failed(applyPartialConversion(module, target, std::move(patterns))))
333 return signalPassFailure();
334
335 // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
336 // attributes.
337 OpBuilder builder(context);
338 SmallVector<spirv::FuncOp, 1> entryPointFns;
339 auto entryPointAttrName = spirv::getEntryPointABIAttrName();
340 module.walk([&](spirv::FuncOp funcOp) {
341 if (funcOp->getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
342 entryPointFns.push_back(funcOp);
343 }
344 });
345 for (auto fn : entryPointFns) {
346 if (failed(lowerEntryPointABIAttr(fn, builder))) {
347 return signalPassFailure();
348 }
349 }
350}
351

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp