1//===- GroupOps.cpp - MLIR SPIR-V Group Ops ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the group operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
15
16#include "SPIRVOpUtils.h"
17#include "SPIRVParsingUtils.h"
18
19using namespace mlir::spirv::AttrNames;
20
21namespace mlir::spirv {
22
23template <typename OpTy>
24static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
25 OperationState &state) {
26 spirv::Scope executionScope;
27 GroupOperation groupOperation;
28 OpAsmParser::UnresolvedOperand valueInfo;
29 if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
30 executionScope, parser, state,
31 OpTy::getExecutionScopeAttrName(state.name)) ||
32 spirv::parseEnumStrAttr<GroupOperationAttr>(
33 groupOperation, parser, state,
34 OpTy::getGroupOperationAttrName(state.name)) ||
35 parser.parseOperand(valueInfo))
36 return failure();
37
38 std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
39 if (succeeded(result: parser.parseOptionalKeyword(keyword: kClusterSize))) {
40 clusterSizeInfo = OpAsmParser::UnresolvedOperand();
41 if (parser.parseLParen() || parser.parseOperand(result&: *clusterSizeInfo) ||
42 parser.parseRParen())
43 return failure();
44 }
45
46 Type resultType;
47 if (parser.parseColonType(result&: resultType))
48 return failure();
49
50 if (parser.resolveOperand(operand: valueInfo, type: resultType, result&: state.operands))
51 return failure();
52
53 if (clusterSizeInfo) {
54 Type i32Type = parser.getBuilder().getIntegerType(32);
55 if (parser.resolveOperand(operand: *clusterSizeInfo, type: i32Type, result&: state.operands))
56 return failure();
57 }
58
59 return parser.addTypeToList(type: resultType, result&: state.types);
60}
61
62template <typename GroupNonUniformArithmeticOpTy>
63static void printGroupNonUniformArithmeticOp(Operation *groupOp,
64 OpAsmPrinter &printer) {
65 printer
66 << " \""
67 << stringifyScope(
68 groupOp
69 ->getAttrOfType<spirv::ScopeAttr>(
70 GroupNonUniformArithmeticOpTy::getExecutionScopeAttrName(
71 groupOp->getName()))
72 .getValue())
73 << "\" \""
74 << stringifyGroupOperation(
75 groupOp
76 ->getAttrOfType<GroupOperationAttr>(
77 GroupNonUniformArithmeticOpTy::getGroupOperationAttrName(
78 groupOp->getName()))
79 .getValue())
80 << "\" " << groupOp->getOperand(0);
81
82 if (groupOp->getNumOperands() > 1)
83 printer << " " << kClusterSize << '(' << groupOp->getOperand(idx: 1) << ')';
84 printer << " : " << groupOp->getResult(idx: 0).getType();
85}
86
87template <typename OpTy>
88static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
89 spirv::Scope scope =
90 groupOp
91 ->getAttrOfType<spirv::ScopeAttr>(
92 OpTy::getExecutionScopeAttrName(groupOp->getName()))
93 .getValue();
94 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
95 return groupOp->emitOpError(
96 message: "execution scope must be 'Workgroup' or 'Subgroup'");
97
98 GroupOperation operation =
99 groupOp
100 ->getAttrOfType<GroupOperationAttr>(
101 OpTy::getGroupOperationAttrName(groupOp->getName()))
102 .getValue();
103 if (operation == GroupOperation::ClusteredReduce &&
104 groupOp->getNumOperands() == 1)
105 return groupOp->emitOpError(message: "cluster size operand must be provided for "
106 "'ClusteredReduce' group operation");
107 if (groupOp->getNumOperands() > 1) {
108 Operation *sizeOp = groupOp->getOperand(idx: 1).getDefiningOp();
109 int32_t clusterSize = 0;
110
111 // TODO: support specialization constant here.
112 if (failed(result: extractValueFromConstOp(op: sizeOp, value&: clusterSize)))
113 return groupOp->emitOpError(
114 message: "cluster size operand must come from a constant op");
115
116 if (!llvm::isPowerOf2_32(Value: clusterSize))
117 return groupOp->emitOpError(
118 message: "cluster size operand must be a power of two");
119 }
120 return success();
121}
122
123//===----------------------------------------------------------------------===//
124// spirv.GroupBroadcast
125//===----------------------------------------------------------------------===//
126
127LogicalResult GroupBroadcastOp::verify() {
128 spirv::Scope scope = getExecutionScope();
129 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
130 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
131
132 if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
133 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
134 return emitOpError("localid is a vector and can be with only "
135 " 2 or 3 components, actual number is ")
136 << localIdTy.getNumElements();
137
138 return success();
139}
140
141//===----------------------------------------------------------------------===//
142// spirv.GroupNonUniformBallotOp
143//===----------------------------------------------------------------------===//
144
145LogicalResult GroupNonUniformBallotOp::verify() {
146 spirv::Scope scope = getExecutionScope();
147 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
148 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
149
150 return success();
151}
152
153//===----------------------------------------------------------------------===//
154// spirv.GroupNonUniformBroadcast
155//===----------------------------------------------------------------------===//
156
157LogicalResult GroupNonUniformBroadcastOp::verify() {
158 spirv::Scope scope = getExecutionScope();
159 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
160 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
161
162 // SPIR-V spec: "Before version 1.5, Id must come from a
163 // constant instruction.
164 auto targetEnv = spirv::getDefaultTargetEnv(getContext());
165 if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
166 targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
167
168 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
169 auto *idOp = getId().getDefiningOp();
170 if (!idOp || !isa<spirv::ConstantOp, // for normal constant
171 spirv::ReferenceOfOp>(idOp)) // for spec constant
172 return emitOpError("id must be the result of a constant op");
173 }
174
175 return success();
176}
177
178//===----------------------------------------------------------------------===//
179// spirv.GroupNonUniformShuffle*
180//===----------------------------------------------------------------------===//
181
182template <typename OpTy>
183static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
184 spirv::Scope scope = op.getExecutionScope();
185 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
186 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
187
188 if (op.getOperands().back().getType().isSignedInteger())
189 return op.emitOpError("second operand must be a singless/unsigned integer");
190
191 return success();
192}
193
194LogicalResult GroupNonUniformShuffleOp::verify() {
195 return verifyGroupNonUniformShuffleOp(*this);
196}
197LogicalResult GroupNonUniformShuffleDownOp::verify() {
198 return verifyGroupNonUniformShuffleOp(*this);
199}
200LogicalResult GroupNonUniformShuffleUpOp::verify() {
201 return verifyGroupNonUniformShuffleOp(*this);
202}
203LogicalResult GroupNonUniformShuffleXorOp::verify() {
204 return verifyGroupNonUniformShuffleOp(*this);
205}
206
207//===----------------------------------------------------------------------===//
208// spirv.GroupNonUniformElectOp
209//===----------------------------------------------------------------------===//
210
211LogicalResult GroupNonUniformElectOp::verify() {
212 spirv::Scope scope = getExecutionScope();
213 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
214 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
215
216 return success();
217}
218
219//===----------------------------------------------------------------------===//
220// spirv.GroupNonUniformFAddOp
221//===----------------------------------------------------------------------===//
222
223LogicalResult GroupNonUniformFAddOp::verify() {
224 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this);
225}
226
227ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser,
228 OperationState &result) {
229 return parseGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(parser,
230 result);
231}
232
233void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
234 printGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this, p);
235}
236
237//===----------------------------------------------------------------------===//
238// spirv.GroupNonUniformFMaxOp
239//===----------------------------------------------------------------------===//
240
241LogicalResult GroupNonUniformFMaxOp::verify() {
242 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this);
243}
244
245ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
246 OperationState &result) {
247 return parseGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(parser,
248 result);
249}
250
251void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
252 printGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this, p);
253}
254
255//===----------------------------------------------------------------------===//
256// spirv.GroupNonUniformFMinOp
257//===----------------------------------------------------------------------===//
258
259LogicalResult GroupNonUniformFMinOp::verify() {
260 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this);
261}
262
263ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser,
264 OperationState &result) {
265 return parseGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(parser,
266 result);
267}
268
269void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
270 printGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this, p);
271}
272
273//===----------------------------------------------------------------------===//
274// spirv.GroupNonUniformFMulOp
275//===----------------------------------------------------------------------===//
276
277LogicalResult GroupNonUniformFMulOp::verify() {
278 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this);
279}
280
281ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser,
282 OperationState &result) {
283 return parseGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(parser,
284 result);
285}
286
287void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
288 printGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this, p);
289}
290
291//===----------------------------------------------------------------------===//
292// spirv.GroupNonUniformIAddOp
293//===----------------------------------------------------------------------===//
294
295LogicalResult GroupNonUniformIAddOp::verify() {
296 return verifyGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this);
297}
298
299ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser,
300 OperationState &result) {
301 return parseGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(parser,
302 result);
303}
304
305void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
306 printGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this, p);
307}
308
309//===----------------------------------------------------------------------===//
310// spirv.GroupNonUniformIMulOp
311//===----------------------------------------------------------------------===//
312
313LogicalResult GroupNonUniformIMulOp::verify() {
314 return verifyGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this);
315}
316
317ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser,
318 OperationState &result) {
319 return parseGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(parser,
320 result);
321}
322
323void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
324 printGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this, p);
325}
326
327//===----------------------------------------------------------------------===//
328// spirv.GroupNonUniformSMaxOp
329//===----------------------------------------------------------------------===//
330
331LogicalResult GroupNonUniformSMaxOp::verify() {
332 return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this);
333}
334
335ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
336 OperationState &result) {
337 return parseGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(parser,
338 result);
339}
340
341void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
342 printGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this, p);
343}
344
345//===----------------------------------------------------------------------===//
346// spirv.GroupNonUniformSMinOp
347//===----------------------------------------------------------------------===//
348
349LogicalResult GroupNonUniformSMinOp::verify() {
350 return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this);
351}
352
353ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser,
354 OperationState &result) {
355 return parseGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(parser,
356 result);
357}
358
359void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
360 printGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this, p);
361}
362
363//===----------------------------------------------------------------------===//
364// spirv.GroupNonUniformUMaxOp
365//===----------------------------------------------------------------------===//
366
367LogicalResult GroupNonUniformUMaxOp::verify() {
368 return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this);
369}
370
371ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
372 OperationState &result) {
373 return parseGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(parser,
374 result);
375}
376
377void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
378 printGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this, p);
379}
380
381//===----------------------------------------------------------------------===//
382// spirv.GroupNonUniformUMinOp
383//===----------------------------------------------------------------------===//
384
385LogicalResult GroupNonUniformUMinOp::verify() {
386 return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this);
387}
388
389ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser,
390 OperationState &result) {
391 return parseGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(parser,
392 result);
393}
394
395void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
396 printGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this, p);
397}
398
399//===----------------------------------------------------------------------===//
400// spirv.GroupNonUniformBitwiseAnd
401//===----------------------------------------------------------------------===//
402
403LogicalResult GroupNonUniformBitwiseAndOp::verify() {
404 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this);
405}
406
407ParseResult GroupNonUniformBitwiseAndOp::parse(OpAsmParser &parser,
408 OperationState &result) {
409 return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(parser,
410 result);
411}
412
413void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
414 printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this, p);
415}
416
417//===----------------------------------------------------------------------===//
418// spirv.GroupNonUniformBitwiseOr
419//===----------------------------------------------------------------------===//
420
421LogicalResult GroupNonUniformBitwiseOrOp::verify() {
422 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this);
423}
424
425ParseResult GroupNonUniformBitwiseOrOp::parse(OpAsmParser &parser,
426 OperationState &result) {
427 return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(parser,
428 result);
429}
430
431void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
432 printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this, p);
433}
434
435//===----------------------------------------------------------------------===//
436// spirv.GroupNonUniformBitwiseXor
437//===----------------------------------------------------------------------===//
438
439LogicalResult GroupNonUniformBitwiseXorOp::verify() {
440 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this);
441}
442
443ParseResult GroupNonUniformBitwiseXorOp::parse(OpAsmParser &parser,
444 OperationState &result) {
445 return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(parser,
446 result);
447}
448
449void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
450 printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this, p);
451}
452
453//===----------------------------------------------------------------------===//
454// spirv.GroupNonUniformLogicalAnd
455//===----------------------------------------------------------------------===//
456
457LogicalResult GroupNonUniformLogicalAndOp::verify() {
458 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this);
459}
460
461ParseResult GroupNonUniformLogicalAndOp::parse(OpAsmParser &parser,
462 OperationState &result) {
463 return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(parser,
464 result);
465}
466
467void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
468 printGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this, p);
469}
470
471//===----------------------------------------------------------------------===//
472// spirv.GroupNonUniformLogicalOr
473//===----------------------------------------------------------------------===//
474
475LogicalResult GroupNonUniformLogicalOrOp::verify() {
476 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this);
477}
478
479ParseResult GroupNonUniformLogicalOrOp::parse(OpAsmParser &parser,
480 OperationState &result) {
481 return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(parser,
482 result);
483}
484
485void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
486 printGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this, p);
487}
488
489//===----------------------------------------------------------------------===//
490// spirv.GroupNonUniformLogicalXor
491//===----------------------------------------------------------------------===//
492
493LogicalResult GroupNonUniformLogicalXorOp::verify() {
494 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this);
495}
496
497ParseResult GroupNonUniformLogicalXorOp::parse(OpAsmParser &parser,
498 OperationState &result) {
499 return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(parser,
500 result);
501}
502
503void GroupNonUniformLogicalXorOp::print(OpAsmPrinter &p) {
504 printGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this, p);
505}
506
507//===----------------------------------------------------------------------===//
508// Group op verification
509//===----------------------------------------------------------------------===//
510
511template <typename Op>
512static LogicalResult verifyGroupOp(Op op) {
513 spirv::Scope scope = op.getExecutionScope();
514 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
515 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
516
517 return success();
518}
519
520LogicalResult GroupIAddOp::verify() { return verifyGroupOp(*this); }
521
522LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); }
523
524LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); }
525
526LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); }
527
528LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); }
529
530LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); }
531
532LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); }
533
534LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); }
535
536LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
537
538LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
539
540} // namespace mlir::spirv
541

source code of mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp