1//===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the control flow operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
16#include "mlir/Interfaces/CallInterfaces.h"
17
18#include "llvm/Support/InterleavedRange.h"
19
20#include "SPIRVOpUtils.h"
21#include "SPIRVParsingUtils.h"
22
23using namespace mlir::spirv::AttrNames;
24
25namespace mlir::spirv {
26
27/// Parses Function, Selection and Loop control attributes. If no control is
28/// specified, "None" is used as a default.
29template <typename EnumAttrClass, typename EnumClass>
30static ParseResult
31parseControlAttribute(OpAsmParser &parser, OperationState &state,
32 StringRef attrName = spirv::attributeName<EnumClass>()) {
33 if (succeeded(Result: parser.parseOptionalKeyword(keyword: kControl))) {
34 EnumClass control;
35 if (parser.parseLParen() ||
36 spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
37 parser.parseRParen())
38 return failure();
39 return success();
40 }
41 // Set control to "None" otherwise.
42 Builder builder = parser.getBuilder();
43 state.addAttribute(attrName,
44 builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
45 return success();
46}
47
48//===----------------------------------------------------------------------===//
49// spirv.BranchOp
50//===----------------------------------------------------------------------===//
51
52SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
53 assert(index == 0 && "invalid successor index");
54 return SuccessorOperands(0, getTargetOperandsMutable());
55}
56
57//===----------------------------------------------------------------------===//
58// spirv.BranchConditionalOp
59//===----------------------------------------------------------------------===//
60
61SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) {
62 assert(index < 2 && "invalid successor index");
63 return SuccessorOperands(index == kTrueIndex
64 ? getTrueTargetOperandsMutable()
65 : getFalseTargetOperandsMutable());
66}
67
68ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
69 OperationState &result) {
70 auto &builder = parser.getBuilder();
71 OpAsmParser::UnresolvedOperand condInfo;
72 Block *dest;
73
74 // Parse the condition.
75 Type boolTy = builder.getI1Type();
76 if (parser.parseOperand(condInfo) ||
77 parser.resolveOperand(condInfo, boolTy, result.operands))
78 return failure();
79
80 // Parse the optional branch weights.
81 if (succeeded(parser.parseOptionalLSquare())) {
82 IntegerAttr trueWeight, falseWeight;
83 NamedAttrList weights;
84
85 auto i32Type = builder.getIntegerType(32);
86 if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
87 parser.parseComma() ||
88 parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
89 parser.parseRSquare())
90 return failure();
91
92 StringAttr branchWeightsAttrName =
93 BranchConditionalOp::getBranchWeightsAttrName(result.name);
94 result.addAttribute(branchWeightsAttrName,
95 builder.getArrayAttr({trueWeight, falseWeight}));
96 }
97
98 // Parse the true branch.
99 SmallVector<Value, 4> trueOperands;
100 if (parser.parseComma() ||
101 parser.parseSuccessorAndUseList(dest, trueOperands))
102 return failure();
103 result.addSuccessors(dest);
104 result.addOperands(trueOperands);
105
106 // Parse the false branch.
107 SmallVector<Value, 4> falseOperands;
108 if (parser.parseComma() ||
109 parser.parseSuccessorAndUseList(dest, falseOperands))
110 return failure();
111 result.addSuccessors(dest);
112 result.addOperands(falseOperands);
113 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
114 builder.getDenseI32ArrayAttr(
115 {1, static_cast<int32_t>(trueOperands.size()),
116 static_cast<int32_t>(falseOperands.size())}));
117
118 return success();
119}
120
121void BranchConditionalOp::print(OpAsmPrinter &printer) {
122 printer << ' ' << getCondition();
123
124 if (std::optional<ArrayAttr> weights = getBranchWeights()) {
125 printer << ' '
126 << llvm::interleaved_array(weights->getAsValueRange<IntegerAttr>());
127 }
128
129 printer << ", ";
130 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
131 printer << ", ";
132 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
133}
134
135LogicalResult BranchConditionalOp::verify() {
136 if (auto weights = getBranchWeights()) {
137 if (weights->getValue().size() != 2) {
138 return emitOpError("must have exactly two branch weights");
139 }
140 if (llvm::all_of(*weights, [](Attribute attr) {
141 return llvm::cast<IntegerAttr>(attr).getValue().isZero();
142 }))
143 return emitOpError("branch weights cannot both be zero");
144 }
145
146 return success();
147}
148
149//===----------------------------------------------------------------------===//
150// spirv.FunctionCall
151//===----------------------------------------------------------------------===//
152
153LogicalResult FunctionCallOp::verify() {
154 auto fnName = getCalleeAttr();
155
156 auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
157 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
158 if (!funcOp) {
159 return emitOpError("callee function '")
160 << fnName.getValue() << "' not found in nearest symbol table";
161 }
162
163 auto functionType = funcOp.getFunctionType();
164
165 if (getNumResults() > 1) {
166 return emitOpError(
167 "expected callee function to have 0 or 1 result, but provided ")
168 << getNumResults();
169 }
170
171 if (functionType.getNumInputs() != getNumOperands()) {
172 return emitOpError("has incorrect number of operands for callee: expected ")
173 << functionType.getNumInputs() << ", but provided "
174 << getNumOperands();
175 }
176
177 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
178 if (getOperand(i).getType() != functionType.getInput(i)) {
179 return emitOpError("operand type mismatch: expected operand type ")
180 << functionType.getInput(i) << ", but provided "
181 << getOperand(i).getType() << " for operand number " << i;
182 }
183 }
184
185 if (functionType.getNumResults() != getNumResults()) {
186 return emitOpError(
187 "has incorrect number of results has for callee: expected ")
188 << functionType.getNumResults() << ", but provided "
189 << getNumResults();
190 }
191
192 if (getNumResults() &&
193 (getResult(0).getType() != functionType.getResult(0))) {
194 return emitOpError("result type mismatch: expected ")
195 << functionType.getResult(0) << ", but provided "
196 << getResult(0).getType();
197 }
198
199 return success();
200}
201
202CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
203 return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
204}
205
206void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
207 (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
208}
209
210Operation::operand_range FunctionCallOp::getArgOperands() {
211 return getArguments();
212}
213
214MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
215 return getArgumentsMutable();
216}
217
218//===----------------------------------------------------------------------===//
219// spirv.mlir.loop
220//===----------------------------------------------------------------------===//
221
222void LoopOp::build(OpBuilder &builder, OperationState &state) {
223 state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
224 spirv::LoopControl::None));
225 state.addRegion();
226}
227
228ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
229 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
230 result))
231 return failure();
232
233 if (succeeded(parser.parseOptionalArrow()))
234 if (parser.parseTypeList(result.types))
235 return failure();
236
237 return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
238}
239
240void LoopOp::print(OpAsmPrinter &printer) {
241 auto control = getLoopControl();
242 if (control != spirv::LoopControl::None)
243 printer << " control(" << spirv::stringifyLoopControl(control) << ")";
244 if (getNumResults() > 0) {
245 printer << " -> ";
246 printer << getResultTypes();
247 }
248 printer << ' ';
249 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
250 /*printBlockTerminators=*/true);
251}
252
253/// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
254/// given `dstBlock`.
255static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
256 // Check that there is only one op in the `srcBlock`.
257 if (!llvm::hasSingleElement(C&: srcBlock))
258 return false;
259
260 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
261 return branchOp && branchOp.getSuccessor() == &dstBlock;
262}
263
264/// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
265static bool isMergeBlock(Block &block) {
266 return llvm::hasSingleElement(block) && isa<spirv::MergeOp>(block.front());
267}
268
269/// Returns true if a `spirv.mlir.merge` op outside the merge block.
270static bool hasOtherMerge(Region &region) {
271 return !region.empty() && llvm::any_of(Range: region.getOps(), P: [&](Operation &op) {
272 return isa<spirv::MergeOp>(op) && op.getBlock() != &region.back();
273 });
274}
275
276LogicalResult LoopOp::verifyRegions() {
277 auto *op = getOperation();
278
279 // We need to verify that the blocks follow the following layout:
280 //
281 // +-------------+
282 // | entry block |
283 // +-------------+
284 // |
285 // v
286 // +-------------+
287 // | loop header | <-----+
288 // +-------------+ |
289 // |
290 // ... |
291 // \ | / |
292 // v |
293 // +---------------+ |
294 // | loop continue | -----+
295 // +---------------+
296 //
297 // ...
298 // \ | /
299 // v
300 // +-------------+
301 // | merge block |
302 // +-------------+
303
304 auto &region = op->getRegion(0);
305 // Allow empty region as a degenerated case, which can come from
306 // optimizations.
307 if (region.empty())
308 return success();
309
310 // The last block is the merge block.
311 Block &merge = region.back();
312 if (!isMergeBlock(merge))
313 return emitOpError("last block must be the merge block with only one "
314 "'spirv.mlir.merge' op");
315 if (hasOtherMerge(region))
316 return emitOpError(
317 "should not have 'spirv.mlir.merge' op outside the merge block");
318
319 if (region.hasOneBlock())
320 return emitOpError(
321 "must have an entry block branching to the loop header block");
322 // The first block is the entry block.
323 Block &entry = region.front();
324
325 if (std::next(region.begin(), 2) == region.end())
326 return emitOpError(
327 "must have a loop header block branched from the entry block");
328 // The second block is the loop header block.
329 Block &header = *std::next(region.begin(), 1);
330
331 if (!hasOneBranchOpTo(entry, header))
332 return emitOpError(
333 "entry block must only have one 'spirv.Branch' op to the second block");
334
335 if (std::next(region.begin(), 3) == region.end())
336 return emitOpError(
337 "requires a loop continue block branching to the loop header block");
338 // The second to last block is the loop continue block.
339 Block &cont = *std::prev(region.end(), 2);
340
341 // Make sure that we have a branch from the loop continue block to the loop
342 // header block.
343 if (llvm::none_of(
344 llvm::seq<unsigned>(0, cont.getNumSuccessors()),
345 [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
346 return emitOpError("second to last block must be the loop continue "
347 "block that branches to the loop header block");
348
349 // Make sure that no other blocks (except the entry and loop continue block)
350 // branches to the loop header block.
351 for (auto &block : llvm::make_range(std::next(region.begin(), 2),
352 std::prev(region.end(), 2))) {
353 for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
354 if (block.getSuccessor(i) == &header) {
355 return emitOpError("can only have the entry and loop continue "
356 "block branching to the loop header block");
357 }
358 }
359 }
360
361 return success();
362}
363
364Block *LoopOp::getEntryBlock() {
365 assert(!getBody().empty() && "op region should not be empty!");
366 return &getBody().front();
367}
368
369Block *LoopOp::getHeaderBlock() {
370 assert(!getBody().empty() && "op region should not be empty!");
371 // The second block is the loop header block.
372 return &*std::next(getBody().begin());
373}
374
375Block *LoopOp::getContinueBlock() {
376 assert(!getBody().empty() && "op region should not be empty!");
377 // The second to last block is the loop continue block.
378 return &*std::prev(getBody().end(), 2);
379}
380
381Block *LoopOp::getMergeBlock() {
382 assert(!getBody().empty() && "op region should not be empty!");
383 // The last block is the loop merge block.
384 return &getBody().back();
385}
386
387void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
388 assert(getBody().empty() && "entry and merge block already exist");
389 OpBuilder::InsertionGuard g(builder);
390 builder.createBlock(&getBody());
391 builder.createBlock(&getBody());
392
393 // Add a spirv.mlir.merge op into the merge block.
394 builder.create<spirv::MergeOp>(getLoc());
395}
396
397//===----------------------------------------------------------------------===//
398// spirv.Return
399//===----------------------------------------------------------------------===//
400
401LogicalResult ReturnOp::verify() {
402 // Verification is performed in spirv.func op.
403 return success();
404}
405
406//===----------------------------------------------------------------------===//
407// spirv.ReturnValue
408//===----------------------------------------------------------------------===//
409
410LogicalResult ReturnValueOp::verify() {
411 // Verification is performed in spirv.func op.
412 return success();
413}
414
415//===----------------------------------------------------------------------===//
416// spirv.Select
417//===----------------------------------------------------------------------===//
418
419LogicalResult SelectOp::verify() {
420 if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
421 auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
422 if (!resultVectorTy) {
423 return emitOpError("result expected to be of vector type when "
424 "condition is of vector type");
425 }
426 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
427 return emitOpError("result should have the same number of elements as "
428 "the condition when condition is of vector type");
429 }
430 }
431 return success();
432}
433
434// Custom availability implementation is needed for spirv.Select given the
435// syntax changes starting v1.4.
436SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
437 return {};
438}
439SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
440 return {};
441}
442std::optional<spirv::Version> SelectOp::getMinVersion() {
443 // Per the spec, "Before version 1.4, results are only computed per
444 // component."
445 if (isa<spirv::ScalarType>(getCondition().getType()) &&
446 isa<spirv::CompositeType>(getType()))
447 return Version::V_1_4;
448
449 return Version::V_1_0;
450}
451std::optional<spirv::Version> SelectOp::getMaxVersion() {
452 return Version::V_1_6;
453}
454
455//===----------------------------------------------------------------------===//
456// spirv.mlir.selection
457//===----------------------------------------------------------------------===//
458
459ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
460 if (parseControlAttribute<spirv::SelectionControlAttr,
461 spirv::SelectionControl>(parser, result))
462 return failure();
463
464 if (succeeded(parser.parseOptionalArrow()))
465 if (parser.parseTypeList(result.types))
466 return failure();
467
468 return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
469}
470
471void SelectionOp::print(OpAsmPrinter &printer) {
472 auto control = getSelectionControl();
473 if (control != spirv::SelectionControl::None)
474 printer << " control(" << spirv::stringifySelectionControl(control) << ")";
475 if (getNumResults() > 0) {
476 printer << " -> ";
477 printer << getResultTypes();
478 }
479 printer << ' ';
480 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
481 /*printBlockTerminators=*/true);
482}
483
484LogicalResult SelectionOp::verifyRegions() {
485 auto *op = getOperation();
486
487 // We need to verify that the blocks follow the following layout:
488 //
489 // +--------------+
490 // | header block |
491 // +--------------+
492 // / | \
493 // ...
494 //
495 //
496 // +---------+ +---------+ +---------+
497 // | case #0 | | case #1 | | case #2 | ...
498 // +---------+ +---------+ +---------+
499 //
500 //
501 // ...
502 // \ | /
503 // v
504 // +-------------+
505 // | merge block |
506 // +-------------+
507
508 auto &region = op->getRegion(0);
509 // Allow empty region as a degenerated case, which can come from
510 // optimizations.
511 if (region.empty())
512 return success();
513
514 // The last block is the merge block.
515 if (!isMergeBlock(region.back()))
516 return emitOpError("last block must be the merge block with only one "
517 "'spirv.mlir.merge' op");
518 if (hasOtherMerge(region))
519 return emitOpError(
520 "should not have 'spirv.mlir.merge' op outside the merge block");
521
522 if (region.hasOneBlock())
523 return emitOpError("must have a selection header block");
524
525 return success();
526}
527
528Block *SelectionOp::getHeaderBlock() {
529 assert(!getBody().empty() && "op region should not be empty!");
530 // The first block is the loop header block.
531 return &getBody().front();
532}
533
534Block *SelectionOp::getMergeBlock() {
535 assert(!getBody().empty() && "op region should not be empty!");
536 // The last block is the loop merge block.
537 return &getBody().back();
538}
539
540void SelectionOp::addMergeBlock(OpBuilder &builder) {
541 assert(getBody().empty() && "entry and merge block already exist");
542 OpBuilder::InsertionGuard guard(builder);
543 builder.createBlock(&getBody());
544
545 // Add a spirv.mlir.merge op into the merge block.
546 builder.create<spirv::MergeOp>(getLoc());
547}
548
549SelectionOp
550SelectionOp::createIfThen(Location loc, Value condition,
551 function_ref<void(OpBuilder &builder)> thenBody,
552 OpBuilder &builder) {
553 auto selectionOp =
554 builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
555
556 selectionOp.addMergeBlock(builder);
557 Block *mergeBlock = selectionOp.getMergeBlock();
558 Block *thenBlock = nullptr;
559
560 // Build the "then" block.
561 {
562 OpBuilder::InsertionGuard guard(builder);
563 thenBlock = builder.createBlock(mergeBlock);
564 thenBody(builder);
565 builder.create<spirv::BranchOp>(loc, mergeBlock);
566 }
567
568 // Build the header block.
569 {
570 OpBuilder::InsertionGuard guard(builder);
571 builder.createBlock(thenBlock);
572 builder.create<spirv::BranchConditionalOp>(
573 loc, condition, thenBlock,
574 /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
575 /*falseArguments=*/ArrayRef<Value>());
576 }
577
578 return selectionOp;
579}
580
581//===----------------------------------------------------------------------===//
582// spirv.Unreachable
583//===----------------------------------------------------------------------===//
584
585LogicalResult spirv::UnreachableOp::verify() {
586 auto *block = (*this)->getBlock();
587 // Fast track: if this is in entry block, its invalid. Otherwise, if no
588 // predecessors, it's valid.
589 if (block->isEntryBlock())
590 return emitOpError("cannot be used in reachable block");
591 if (block->hasNoPredecessors())
592 return success();
593
594 // TODO: further verification needs to analyze reachability from
595 // the entry block.
596
597 return success();
598}
599
600} // namespace mlir::spirv
601

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp