1 | //===- GroupOps.cpp - MLIR SPIR-V Group Ops ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Defines the group operations in the SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
14 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
15 | |
16 | #include "SPIRVOpUtils.h" |
17 | #include "SPIRVParsingUtils.h" |
18 | |
19 | using namespace mlir::spirv::AttrNames; |
20 | |
21 | namespace mlir::spirv { |
22 | |
23 | template <typename OpTy> |
24 | static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) { |
25 | spirv::Scope scope = |
26 | groupOp |
27 | ->getAttrOfType<spirv::ScopeAttr>( |
28 | OpTy::getExecutionScopeAttrName(groupOp->getName())) |
29 | .getValue(); |
30 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) |
31 | return groupOp->emitOpError( |
32 | message: "execution scope must be 'Workgroup' or 'Subgroup'" ); |
33 | |
34 | GroupOperation operation = |
35 | groupOp |
36 | ->getAttrOfType<GroupOperationAttr>( |
37 | OpTy::getGroupOperationAttrName(groupOp->getName())) |
38 | .getValue(); |
39 | if (operation == GroupOperation::ClusteredReduce && |
40 | groupOp->getNumOperands() == 1) |
41 | return groupOp->emitOpError(message: "cluster size operand must be provided for " |
42 | "'ClusteredReduce' group operation" ); |
43 | if (groupOp->getNumOperands() > 1) { |
44 | Operation *sizeOp = groupOp->getOperand(idx: 1).getDefiningOp(); |
45 | int32_t clusterSize = 0; |
46 | |
47 | // TODO: support specialization constant here. |
48 | if (failed(Result: extractValueFromConstOp(op: sizeOp, value&: clusterSize))) |
49 | return groupOp->emitOpError( |
50 | message: "cluster size operand must come from a constant op" ); |
51 | |
52 | if (!llvm::isPowerOf2_32(Value: clusterSize)) |
53 | return groupOp->emitOpError( |
54 | message: "cluster size operand must be a power of two" ); |
55 | } |
56 | return success(); |
57 | } |
58 | |
59 | //===----------------------------------------------------------------------===// |
60 | // spirv.GroupBroadcast |
61 | //===----------------------------------------------------------------------===// |
62 | |
63 | LogicalResult GroupBroadcastOp::verify() { |
64 | spirv::Scope scope = getExecutionScope(); |
65 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) |
66 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'" ); |
67 | |
68 | if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType())) |
69 | if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3) |
70 | return emitOpError("localid is a vector and can be with only " |
71 | " 2 or 3 components, actual number is " ) |
72 | << localIdTy.getNumElements(); |
73 | |
74 | return success(); |
75 | } |
76 | |
77 | //===----------------------------------------------------------------------===// |
78 | // spirv.GroupNonUniformBallotOp |
79 | //===----------------------------------------------------------------------===// |
80 | |
81 | LogicalResult GroupNonUniformBallotOp::verify() { |
82 | spirv::Scope scope = getExecutionScope(); |
83 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) |
84 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'" ); |
85 | |
86 | return success(); |
87 | } |
88 | |
89 | //===----------------------------------------------------------------------===// |
90 | // spirv.GroupNonUniformBallotFindLSBOp |
91 | //===----------------------------------------------------------------------===// |
92 | |
93 | LogicalResult GroupNonUniformBallotFindLSBOp::verify() { |
94 | spirv::Scope scope = getExecutionScope(); |
95 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) |
96 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'" ); |
97 | |
98 | return success(); |
99 | } |
100 | |
101 | //===----------------------------------------------------------------------===// |
102 | // spirv.GroupNonUniformBallotFindLSBOp |
103 | //===----------------------------------------------------------------------===// |
104 | |
105 | LogicalResult GroupNonUniformBallotFindMSBOp::verify() { |
106 | spirv::Scope scope = getExecutionScope(); |
107 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) |
108 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'" ); |
109 | |
110 | return success(); |
111 | } |
112 | |
113 | //===----------------------------------------------------------------------===// |
114 | // spirv.GroupNonUniformBroadcast |
115 | //===----------------------------------------------------------------------===// |
116 | |
117 | LogicalResult GroupNonUniformBroadcastOp::verify() { |
118 | spirv::Scope scope = getExecutionScope(); |
119 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) |
120 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'" ); |
121 | |
122 | // SPIR-V spec: "Before version 1.5, Id must come from a |
123 | // constant instruction. |
124 | auto targetEnv = spirv::getDefaultTargetEnv(getContext()); |
125 | if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>()) |
126 | targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule); |
127 | |
128 | if (targetEnv.getVersion() < spirv::Version::V_1_5) { |
129 | auto *idOp = getId().getDefiningOp(); |
130 | if (!idOp || !isa<spirv::ConstantOp, // for normal constant |
131 | spirv::ReferenceOfOp>(idOp)) // for spec constant |
132 | return emitOpError("id must be the result of a constant op" ); |
133 | } |
134 | |
135 | return success(); |
136 | } |
137 | |
138 | //===----------------------------------------------------------------------===// |
139 | // spirv.GroupNonUniformShuffle* |
140 | //===----------------------------------------------------------------------===// |
141 | |
142 | template <typename OpTy> |
143 | static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) { |
144 | spirv::Scope scope = op.getExecutionScope(); |
145 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) |
146 | return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'" ); |
147 | |
148 | if (op.getOperands().back().getType().isSignedInteger()) |
149 | return op.emitOpError("second operand must be a singless/unsigned integer" ); |
150 | |
151 | return success(); |
152 | } |
153 | |
154 | LogicalResult GroupNonUniformShuffleOp::verify() { |
155 | return verifyGroupNonUniformShuffleOp(*this); |
156 | } |
157 | LogicalResult GroupNonUniformShuffleDownOp::verify() { |
158 | return verifyGroupNonUniformShuffleOp(*this); |
159 | } |
160 | LogicalResult GroupNonUniformShuffleUpOp::verify() { |
161 | return verifyGroupNonUniformShuffleOp(*this); |
162 | } |
163 | LogicalResult GroupNonUniformShuffleXorOp::verify() { |
164 | return verifyGroupNonUniformShuffleOp(*this); |
165 | } |
166 | |
167 | //===----------------------------------------------------------------------===// |
168 | // spirv.GroupNonUniformElectOp |
169 | //===----------------------------------------------------------------------===// |
170 | |
171 | LogicalResult GroupNonUniformElectOp::verify() { |
172 | spirv::Scope scope = getExecutionScope(); |
173 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) |
174 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'" ); |
175 | |
176 | return success(); |
177 | } |
178 | |
179 | //===----------------------------------------------------------------------===// |
180 | // spirv.GroupNonUniformFAddOp |
181 | //===----------------------------------------------------------------------===// |
182 | |
183 | LogicalResult GroupNonUniformFAddOp::verify() { |
184 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this); |
185 | } |
186 | |
187 | //===----------------------------------------------------------------------===// |
188 | // spirv.GroupNonUniformFMaxOp |
189 | //===----------------------------------------------------------------------===// |
190 | |
191 | LogicalResult GroupNonUniformFMaxOp::verify() { |
192 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this); |
193 | } |
194 | |
195 | //===----------------------------------------------------------------------===// |
196 | // spirv.GroupNonUniformFMinOp |
197 | //===----------------------------------------------------------------------===// |
198 | |
199 | LogicalResult GroupNonUniformFMinOp::verify() { |
200 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this); |
201 | } |
202 | |
203 | //===----------------------------------------------------------------------===// |
204 | // spirv.GroupNonUniformFMulOp |
205 | //===----------------------------------------------------------------------===// |
206 | |
207 | LogicalResult GroupNonUniformFMulOp::verify() { |
208 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this); |
209 | } |
210 | |
211 | //===----------------------------------------------------------------------===// |
212 | // spirv.GroupNonUniformIAddOp |
213 | //===----------------------------------------------------------------------===// |
214 | |
215 | LogicalResult GroupNonUniformIAddOp::verify() { |
216 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this); |
217 | } |
218 | |
219 | //===----------------------------------------------------------------------===// |
220 | // spirv.GroupNonUniformIMulOp |
221 | //===----------------------------------------------------------------------===// |
222 | |
223 | LogicalResult GroupNonUniformIMulOp::verify() { |
224 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this); |
225 | } |
226 | |
227 | //===----------------------------------------------------------------------===// |
228 | // spirv.GroupNonUniformSMaxOp |
229 | //===----------------------------------------------------------------------===// |
230 | |
231 | LogicalResult GroupNonUniformSMaxOp::verify() { |
232 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this); |
233 | } |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // spirv.GroupNonUniformSMinOp |
237 | //===----------------------------------------------------------------------===// |
238 | |
239 | LogicalResult GroupNonUniformSMinOp::verify() { |
240 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this); |
241 | } |
242 | |
243 | //===----------------------------------------------------------------------===// |
244 | // spirv.GroupNonUniformUMaxOp |
245 | //===----------------------------------------------------------------------===// |
246 | |
247 | LogicalResult GroupNonUniformUMaxOp::verify() { |
248 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this); |
249 | } |
250 | |
251 | //===----------------------------------------------------------------------===// |
252 | // spirv.GroupNonUniformUMinOp |
253 | //===----------------------------------------------------------------------===// |
254 | |
255 | LogicalResult GroupNonUniformUMinOp::verify() { |
256 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this); |
257 | } |
258 | |
259 | //===----------------------------------------------------------------------===// |
260 | // spirv.GroupNonUniformBitwiseAnd |
261 | //===----------------------------------------------------------------------===// |
262 | |
263 | LogicalResult GroupNonUniformBitwiseAndOp::verify() { |
264 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this); |
265 | } |
266 | |
267 | //===----------------------------------------------------------------------===// |
268 | // spirv.GroupNonUniformBitwiseOr |
269 | //===----------------------------------------------------------------------===// |
270 | |
271 | LogicalResult GroupNonUniformBitwiseOrOp::verify() { |
272 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this); |
273 | } |
274 | |
275 | //===----------------------------------------------------------------------===// |
276 | // spirv.GroupNonUniformBitwiseXor |
277 | //===----------------------------------------------------------------------===// |
278 | |
279 | LogicalResult GroupNonUniformBitwiseXorOp::verify() { |
280 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this); |
281 | } |
282 | |
283 | //===----------------------------------------------------------------------===// |
284 | // spirv.GroupNonUniformLogicalAnd |
285 | //===----------------------------------------------------------------------===// |
286 | |
287 | LogicalResult GroupNonUniformLogicalAndOp::verify() { |
288 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this); |
289 | } |
290 | |
291 | //===----------------------------------------------------------------------===// |
292 | // spirv.GroupNonUniformLogicalOr |
293 | //===----------------------------------------------------------------------===// |
294 | |
295 | LogicalResult GroupNonUniformLogicalOrOp::verify() { |
296 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this); |
297 | } |
298 | |
299 | //===----------------------------------------------------------------------===// |
300 | // spirv.GroupNonUniformLogicalXor |
301 | //===----------------------------------------------------------------------===// |
302 | |
303 | LogicalResult GroupNonUniformLogicalXorOp::verify() { |
304 | return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this); |
305 | } |
306 | |
307 | //===----------------------------------------------------------------------===// |
308 | // spirv.GroupNonUniformRotateKHR |
309 | //===----------------------------------------------------------------------===// |
310 | |
311 | LogicalResult GroupNonUniformRotateKHROp::verify() { |
312 | spirv::Scope scope = getExecutionScope(); |
313 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) |
314 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'" ); |
315 | |
316 | if (Value clusterSizeVal = getClusterSize()) { |
317 | mlir::Operation *defOp = clusterSizeVal.getDefiningOp(); |
318 | int32_t clusterSize = 0; |
319 | |
320 | if (failed(extractValueFromConstOp(defOp, clusterSize))) |
321 | return emitOpError("cluster size operand must come from a constant op" ); |
322 | |
323 | if (!llvm::isPowerOf2_32(clusterSize)) |
324 | return emitOpError("cluster size operand must be a power of two" ); |
325 | } |
326 | |
327 | return success(); |
328 | } |
329 | |
330 | //===----------------------------------------------------------------------===// |
331 | // Group op verification |
332 | //===----------------------------------------------------------------------===// |
333 | |
334 | template <typename Op> |
335 | static LogicalResult verifyGroupOp(Op op) { |
336 | spirv::Scope scope = op.getExecutionScope(); |
337 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) |
338 | return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'" ); |
339 | |
340 | return success(); |
341 | } |
342 | |
343 | LogicalResult GroupIAddOp::verify() { return verifyGroupOp(*this); } |
344 | |
345 | LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); } |
346 | |
347 | LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); } |
348 | |
349 | LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); } |
350 | |
351 | LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); } |
352 | |
353 | LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); } |
354 | |
355 | LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); } |
356 | |
357 | LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); } |
358 | |
359 | LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); } |
360 | |
361 | LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); } |
362 | |
363 | } // namespace mlir::spirv |
364 | |