1//===- GroupOps.cpp - MLIR SPIR-V Group Ops ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the group operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
15
16#include "SPIRVOpUtils.h"
17#include "SPIRVParsingUtils.h"
18
19using namespace mlir::spirv::AttrNames;
20
21namespace mlir::spirv {
22
23template <typename OpTy>
24static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
25 spirv::Scope scope =
26 groupOp
27 ->getAttrOfType<spirv::ScopeAttr>(
28 OpTy::getExecutionScopeAttrName(groupOp->getName()))
29 .getValue();
30 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
31 return groupOp->emitOpError(
32 message: "execution scope must be 'Workgroup' or 'Subgroup'");
33
34 GroupOperation operation =
35 groupOp
36 ->getAttrOfType<GroupOperationAttr>(
37 OpTy::getGroupOperationAttrName(groupOp->getName()))
38 .getValue();
39 if (operation == GroupOperation::ClusteredReduce &&
40 groupOp->getNumOperands() == 1)
41 return groupOp->emitOpError(message: "cluster size operand must be provided for "
42 "'ClusteredReduce' group operation");
43 if (groupOp->getNumOperands() > 1) {
44 Operation *sizeOp = groupOp->getOperand(idx: 1).getDefiningOp();
45 int32_t clusterSize = 0;
46
47 // TODO: support specialization constant here.
48 if (failed(Result: extractValueFromConstOp(op: sizeOp, value&: clusterSize)))
49 return groupOp->emitOpError(
50 message: "cluster size operand must come from a constant op");
51
52 if (!llvm::isPowerOf2_32(Value: clusterSize))
53 return groupOp->emitOpError(
54 message: "cluster size operand must be a power of two");
55 }
56 return success();
57}
58
59//===----------------------------------------------------------------------===//
60// spirv.GroupBroadcast
61//===----------------------------------------------------------------------===//
62
63LogicalResult GroupBroadcastOp::verify() {
64 spirv::Scope scope = getExecutionScope();
65 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
66 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
67
68 if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
69 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
70 return emitOpError("localid is a vector and can be with only "
71 " 2 or 3 components, actual number is ")
72 << localIdTy.getNumElements();
73
74 return success();
75}
76
77//===----------------------------------------------------------------------===//
78// spirv.GroupNonUniformBallotOp
79//===----------------------------------------------------------------------===//
80
81LogicalResult GroupNonUniformBallotOp::verify() {
82 spirv::Scope scope = getExecutionScope();
83 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
84 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
85
86 return success();
87}
88
89//===----------------------------------------------------------------------===//
90// spirv.GroupNonUniformBallotFindLSBOp
91//===----------------------------------------------------------------------===//
92
93LogicalResult GroupNonUniformBallotFindLSBOp::verify() {
94 spirv::Scope scope = getExecutionScope();
95 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
96 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
97
98 return success();
99}
100
101//===----------------------------------------------------------------------===//
102// spirv.GroupNonUniformBallotFindLSBOp
103//===----------------------------------------------------------------------===//
104
105LogicalResult GroupNonUniformBallotFindMSBOp::verify() {
106 spirv::Scope scope = getExecutionScope();
107 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
108 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
109
110 return success();
111}
112
113//===----------------------------------------------------------------------===//
114// spirv.GroupNonUniformBroadcast
115//===----------------------------------------------------------------------===//
116
117LogicalResult GroupNonUniformBroadcastOp::verify() {
118 spirv::Scope scope = getExecutionScope();
119 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
120 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
121
122 // SPIR-V spec: "Before version 1.5, Id must come from a
123 // constant instruction.
124 auto targetEnv = spirv::getDefaultTargetEnv(getContext());
125 if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
126 targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
127
128 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
129 auto *idOp = getId().getDefiningOp();
130 if (!idOp || !isa<spirv::ConstantOp, // for normal constant
131 spirv::ReferenceOfOp>(idOp)) // for spec constant
132 return emitOpError("id must be the result of a constant op");
133 }
134
135 return success();
136}
137
138//===----------------------------------------------------------------------===//
139// spirv.GroupNonUniformShuffle*
140//===----------------------------------------------------------------------===//
141
142template <typename OpTy>
143static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
144 spirv::Scope scope = op.getExecutionScope();
145 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
146 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
147
148 if (op.getOperands().back().getType().isSignedInteger())
149 return op.emitOpError("second operand must be a singless/unsigned integer");
150
151 return success();
152}
153
154LogicalResult GroupNonUniformShuffleOp::verify() {
155 return verifyGroupNonUniformShuffleOp(*this);
156}
157LogicalResult GroupNonUniformShuffleDownOp::verify() {
158 return verifyGroupNonUniformShuffleOp(*this);
159}
160LogicalResult GroupNonUniformShuffleUpOp::verify() {
161 return verifyGroupNonUniformShuffleOp(*this);
162}
163LogicalResult GroupNonUniformShuffleXorOp::verify() {
164 return verifyGroupNonUniformShuffleOp(*this);
165}
166
167//===----------------------------------------------------------------------===//
168// spirv.GroupNonUniformElectOp
169//===----------------------------------------------------------------------===//
170
171LogicalResult GroupNonUniformElectOp::verify() {
172 spirv::Scope scope = getExecutionScope();
173 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
174 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
175
176 return success();
177}
178
179//===----------------------------------------------------------------------===//
180// spirv.GroupNonUniformFAddOp
181//===----------------------------------------------------------------------===//
182
183LogicalResult GroupNonUniformFAddOp::verify() {
184 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this);
185}
186
187//===----------------------------------------------------------------------===//
188// spirv.GroupNonUniformFMaxOp
189//===----------------------------------------------------------------------===//
190
191LogicalResult GroupNonUniformFMaxOp::verify() {
192 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this);
193}
194
195//===----------------------------------------------------------------------===//
196// spirv.GroupNonUniformFMinOp
197//===----------------------------------------------------------------------===//
198
199LogicalResult GroupNonUniformFMinOp::verify() {
200 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this);
201}
202
203//===----------------------------------------------------------------------===//
204// spirv.GroupNonUniformFMulOp
205//===----------------------------------------------------------------------===//
206
207LogicalResult GroupNonUniformFMulOp::verify() {
208 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this);
209}
210
211//===----------------------------------------------------------------------===//
212// spirv.GroupNonUniformIAddOp
213//===----------------------------------------------------------------------===//
214
215LogicalResult GroupNonUniformIAddOp::verify() {
216 return verifyGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this);
217}
218
219//===----------------------------------------------------------------------===//
220// spirv.GroupNonUniformIMulOp
221//===----------------------------------------------------------------------===//
222
223LogicalResult GroupNonUniformIMulOp::verify() {
224 return verifyGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this);
225}
226
227//===----------------------------------------------------------------------===//
228// spirv.GroupNonUniformSMaxOp
229//===----------------------------------------------------------------------===//
230
231LogicalResult GroupNonUniformSMaxOp::verify() {
232 return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this);
233}
234
235//===----------------------------------------------------------------------===//
236// spirv.GroupNonUniformSMinOp
237//===----------------------------------------------------------------------===//
238
239LogicalResult GroupNonUniformSMinOp::verify() {
240 return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this);
241}
242
243//===----------------------------------------------------------------------===//
244// spirv.GroupNonUniformUMaxOp
245//===----------------------------------------------------------------------===//
246
247LogicalResult GroupNonUniformUMaxOp::verify() {
248 return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this);
249}
250
251//===----------------------------------------------------------------------===//
252// spirv.GroupNonUniformUMinOp
253//===----------------------------------------------------------------------===//
254
255LogicalResult GroupNonUniformUMinOp::verify() {
256 return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this);
257}
258
259//===----------------------------------------------------------------------===//
260// spirv.GroupNonUniformBitwiseAnd
261//===----------------------------------------------------------------------===//
262
263LogicalResult GroupNonUniformBitwiseAndOp::verify() {
264 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this);
265}
266
267//===----------------------------------------------------------------------===//
268// spirv.GroupNonUniformBitwiseOr
269//===----------------------------------------------------------------------===//
270
271LogicalResult GroupNonUniformBitwiseOrOp::verify() {
272 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this);
273}
274
275//===----------------------------------------------------------------------===//
276// spirv.GroupNonUniformBitwiseXor
277//===----------------------------------------------------------------------===//
278
279LogicalResult GroupNonUniformBitwiseXorOp::verify() {
280 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this);
281}
282
283//===----------------------------------------------------------------------===//
284// spirv.GroupNonUniformLogicalAnd
285//===----------------------------------------------------------------------===//
286
287LogicalResult GroupNonUniformLogicalAndOp::verify() {
288 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this);
289}
290
291//===----------------------------------------------------------------------===//
292// spirv.GroupNonUniformLogicalOr
293//===----------------------------------------------------------------------===//
294
295LogicalResult GroupNonUniformLogicalOrOp::verify() {
296 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this);
297}
298
299//===----------------------------------------------------------------------===//
300// spirv.GroupNonUniformLogicalXor
301//===----------------------------------------------------------------------===//
302
303LogicalResult GroupNonUniformLogicalXorOp::verify() {
304 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this);
305}
306
307//===----------------------------------------------------------------------===//
308// spirv.GroupNonUniformRotateKHR
309//===----------------------------------------------------------------------===//
310
311LogicalResult GroupNonUniformRotateKHROp::verify() {
312 spirv::Scope scope = getExecutionScope();
313 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
314 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
315
316 if (Value clusterSizeVal = getClusterSize()) {
317 mlir::Operation *defOp = clusterSizeVal.getDefiningOp();
318 int32_t clusterSize = 0;
319
320 if (failed(extractValueFromConstOp(defOp, clusterSize)))
321 return emitOpError("cluster size operand must come from a constant op");
322
323 if (!llvm::isPowerOf2_32(clusterSize))
324 return emitOpError("cluster size operand must be a power of two");
325 }
326
327 return success();
328}
329
330//===----------------------------------------------------------------------===//
331// Group op verification
332//===----------------------------------------------------------------------===//
333
334template <typename Op>
335static LogicalResult verifyGroupOp(Op op) {
336 spirv::Scope scope = op.getExecutionScope();
337 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
338 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
339
340 return success();
341}
342
343LogicalResult GroupIAddOp::verify() { return verifyGroupOp(*this); }
344
345LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); }
346
347LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); }
348
349LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); }
350
351LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); }
352
353LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); }
354
355LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); }
356
357LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); }
358
359LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
360
361LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
362
363} // namespace mlir::spirv
364

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp