1 | //===- SCF.cpp - Structured Control Flow Operations -----------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SCF/IR/SCF.h" |
10 | #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
13 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
14 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
15 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" |
18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/IRMapping.h" |
21 | #include "mlir/IR/Matchers.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/Interfaces/FunctionInterfaces.h" |
24 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
25 | #include "mlir/Transforms/InliningUtils.h" |
26 | #include "llvm/ADT/MapVector.h" |
27 | #include "llvm/ADT/SmallPtrSet.h" |
28 | #include "llvm/ADT/TypeSwitch.h" |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::scf; |
32 | |
33 | #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc" |
34 | |
35 | //===----------------------------------------------------------------------===// |
36 | // SCFDialect Dialect Interfaces |
37 | //===----------------------------------------------------------------------===// |
38 | |
39 | namespace { |
40 | struct SCFInlinerInterface : public DialectInlinerInterface { |
41 | using DialectInlinerInterface::DialectInlinerInterface; |
42 | // We don't have any special restrictions on what can be inlined into |
43 | // destination regions (e.g. while/conditional bodies). Always allow it. |
44 | bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, |
45 | IRMapping &valueMapping) const final { |
46 | return true; |
47 | } |
48 | // Operations in scf dialect are always legal to inline since they are |
49 | // pure. |
50 | bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { |
51 | return true; |
52 | } |
53 | // Handle the given inlined terminator by replacing it with a new operation |
54 | // as necessary. Required when the region has only one block. |
55 | void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { |
56 | auto retValOp = dyn_cast<scf::YieldOp>(op); |
57 | if (!retValOp) |
58 | return; |
59 | |
60 | for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) { |
61 | std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue)); |
62 | } |
63 | } |
64 | }; |
65 | } // namespace |
66 | |
67 | //===----------------------------------------------------------------------===// |
68 | // SCFDialect |
69 | //===----------------------------------------------------------------------===// |
70 | |
71 | void SCFDialect::initialize() { |
72 | addOperations< |
73 | #define GET_OP_LIST |
74 | #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc" |
75 | >(); |
76 | addInterfaces<SCFInlinerInterface>(); |
77 | declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>(); |
78 | declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface, |
79 | InParallelOp, ReduceReturnOp>(); |
80 | declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp, |
81 | ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp, |
82 | ForallOp, InParallelOp, WhileOp, YieldOp>(); |
83 | declarePromisedInterface<ValueBoundsOpInterface, ForOp>(); |
84 | } |
85 | |
86 | /// Default callback for IfOp builders. Inserts a yield without arguments. |
87 | void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) { |
88 | builder.create<scf::YieldOp>(loc); |
89 | } |
90 | |
91 | /// Verifies that the first block of the given `region` is terminated by a |
92 | /// TerminatorTy. Reports errors on the given operation if it is not the case. |
93 | template <typename TerminatorTy> |
94 | static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, |
95 | StringRef errorMessage) { |
96 | Operation *terminatorOperation = nullptr; |
97 | if (!region.empty() && !region.front().empty()) { |
98 | terminatorOperation = ®ion.front().back(); |
99 | if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation)) |
100 | return yield; |
101 | } |
102 | auto diag = op->emitOpError(message: errorMessage); |
103 | if (terminatorOperation) |
104 | diag.attachNote(noteLoc: terminatorOperation->getLoc()) << "terminator here"; |
105 | return nullptr; |
106 | } |
107 | |
108 | //===----------------------------------------------------------------------===// |
109 | // ExecuteRegionOp |
110 | //===----------------------------------------------------------------------===// |
111 | |
112 | /// Replaces the given op with the contents of the given single-block region, |
113 | /// using the operands of the block terminator to replace operation results. |
114 | static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, |
115 | Region ®ion, ValueRange blockArgs = {}) { |
116 | assert(llvm::hasSingleElement(region) && "expected single-region block"); |
117 | Block *block = ®ion.front(); |
118 | Operation *terminator = block->getTerminator(); |
119 | ValueRange results = terminator->getOperands(); |
120 | rewriter.inlineBlockBefore(source: block, op, argValues: blockArgs); |
121 | rewriter.replaceOp(op, newValues: results); |
122 | rewriter.eraseOp(op: terminator); |
123 | } |
124 | |
125 | /// |
126 | /// (ssa-id `=`)? `execute_region` `->` function-result-type `{` |
127 | /// block+ |
128 | /// `}` |
129 | /// |
130 | /// Example: |
131 | /// scf.execute_region -> i32 { |
132 | /// %idx = load %rI[%i] : memref<128xi32> |
133 | /// return %idx : i32 |
134 | /// } |
135 | /// |
136 | ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, |
137 | OperationState &result) { |
138 | if (parser.parseOptionalArrowTypeList(result.types)) |
139 | return failure(); |
140 | |
141 | // Introduce the body region and parse it. |
142 | Region *body = result.addRegion(); |
143 | if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || |
144 | parser.parseOptionalAttrDict(result.attributes)) |
145 | return failure(); |
146 | |
147 | return success(); |
148 | } |
149 | |
150 | void ExecuteRegionOp::print(OpAsmPrinter &p) { |
151 | p.printOptionalArrowTypeList(getResultTypes()); |
152 | |
153 | p << ' '; |
154 | p.printRegion(getRegion(), |
155 | /*printEntryBlockArgs=*/false, |
156 | /*printBlockTerminators=*/true); |
157 | |
158 | p.printOptionalAttrDict((*this)->getAttrs()); |
159 | } |
160 | |
161 | LogicalResult ExecuteRegionOp::verify() { |
162 | if (getRegion().empty()) |
163 | return emitOpError("region needs to have at least one block"); |
164 | if (getRegion().front().getNumArguments() > 0) |
165 | return emitOpError("region cannot have any arguments"); |
166 | return success(); |
167 | } |
168 | |
169 | // Inline an ExecuteRegionOp if it only contains one block. |
170 | // "test.foo"() : () -> () |
171 | // %v = scf.execute_region -> i64 { |
172 | // %x = "test.val"() : () -> i64 |
173 | // scf.yield %x : i64 |
174 | // } |
175 | // "test.bar"(%v) : (i64) -> () |
176 | // |
177 | // becomes |
178 | // |
179 | // "test.foo"() : () -> () |
180 | // %x = "test.val"() : () -> i64 |
181 | // "test.bar"(%x) : (i64) -> () |
182 | // |
183 | struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { |
184 | using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; |
185 | |
186 | LogicalResult matchAndRewrite(ExecuteRegionOp op, |
187 | PatternRewriter &rewriter) const override { |
188 | if (!llvm::hasSingleElement(op.getRegion())) |
189 | return failure(); |
190 | replaceOpWithRegion(rewriter, op, op.getRegion()); |
191 | return success(); |
192 | } |
193 | }; |
194 | |
195 | // Inline an ExecuteRegionOp if its parent can contain multiple blocks. |
196 | // TODO generalize the conditions for operations which can be inlined into. |
197 | // func @func_execute_region_elim() { |
198 | // "test.foo"() : () -> () |
199 | // %v = scf.execute_region -> i64 { |
200 | // %c = "test.cmp"() : () -> i1 |
201 | // cf.cond_br %c, ^bb2, ^bb3 |
202 | // ^bb2: |
203 | // %x = "test.val1"() : () -> i64 |
204 | // cf.br ^bb4(%x : i64) |
205 | // ^bb3: |
206 | // %y = "test.val2"() : () -> i64 |
207 | // cf.br ^bb4(%y : i64) |
208 | // ^bb4(%z : i64): |
209 | // scf.yield %z : i64 |
210 | // } |
211 | // "test.bar"(%v) : (i64) -> () |
212 | // return |
213 | // } |
214 | // |
215 | // becomes |
216 | // |
217 | // func @func_execute_region_elim() { |
218 | // "test.foo"() : () -> () |
219 | // %c = "test.cmp"() : () -> i1 |
220 | // cf.cond_br %c, ^bb1, ^bb2 |
221 | // ^bb1: // pred: ^bb0 |
222 | // %x = "test.val1"() : () -> i64 |
223 | // cf.br ^bb3(%x : i64) |
224 | // ^bb2: // pred: ^bb0 |
225 | // %y = "test.val2"() : () -> i64 |
226 | // cf.br ^bb3(%y : i64) |
227 | // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2 |
228 | // "test.bar"(%z) : (i64) -> () |
229 | // return |
230 | // } |
231 | // |
232 | struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { |
233 | using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; |
234 | |
235 | LogicalResult matchAndRewrite(ExecuteRegionOp op, |
236 | PatternRewriter &rewriter) const override { |
237 | if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp())) |
238 | return failure(); |
239 | |
240 | Block *prevBlock = op->getBlock(); |
241 | Block *postBlock = rewriter.splitBlock(block: prevBlock, before: op->getIterator()); |
242 | rewriter.setInsertionPointToEnd(prevBlock); |
243 | |
244 | rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front()); |
245 | |
246 | for (Block &blk : op.getRegion()) { |
247 | if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) { |
248 | rewriter.setInsertionPoint(yieldOp); |
249 | rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock, |
250 | yieldOp.getResults()); |
251 | rewriter.eraseOp(yieldOp); |
252 | } |
253 | } |
254 | |
255 | rewriter.inlineRegionBefore(op.getRegion(), postBlock); |
256 | SmallVector<Value> blockArgs; |
257 | |
258 | for (auto res : op.getResults()) |
259 | blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc())); |
260 | |
261 | rewriter.replaceOp(op, blockArgs); |
262 | return success(); |
263 | } |
264 | }; |
265 | |
266 | void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, |
267 | MLIRContext *context) { |
268 | results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); |
269 | } |
270 | |
271 | void ExecuteRegionOp::getSuccessorRegions( |
272 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
273 | // If the predecessor is the ExecuteRegionOp, branch into the body. |
274 | if (point.isParent()) { |
275 | regions.push_back(RegionSuccessor(&getRegion())); |
276 | return; |
277 | } |
278 | |
279 | // Otherwise, the region branches back to the parent operation. |
280 | regions.push_back(RegionSuccessor(getResults())); |
281 | } |
282 | |
283 | //===----------------------------------------------------------------------===// |
284 | // ConditionOp |
285 | //===----------------------------------------------------------------------===// |
286 | |
287 | MutableOperandRange |
288 | ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { |
289 | assert((point.isParent() || point == getParentOp().getAfter()) && |
290 | "condition op can only exit the loop or branch to the after" |
291 | "region"); |
292 | // Pass all operands except the condition to the successor region. |
293 | return getArgsMutable(); |
294 | } |
295 | |
296 | void ConditionOp::getSuccessorRegions( |
297 | ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) { |
298 | FoldAdaptor adaptor(operands, *this); |
299 | |
300 | WhileOp whileOp = getParentOp(); |
301 | |
302 | // Condition can either lead to the after region or back to the parent op |
303 | // depending on whether the condition is true or not. |
304 | auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition()); |
305 | if (!boolAttr || boolAttr.getValue()) |
306 | regions.emplace_back(&whileOp.getAfter(), |
307 | whileOp.getAfter().getArguments()); |
308 | if (!boolAttr || !boolAttr.getValue()) |
309 | regions.emplace_back(whileOp.getResults()); |
310 | } |
311 | |
312 | //===----------------------------------------------------------------------===// |
313 | // ForOp |
314 | //===----------------------------------------------------------------------===// |
315 | |
316 | void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, |
317 | Value ub, Value step, ValueRange initArgs, |
318 | BodyBuilderFn bodyBuilder) { |
319 | OpBuilder::InsertionGuard guard(builder); |
320 | |
321 | result.addOperands({lb, ub, step}); |
322 | result.addOperands(initArgs); |
323 | for (Value v : initArgs) |
324 | result.addTypes(v.getType()); |
325 | Type t = lb.getType(); |
326 | Region *bodyRegion = result.addRegion(); |
327 | Block *bodyBlock = builder.createBlock(bodyRegion); |
328 | bodyBlock->addArgument(t, result.location); |
329 | for (Value v : initArgs) |
330 | bodyBlock->addArgument(v.getType(), v.getLoc()); |
331 | |
332 | // Create the default terminator if the builder is not provided and if the |
333 | // iteration arguments are not provided. Otherwise, leave this to the caller |
334 | // because we don't know which values to return from the loop. |
335 | if (initArgs.empty() && !bodyBuilder) { |
336 | ForOp::ensureTerminator(*bodyRegion, builder, result.location); |
337 | } else if (bodyBuilder) { |
338 | OpBuilder::InsertionGuard guard(builder); |
339 | builder.setInsertionPointToStart(bodyBlock); |
340 | bodyBuilder(builder, result.location, bodyBlock->getArgument(0), |
341 | bodyBlock->getArguments().drop_front()); |
342 | } |
343 | } |
344 | |
345 | LogicalResult ForOp::verify() { |
346 | // Check that the number of init args and op results is the same. |
347 | if (getInitArgs().size() != getNumResults()) |
348 | return emitOpError( |
349 | "mismatch in number of loop-carried values and defined values"); |
350 | |
351 | return success(); |
352 | } |
353 | |
354 | LogicalResult ForOp::verifyRegions() { |
355 | // Check that the body defines as single block argument for the induction |
356 | // variable. |
357 | if (getInductionVar().getType() != getLowerBound().getType()) |
358 | return emitOpError( |
359 | "expected induction variable to be same type as bounds and step"); |
360 | |
361 | if (getNumRegionIterArgs() != getNumResults()) |
362 | return emitOpError( |
363 | "mismatch in number of basic block args and defined values"); |
364 | |
365 | auto initArgs = getInitArgs(); |
366 | auto iterArgs = getRegionIterArgs(); |
367 | auto opResults = getResults(); |
368 | unsigned i = 0; |
369 | for (auto e : llvm::zip(initArgs, iterArgs, opResults)) { |
370 | if (std::get<0>(e).getType() != std::get<2>(e).getType()) |
371 | return emitOpError() << "types mismatch between "<< i |
372 | << "th iter operand and defined value"; |
373 | if (std::get<1>(e).getType() != std::get<2>(e).getType()) |
374 | return emitOpError() << "types mismatch between "<< i |
375 | << "th iter region arg and defined value"; |
376 | |
377 | ++i; |
378 | } |
379 | return success(); |
380 | } |
381 | |
382 | std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() { |
383 | return SmallVector<Value>{getInductionVar()}; |
384 | } |
385 | |
386 | std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() { |
387 | return SmallVector<OpFoldResult>{OpFoldResult(getLowerBound())}; |
388 | } |
389 | |
390 | std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() { |
391 | return SmallVector<OpFoldResult>{OpFoldResult(getStep())}; |
392 | } |
393 | |
394 | std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() { |
395 | return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())}; |
396 | } |
397 | |
398 | std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); } |
399 | |
400 | /// Promotes the loop body of a forOp to its containing block if the forOp |
401 | /// it can be determined that the loop has a single iteration. |
402 | LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) { |
403 | std::optional<int64_t> tripCount = |
404 | constantTripCount(getLowerBound(), getUpperBound(), getStep()); |
405 | if (!tripCount.has_value() || tripCount != 1) |
406 | return failure(); |
407 | |
408 | // Replace all results with the yielded values. |
409 | auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator()); |
410 | rewriter.replaceAllUsesWith(getResults(), getYieldedValues()); |
411 | |
412 | // Replace block arguments with lower bound (replacement for IV) and |
413 | // iter_args. |
414 | SmallVector<Value> bbArgReplacements; |
415 | bbArgReplacements.push_back(getLowerBound()); |
416 | llvm::append_range(bbArgReplacements, getInitArgs()); |
417 | |
418 | // Move the loop body operations to the loop's containing block. |
419 | rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(), |
420 | getOperation()->getIterator(), bbArgReplacements); |
421 | |
422 | // Erase the old terminator and the loop. |
423 | rewriter.eraseOp(yieldOp); |
424 | rewriter.eraseOp(*this); |
425 | |
426 | return success(); |
427 | } |
428 | |
429 | /// Prints the initialization list in the form of |
430 | /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>) |
431 | /// where 'inner' values are assumed to be region arguments and 'outer' values |
432 | /// are regular SSA values. |
433 | static void printInitializationList(OpAsmPrinter &p, |
434 | Block::BlockArgListType blocksArgs, |
435 | ValueRange initializers, |
436 | StringRef prefix = "") { |
437 | assert(blocksArgs.size() == initializers.size() && |
438 | "expected same length of arguments and initializers"); |
439 | if (initializers.empty()) |
440 | return; |
441 | |
442 | p << prefix << '('; |
443 | llvm::interleaveComma(c: llvm::zip(t&: blocksArgs, u&: initializers), os&: p, each_fn: [&](auto it) { |
444 | p << std::get<0>(it) << " = "<< std::get<1>(it); |
445 | }); |
446 | p << ")"; |
447 | } |
448 | |
449 | void ForOp::print(OpAsmPrinter &p) { |
450 | p << " "<< getInductionVar() << " = "<< getLowerBound() << " to " |
451 | << getUpperBound() << " step "<< getStep(); |
452 | |
453 | printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args"); |
454 | if (!getInitArgs().empty()) |
455 | p << " -> ("<< getInitArgs().getTypes() << ')'; |
456 | p << ' '; |
457 | if (Type t = getInductionVar().getType(); !t.isIndex()) |
458 | p << " : "<< t << ' '; |
459 | p.printRegion(getRegion(), |
460 | /*printEntryBlockArgs=*/false, |
461 | /*printBlockTerminators=*/!getInitArgs().empty()); |
462 | p.printOptionalAttrDict((*this)->getAttrs()); |
463 | } |
464 | |
465 | ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { |
466 | auto &builder = parser.getBuilder(); |
467 | Type type; |
468 | |
469 | OpAsmParser::Argument inductionVariable; |
470 | OpAsmParser::UnresolvedOperand lb, ub, step; |
471 | |
472 | // Parse the induction variable followed by '='. |
473 | if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() || |
474 | // Parse loop bounds. |
475 | parser.parseOperand(lb) || parser.parseKeyword("to") || |
476 | parser.parseOperand(ub) || parser.parseKeyword("step") || |
477 | parser.parseOperand(step)) |
478 | return failure(); |
479 | |
480 | // Parse the optional initial iteration arguments. |
481 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
482 | SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; |
483 | regionArgs.push_back(inductionVariable); |
484 | |
485 | bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); |
486 | if (hasIterArgs) { |
487 | // Parse assignment list and results type list. |
488 | if (parser.parseAssignmentList(regionArgs, operands) || |
489 | parser.parseArrowTypeList(result.types)) |
490 | return failure(); |
491 | } |
492 | |
493 | if (regionArgs.size() != result.types.size() + 1) |
494 | return parser.emitError( |
495 | parser.getNameLoc(), |
496 | "mismatch in number of loop-carried values and defined values"); |
497 | |
498 | // Parse optional type, else assume Index. |
499 | if (parser.parseOptionalColon()) |
500 | type = builder.getIndexType(); |
501 | else if (parser.parseType(type)) |
502 | return failure(); |
503 | |
504 | // Set block argument types, so that they are known when parsing the region. |
505 | regionArgs.front().type = type; |
506 | for (auto [iterArg, type] : |
507 | llvm::zip_equal(llvm::drop_begin(regionArgs), result.types)) |
508 | iterArg.type = type; |
509 | |
510 | // Parse the body region. |
511 | Region *body = result.addRegion(); |
512 | if (parser.parseRegion(*body, regionArgs)) |
513 | return failure(); |
514 | ForOp::ensureTerminator(*body, builder, result.location); |
515 | |
516 | // Resolve input operands. This should be done after parsing the region to |
517 | // catch invalid IR where operands were defined inside of the region. |
518 | if (parser.resolveOperand(lb, type, result.operands) || |
519 | parser.resolveOperand(ub, type, result.operands) || |
520 | parser.resolveOperand(step, type, result.operands)) |
521 | return failure(); |
522 | if (hasIterArgs) { |
523 | for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs), |
524 | operands, result.types)) { |
525 | Type type = std::get<2>(argOperandType); |
526 | std::get<0>(argOperandType).type = type; |
527 | if (parser.resolveOperand(std::get<1>(argOperandType), type, |
528 | result.operands)) |
529 | return failure(); |
530 | } |
531 | } |
532 | |
533 | // Parse the optional attribute list. |
534 | if (parser.parseOptionalAttrDict(result.attributes)) |
535 | return failure(); |
536 | |
537 | return success(); |
538 | } |
539 | |
540 | SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; } |
541 | |
542 | Block::BlockArgListType ForOp::getRegionIterArgs() { |
543 | return getBody()->getArguments().drop_front(getNumInductionVars()); |
544 | } |
545 | |
546 | MutableArrayRef<OpOperand> ForOp::getInitsMutable() { |
547 | return getInitArgsMutable(); |
548 | } |
549 | |
550 | FailureOr<LoopLikeOpInterface> |
551 | ForOp::replaceWithAdditionalYields(RewriterBase &rewriter, |
552 | ValueRange newInitOperands, |
553 | bool replaceInitOperandUsesInLoop, |
554 | const NewYieldValuesFn &newYieldValuesFn) { |
555 | // Create a new loop before the existing one, with the extra operands. |
556 | OpBuilder::InsertionGuard g(rewriter); |
557 | rewriter.setInsertionPoint(getOperation()); |
558 | auto inits = llvm::to_vector(getInitArgs()); |
559 | inits.append(newInitOperands.begin(), newInitOperands.end()); |
560 | scf::ForOp newLoop = rewriter.create<scf::ForOp>( |
561 | getLoc(), getLowerBound(), getUpperBound(), getStep(), inits, |
562 | [](OpBuilder &, Location, Value, ValueRange) {}); |
563 | newLoop->setAttrs(getPrunedAttributeList(getOperation(), {})); |
564 | |
565 | // Generate the new yield values and append them to the scf.yield operation. |
566 | auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator()); |
567 | ArrayRef<BlockArgument> newIterArgs = |
568 | newLoop.getBody()->getArguments().take_back(newInitOperands.size()); |
569 | { |
570 | OpBuilder::InsertionGuard g(rewriter); |
571 | rewriter.setInsertionPoint(yieldOp); |
572 | SmallVector<Value> newYieldedValues = |
573 | newYieldValuesFn(rewriter, getLoc(), newIterArgs); |
574 | assert(newInitOperands.size() == newYieldedValues.size() && |
575 | "expected as many new yield values as new iter operands"); |
576 | rewriter.modifyOpInPlace(yieldOp, [&]() { |
577 | yieldOp.getResultsMutable().append(newYieldedValues); |
578 | }); |
579 | } |
580 | |
581 | // Move the loop body to the new op. |
582 | rewriter.mergeBlocks(getBody(), newLoop.getBody(), |
583 | newLoop.getBody()->getArguments().take_front( |
584 | getBody()->getNumArguments())); |
585 | |
586 | if (replaceInitOperandUsesInLoop) { |
587 | // Replace all uses of `newInitOperands` with the corresponding basic block |
588 | // arguments. |
589 | for (auto it : llvm::zip(newInitOperands, newIterArgs)) { |
590 | rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it), |
591 | [&](OpOperand &use) { |
592 | Operation *user = use.getOwner(); |
593 | return newLoop->isProperAncestor(user); |
594 | }); |
595 | } |
596 | } |
597 | |
598 | // Replace the old loop. |
599 | rewriter.replaceOp(getOperation(), |
600 | newLoop->getResults().take_front(getNumResults())); |
601 | return cast<LoopLikeOpInterface>(newLoop.getOperation()); |
602 | } |
603 | |
604 | ForOp mlir::scf::getForInductionVarOwner(Value val) { |
605 | auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val); |
606 | if (!ivArg) |
607 | return ForOp(); |
608 | assert(ivArg.getOwner() && "unlinked block argument"); |
609 | auto *containingOp = ivArg.getOwner()->getParentOp(); |
610 | return dyn_cast_or_null<ForOp>(containingOp); |
611 | } |
612 | |
613 | OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
614 | return getInitArgs(); |
615 | } |
616 | |
617 | void ForOp::getSuccessorRegions(RegionBranchPoint point, |
618 | SmallVectorImpl<RegionSuccessor> ®ions) { |
619 | // Both the operation itself and the region may be branching into the body or |
620 | // back into the operation itself. It is possible for loop not to enter the |
621 | // body. |
622 | regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); |
623 | regions.push_back(RegionSuccessor(getResults())); |
624 | } |
625 | |
626 | SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } |
627 | |
628 | /// Promotes the loop body of a forallOp to its containing block if it can be |
629 | /// determined that the loop has a single iteration. |
630 | LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { |
631 | for (auto [lb, ub, step] : |
632 | llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) { |
633 | auto tripCount = constantTripCount(lb, ub, step); |
634 | if (!tripCount.has_value() || *tripCount != 1) |
635 | return failure(); |
636 | } |
637 | |
638 | promote(rewriter, *this); |
639 | return success(); |
640 | } |
641 | |
642 | Block::BlockArgListType ForallOp::getRegionIterArgs() { |
643 | return getBody()->getArguments().drop_front(getRank()); |
644 | } |
645 | |
646 | MutableArrayRef<OpOperand> ForallOp::getInitsMutable() { |
647 | return getOutputsMutable(); |
648 | } |
649 | |
650 | /// Promotes the loop body of a scf::ForallOp to its containing block. |
651 | void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { |
652 | OpBuilder::InsertionGuard g(rewriter); |
653 | scf::InParallelOp terminator = forallOp.getTerminator(); |
654 | |
655 | // Replace block arguments with lower bounds (replacements for IVs) and |
656 | // outputs. |
657 | SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter); |
658 | bbArgReplacements.append(forallOp.getOutputs().begin(), |
659 | forallOp.getOutputs().end()); |
660 | |
661 | // Move the loop body operations to the loop's containing block. |
662 | rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(), |
663 | forallOp->getIterator(), bbArgReplacements); |
664 | |
665 | // Replace the terminator with tensor.insert_slice ops. |
666 | rewriter.setInsertionPointAfter(forallOp); |
667 | SmallVector<Value> results; |
668 | results.reserve(N: forallOp.getResults().size()); |
669 | for (auto &yieldingOp : terminator.getYieldingOps()) { |
670 | auto parallelInsertSliceOp = |
671 | cast<tensor::ParallelInsertSliceOp>(yieldingOp); |
672 | |
673 | Value dst = parallelInsertSliceOp.getDest(); |
674 | Value src = parallelInsertSliceOp.getSource(); |
675 | if (llvm::isa<TensorType>(src.getType())) { |
676 | results.push_back(rewriter.create<tensor::InsertSliceOp>( |
677 | forallOp.getLoc(), dst.getType(), src, dst, |
678 | parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(), |
679 | parallelInsertSliceOp.getStrides(), |
680 | parallelInsertSliceOp.getStaticOffsets(), |
681 | parallelInsertSliceOp.getStaticSizes(), |
682 | parallelInsertSliceOp.getStaticStrides())); |
683 | } else { |
684 | llvm_unreachable("unsupported terminator"); |
685 | } |
686 | } |
687 | rewriter.replaceAllUsesWith(forallOp.getResults(), results); |
688 | |
689 | // Erase the old terminator and the loop. |
690 | rewriter.eraseOp(op: terminator); |
691 | rewriter.eraseOp(op: forallOp); |
692 | } |
693 | |
694 | LoopNest mlir::scf::buildLoopNest( |
695 | OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, |
696 | ValueRange steps, ValueRange iterArgs, |
697 | function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> |
698 | bodyBuilder) { |
699 | assert(lbs.size() == ubs.size() && |
700 | "expected the same number of lower and upper bounds"); |
701 | assert(lbs.size() == steps.size() && |
702 | "expected the same number of lower bounds and steps"); |
703 | |
704 | // If there are no bounds, call the body-building function and return early. |
705 | if (lbs.empty()) { |
706 | ValueVector results = |
707 | bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs) |
708 | : ValueVector(); |
709 | assert(results.size() == iterArgs.size() && |
710 | "loop nest body must return as many values as loop has iteration " |
711 | "arguments"); |
712 | return LoopNest{{}, std::move(results)}; |
713 | } |
714 | |
715 | // First, create the loop structure iteratively using the body-builder |
716 | // callback of `ForOp::build`. Do not create `YieldOp`s yet. |
717 | OpBuilder::InsertionGuard guard(builder); |
718 | SmallVector<scf::ForOp, 4> loops; |
719 | SmallVector<Value, 4> ivs; |
720 | loops.reserve(lbs.size()); |
721 | ivs.reserve(N: lbs.size()); |
722 | ValueRange currentIterArgs = iterArgs; |
723 | Location currentLoc = loc; |
724 | for (unsigned i = 0, e = lbs.size(); i < e; ++i) { |
725 | auto loop = builder.create<scf::ForOp>( |
726 | currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs, |
727 | [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, |
728 | ValueRange args) { |
729 | ivs.push_back(iv); |
730 | // It is safe to store ValueRange args because it points to block |
731 | // arguments of a loop operation that we also own. |
732 | currentIterArgs = args; |
733 | currentLoc = nestedLoc; |
734 | }); |
735 | // Set the builder to point to the body of the newly created loop. We don't |
736 | // do this in the callback because the builder is reset when the callback |
737 | // returns. |
738 | builder.setInsertionPointToStart(loop.getBody()); |
739 | loops.push_back(loop); |
740 | } |
741 | |
742 | // For all loops but the innermost, yield the results of the nested loop. |
743 | for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) { |
744 | builder.setInsertionPointToEnd(loops[i].getBody()); |
745 | builder.create<scf::YieldOp>(loc, loops[i + 1].getResults()); |
746 | } |
747 | |
748 | // In the body of the innermost loop, call the body building function if any |
749 | // and yield its results. |
750 | builder.setInsertionPointToStart(loops.back().getBody()); |
751 | ValueVector results = bodyBuilder |
752 | ? bodyBuilder(builder, currentLoc, ivs, |
753 | loops.back().getRegionIterArgs()) |
754 | : ValueVector(); |
755 | assert(results.size() == iterArgs.size() && |
756 | "loop nest body must return as many values as loop has iteration " |
757 | "arguments"); |
758 | builder.setInsertionPointToEnd(loops.back().getBody()); |
759 | builder.create<scf::YieldOp>(loc, results); |
760 | |
761 | // Return the loops. |
762 | ValueVector nestResults; |
763 | llvm::append_range(nestResults, loops.front().getResults()); |
764 | return LoopNest{std::move(loops), std::move(nestResults)}; |
765 | } |
766 | |
767 | LoopNest mlir::scf::buildLoopNest( |
768 | OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, |
769 | ValueRange steps, |
770 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { |
771 | // Delegate to the main function by wrapping the body builder. |
772 | return buildLoopNest(builder, loc, lbs, ubs, steps, iterArgs: std::nullopt, |
773 | bodyBuilder: [&bodyBuilder](OpBuilder &nestedBuilder, |
774 | Location nestedLoc, ValueRange ivs, |
775 | ValueRange) -> ValueVector { |
776 | if (bodyBuilder) |
777 | bodyBuilder(nestedBuilder, nestedLoc, ivs); |
778 | return {}; |
779 | }); |
780 | } |
781 | |
782 | SmallVector<Value> |
783 | mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, |
784 | OpOperand &operand, Value replacement, |
785 | const ValueTypeCastFnTy &castFn) { |
786 | assert(operand.getOwner() == forOp); |
787 | Type oldType = operand.get().getType(), newType = replacement.getType(); |
788 | |
789 | // 1. Create new iter operands, exactly 1 is replaced. |
790 | assert(operand.getOperandNumber() >= forOp.getNumControlOperands() && |
791 | "expected an iter OpOperand"); |
792 | assert(operand.get().getType() != replacement.getType() && |
793 | "Expected a different type"); |
794 | SmallVector<Value> newIterOperands; |
795 | for (OpOperand &opOperand : forOp.getInitArgsMutable()) { |
796 | if (opOperand.getOperandNumber() == operand.getOperandNumber()) { |
797 | newIterOperands.push_back(replacement); |
798 | continue; |
799 | } |
800 | newIterOperands.push_back(opOperand.get()); |
801 | } |
802 | |
803 | // 2. Create the new forOp shell. |
804 | scf::ForOp newForOp = rewriter.create<scf::ForOp>( |
805 | forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
806 | forOp.getStep(), newIterOperands); |
807 | newForOp->setAttrs(forOp->getAttrs()); |
808 | Block &newBlock = newForOp.getRegion().front(); |
809 | SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(), |
810 | newBlock.getArguments().end()); |
811 | |
812 | // 3. Inject an incoming cast op at the beginning of the block for the bbArg |
813 | // corresponding to the `replacement` value. |
814 | OpBuilder::InsertionGuard g(rewriter); |
815 | rewriter.setInsertionPointToStart(&newBlock); |
816 | BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg( |
817 | &newForOp->getOpOperand(operand.getOperandNumber())); |
818 | Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg); |
819 | newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn; |
820 | |
821 | // 4. Steal the old block ops, mapping to the newBlockTransferArgs. |
822 | Block &oldBlock = forOp.getRegion().front(); |
823 | rewriter.mergeBlocks(source: &oldBlock, dest: &newBlock, argValues: newBlockTransferArgs); |
824 | |
825 | // 5. Inject an outgoing cast op at the end of the block and yield it instead. |
826 | auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); |
827 | rewriter.setInsertionPoint(clonedYieldOp); |
828 | unsigned yieldIdx = |
829 | newRegionIterArg.getArgNumber() - forOp.getNumInductionVars(); |
830 | Value castOut = castFn(rewriter, newForOp.getLoc(), newType, |
831 | clonedYieldOp.getOperand(yieldIdx)); |
832 | SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands(); |
833 | newYieldOperands[yieldIdx] = castOut; |
834 | rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands); |
835 | rewriter.eraseOp(op: clonedYieldOp); |
836 | |
837 | // 6. Inject an outgoing cast op after the forOp. |
838 | rewriter.setInsertionPointAfter(newForOp); |
839 | SmallVector<Value> newResults = newForOp.getResults(); |
840 | newResults[yieldIdx] = |
841 | castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]); |
842 | |
843 | return newResults; |
844 | } |
845 | |
846 | namespace { |
847 | // Fold away ForOp iter arguments when: |
848 | // 1) The op yields the iter arguments. |
849 | // 2) The argument's corresponding outer region iterators (inputs) are yielded. |
850 | // 3) The iter arguments have no use and the corresponding (operation) results |
851 | // have no use. |
852 | // |
853 | // These arguments must be defined outside of the ForOp region and can just be |
854 | // forwarded after simplifying the op inits, yields and returns. |
855 | // |
856 | // The implementation uses `inlineBlockBefore` to steal the content of the |
857 | // original ForOp and avoid cloning. |
858 | struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { |
859 | using OpRewritePattern<scf::ForOp>::OpRewritePattern; |
860 | |
861 | LogicalResult matchAndRewrite(scf::ForOp forOp, |
862 | PatternRewriter &rewriter) const final { |
863 | bool canonicalize = false; |
864 | |
865 | // An internal flat vector of block transfer |
866 | // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to |
867 | // transformed block argument mappings. This plays the role of a |
868 | // IRMapping for the particular use case of calling into |
869 | // `inlineBlockBefore`. |
870 | int64_t numResults = forOp.getNumResults(); |
871 | SmallVector<bool, 4> keepMask; |
872 | keepMask.reserve(N: numResults); |
873 | SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues, |
874 | newResultValues; |
875 | newBlockTransferArgs.reserve(N: 1 + numResults); |
876 | newBlockTransferArgs.push_back(Elt: Value()); // iv placeholder with null value |
877 | newIterArgs.reserve(N: forOp.getInitArgs().size()); |
878 | newYieldValues.reserve(N: numResults); |
879 | newResultValues.reserve(N: numResults); |
880 | DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg; |
881 | for (auto [init, arg, result, yielded] : |
882 | llvm::zip(forOp.getInitArgs(), // iter from outside |
883 | forOp.getRegionIterArgs(), // iter inside region |
884 | forOp.getResults(), // op results |
885 | forOp.getYieldedValues() // iter yield |
886 | )) { |
887 | // Forwarded is `true` when: |
888 | // 1) The region `iter` argument is yielded. |
889 | // 2) The region `iter` argument the corresponding input is yielded. |
890 | // 3) The region `iter` argument has no use, and the corresponding op |
891 | // result has no use. |
892 | bool forwarded = (arg == yielded) || (init == yielded) || |
893 | (arg.use_empty() && result.use_empty()); |
894 | if (forwarded) { |
895 | canonicalize = true; |
896 | keepMask.push_back(false); |
897 | newBlockTransferArgs.push_back(init); |
898 | newResultValues.push_back(init); |
899 | continue; |
900 | } |
901 | |
902 | // Check if a previous kept argument always has the same values for init |
903 | // and yielded values. |
904 | if (auto it = initYieldToArg.find({init, yielded}); |
905 | it != initYieldToArg.end()) { |
906 | canonicalize = true; |
907 | keepMask.push_back(false); |
908 | auto [sameArg, sameResult] = it->second; |
909 | rewriter.replaceAllUsesWith(arg, sameArg); |
910 | rewriter.replaceAllUsesWith(result, sameResult); |
911 | // The replacement value doesn't matter because there are no uses. |
912 | newBlockTransferArgs.push_back(init); |
913 | newResultValues.push_back(init); |
914 | continue; |
915 | } |
916 | |
917 | // This value is kept. |
918 | initYieldToArg.insert({{init, yielded}, {arg, result}}); |
919 | keepMask.push_back(true); |
920 | newIterArgs.push_back(init); |
921 | newYieldValues.push_back(yielded); |
922 | newBlockTransferArgs.push_back(Value()); // placeholder with null value |
923 | newResultValues.push_back(Value()); // placeholder with null value |
924 | } |
925 | |
926 | if (!canonicalize) |
927 | return failure(); |
928 | |
929 | scf::ForOp newForOp = rewriter.create<scf::ForOp>( |
930 | forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
931 | forOp.getStep(), newIterArgs); |
932 | newForOp->setAttrs(forOp->getAttrs()); |
933 | Block &newBlock = newForOp.getRegion().front(); |
934 | |
935 | // Replace the null placeholders with newly constructed values. |
936 | newBlockTransferArgs[0] = newBlock.getArgument(i: 0); // iv |
937 | for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size(); |
938 | idx != e; ++idx) { |
939 | Value &blockTransferArg = newBlockTransferArgs[1 + idx]; |
940 | Value &newResultVal = newResultValues[idx]; |
941 | assert((blockTransferArg && newResultVal) || |
942 | (!blockTransferArg && !newResultVal)); |
943 | if (!blockTransferArg) { |
944 | blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx]; |
945 | newResultVal = newForOp.getResult(collapsedIdx++); |
946 | } |
947 | } |
948 | |
949 | Block &oldBlock = forOp.getRegion().front(); |
950 | assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() && |
951 | "unexpected argument size mismatch"); |
952 | |
953 | // No results case: the scf::ForOp builder already created a zero |
954 | // result terminator. Merge before this terminator and just get rid of the |
955 | // original terminator that has been merged in. |
956 | if (newIterArgs.empty()) { |
957 | auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); |
958 | rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs); |
959 | rewriter.eraseOp(op: newBlock.getTerminator()->getPrevNode()); |
960 | rewriter.replaceOp(forOp, newResultValues); |
961 | return success(); |
962 | } |
963 | |
964 | // No terminator case: merge and rewrite the merged terminator. |
965 | auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) { |
966 | OpBuilder::InsertionGuard g(rewriter); |
967 | rewriter.setInsertionPoint(mergedTerminator); |
968 | SmallVector<Value, 4> filteredOperands; |
969 | filteredOperands.reserve(N: newResultValues.size()); |
970 | for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx) |
971 | if (keepMask[idx]) |
972 | filteredOperands.push_back(Elt: mergedTerminator.getOperand(idx)); |
973 | rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(), |
974 | filteredOperands); |
975 | }; |
976 | |
977 | rewriter.mergeBlocks(source: &oldBlock, dest: &newBlock, argValues: newBlockTransferArgs); |
978 | auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator()); |
979 | cloneFilteredTerminator(mergedYieldOp); |
980 | rewriter.eraseOp(op: mergedYieldOp); |
981 | rewriter.replaceOp(forOp, newResultValues); |
982 | return success(); |
983 | } |
984 | }; |
985 | |
986 | /// Util function that tries to compute a constant diff between u and l. |
987 | /// Returns std::nullopt when the difference between two AffineValueMap is |
988 | /// dynamic. |
989 | static std::optional<int64_t> computeConstDiff(Value l, Value u) { |
990 | IntegerAttr clb, cub; |
991 | if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) { |
992 | llvm::APInt lbValue = clb.getValue(); |
993 | llvm::APInt ubValue = cub.getValue(); |
994 | return (ubValue - lbValue).getSExtValue(); |
995 | } |
996 | |
997 | // Else a simple pattern match for x + c or c + x |
998 | llvm::APInt diff; |
999 | if (matchPattern( |
1000 | u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) || |
1001 | matchPattern( |
1002 | u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l)))) |
1003 | return diff.getSExtValue(); |
1004 | return std::nullopt; |
1005 | } |
1006 | |
1007 | /// Rewriting pattern that erases loops that are known not to iterate, replaces |
1008 | /// single-iteration loops with their bodies, and removes empty loops that |
1009 | /// iterate at least once and only return values defined outside of the loop. |
1010 | struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { |
1011 | using OpRewritePattern<ForOp>::OpRewritePattern; |
1012 | |
1013 | LogicalResult matchAndRewrite(ForOp op, |
1014 | PatternRewriter &rewriter) const override { |
1015 | // If the upper bound is the same as the lower bound, the loop does not |
1016 | // iterate, just remove it. |
1017 | if (op.getLowerBound() == op.getUpperBound()) { |
1018 | rewriter.replaceOp(op, op.getInitArgs()); |
1019 | return success(); |
1020 | } |
1021 | |
1022 | std::optional<int64_t> diff = |
1023 | computeConstDiff(op.getLowerBound(), op.getUpperBound()); |
1024 | if (!diff) |
1025 | return failure(); |
1026 | |
1027 | // If the loop is known to have 0 iterations, remove it. |
1028 | if (*diff <= 0) { |
1029 | rewriter.replaceOp(op, op.getInitArgs()); |
1030 | return success(); |
1031 | } |
1032 | |
1033 | std::optional<llvm::APInt> maybeStepValue = op.getConstantStep(); |
1034 | if (!maybeStepValue) |
1035 | return failure(); |
1036 | |
1037 | // If the loop is known to have 1 iteration, inline its body and remove the |
1038 | // loop. |
1039 | llvm::APInt stepValue = *maybeStepValue; |
1040 | if (stepValue.sge(RHS: *diff)) { |
1041 | SmallVector<Value, 4> blockArgs; |
1042 | blockArgs.reserve(N: op.getInitArgs().size() + 1); |
1043 | blockArgs.push_back(Elt: op.getLowerBound()); |
1044 | llvm::append_range(blockArgs, op.getInitArgs()); |
1045 | replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs); |
1046 | return success(); |
1047 | } |
1048 | |
1049 | // Now we are left with loops that have more than 1 iterations. |
1050 | Block &block = op.getRegion().front(); |
1051 | if (!llvm::hasSingleElement(C&: block)) |
1052 | return failure(); |
1053 | // If the loop is empty, iterates at least once, and only returns values |
1054 | // defined outside of the loop, remove it and replace it with yield values. |
1055 | if (llvm::any_of(op.getYieldedValues(), |
1056 | [&](Value v) { return !op.isDefinedOutsideOfLoop(v); })) |
1057 | return failure(); |
1058 | rewriter.replaceOp(op, op.getYieldedValues()); |
1059 | return success(); |
1060 | } |
1061 | }; |
1062 | |
1063 | /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing |
1064 | /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for: |
1065 | /// |
1066 | /// ``` |
1067 | /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32> |
1068 | /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) |
1069 | /// -> (tensor<?x?xf32>) { |
1070 | /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32> |
1071 | /// scf.yield %2 : tensor<?x?xf32> |
1072 | /// } |
1073 | /// use_of(%1) |
1074 | /// ``` |
1075 | /// |
1076 | /// folds into: |
1077 | /// |
1078 | /// ``` |
1079 | /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0) |
1080 | /// -> (tensor<32x1024xf32>) { |
1081 | /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32> |
1082 | /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32> |
1083 | /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32> |
1084 | /// scf.yield %4 : tensor<32x1024xf32> |
1085 | /// } |
1086 | /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32> |
1087 | /// use_of(%1) |
1088 | /// ``` |
1089 | struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> { |
1090 | using OpRewritePattern<ForOp>::OpRewritePattern; |
1091 | |
1092 | LogicalResult matchAndRewrite(ForOp op, |
1093 | PatternRewriter &rewriter) const override { |
1094 | for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) { |
1095 | OpOperand &iterOpOperand = std::get<0>(it); |
1096 | auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>(); |
1097 | if (!incomingCast || |
1098 | incomingCast.getSource().getType() == incomingCast.getType()) |
1099 | continue; |
1100 | // If the dest type of the cast does not preserve static information in |
1101 | // the source type. |
1102 | if (!tensor::preservesStaticInformation( |
1103 | incomingCast.getDest().getType(), |
1104 | incomingCast.getSource().getType())) |
1105 | continue; |
1106 | if (!std::get<1>(it).hasOneUse()) |
1107 | continue; |
1108 | |
1109 | // Create a new ForOp with that iter operand replaced. |
1110 | rewriter.replaceOp( |
1111 | op, replaceAndCastForOpIterArg( |
1112 | rewriter, op, iterOpOperand, incomingCast.getSource(), |
1113 | [](OpBuilder &b, Location loc, Type type, Value source) { |
1114 | return b.create<tensor::CastOp>(loc, type, source); |
1115 | })); |
1116 | return success(); |
1117 | } |
1118 | return failure(); |
1119 | } |
1120 | }; |
1121 | |
1122 | } // namespace |
1123 | |
1124 | void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, |
1125 | MLIRContext *context) { |
1126 | results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>( |
1127 | context); |
1128 | } |
1129 | |
1130 | std::optional<APInt> ForOp::getConstantStep() { |
1131 | IntegerAttr step; |
1132 | if (matchPattern(getStep(), m_Constant(&step))) |
1133 | return step.getValue(); |
1134 | return {}; |
1135 | } |
1136 | |
1137 | std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() { |
1138 | return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable(); |
1139 | } |
1140 | |
1141 | Speculation::Speculatability ForOp::getSpeculatability() { |
1142 | // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start |
1143 | // and End. |
1144 | if (auto constantStep = getConstantStep()) |
1145 | if (*constantStep == 1) |
1146 | return Speculation::RecursivelySpeculatable; |
1147 | |
1148 | // For Step != 1, the loop may not terminate. We can add more smarts here if |
1149 | // needed. |
1150 | return Speculation::NotSpeculatable; |
1151 | } |
1152 | |
1153 | //===----------------------------------------------------------------------===// |
1154 | // ForallOp |
1155 | //===----------------------------------------------------------------------===// |
1156 | |
1157 | LogicalResult ForallOp::verify() { |
1158 | unsigned numLoops = getRank(); |
1159 | // Check number of outputs. |
1160 | if (getNumResults() != getOutputs().size()) |
1161 | return emitOpError("produces ") |
1162 | << getNumResults() << " results, but has only " |
1163 | << getOutputs().size() << " outputs"; |
1164 | |
1165 | // Check that the body defines block arguments for thread indices and outputs. |
1166 | auto *body = getBody(); |
1167 | if (body->getNumArguments() != numLoops + getOutputs().size()) |
1168 | return emitOpError("region expects ") << numLoops << " arguments"; |
1169 | for (int64_t i = 0; i < numLoops; ++i) |
1170 | if (!body->getArgument(i).getType().isIndex()) |
1171 | return emitOpError("expects ") |
1172 | << i << "-th block argument to be an index"; |
1173 | for (unsigned i = 0; i < getOutputs().size(); ++i) |
1174 | if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType()) |
1175 | return emitOpError("type mismatch between ") |
1176 | << i << "-th output and corresponding block argument"; |
1177 | if (getMapping().has_value() && !getMapping()->empty()) { |
1178 | if (static_cast<int64_t>(getMapping()->size()) != numLoops) |
1179 | return emitOpError() << "mapping attribute size must match op rank"; |
1180 | for (auto map : getMapping()->getValue()) { |
1181 | if (!isa<DeviceMappingAttrInterface>(map)) |
1182 | return emitOpError() |
1183 | << getMappingAttrName() << " is not device mapping attribute"; |
1184 | } |
1185 | } |
1186 | |
1187 | // Verify mixed static/dynamic control variables. |
1188 | Operation *op = getOperation(); |
1189 | if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops, |
1190 | getStaticLowerBound(), |
1191 | getDynamicLowerBound()))) |
1192 | return failure(); |
1193 | if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops, |
1194 | getStaticUpperBound(), |
1195 | getDynamicUpperBound()))) |
1196 | return failure(); |
1197 | if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops, |
1198 | getStaticStep(), getDynamicStep()))) |
1199 | return failure(); |
1200 | |
1201 | return success(); |
1202 | } |
1203 | |
1204 | void ForallOp::print(OpAsmPrinter &p) { |
1205 | Operation *op = getOperation(); |
1206 | p << " ("<< getInductionVars(); |
1207 | if (isNormalized()) { |
1208 | p << ") in "; |
1209 | printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), |
1210 | /*valueTypes=*/{}, /*scalables=*/{}, |
1211 | OpAsmParser::Delimiter::Paren); |
1212 | } else { |
1213 | p << ") = "; |
1214 | printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(), |
1215 | /*valueTypes=*/{}, /*scalables=*/{}, |
1216 | OpAsmParser::Delimiter::Paren); |
1217 | p << " to "; |
1218 | printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(), |
1219 | /*valueTypes=*/{}, /*scalables=*/{}, |
1220 | OpAsmParser::Delimiter::Paren); |
1221 | p << " step "; |
1222 | printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(), |
1223 | /*valueTypes=*/{}, /*scalables=*/{}, |
1224 | OpAsmParser::Delimiter::Paren); |
1225 | } |
1226 | printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs"); |
1227 | p << " "; |
1228 | if (!getRegionOutArgs().empty()) |
1229 | p << "-> ("<< getResultTypes() << ") "; |
1230 | p.printRegion(getRegion(), |
1231 | /*printEntryBlockArgs=*/false, |
1232 | /*printBlockTerminators=*/getNumResults() > 0); |
1233 | p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(), |
1234 | getStaticLowerBoundAttrName(), |
1235 | getStaticUpperBoundAttrName(), |
1236 | getStaticStepAttrName()}); |
1237 | } |
1238 | |
1239 | ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) { |
1240 | OpBuilder b(parser.getContext()); |
1241 | auto indexType = b.getIndexType(); |
1242 | |
1243 | // Parse an opening `(` followed by thread index variables followed by `)` |
1244 | // TODO: when we can refer to such "induction variable"-like handles from the |
1245 | // declarative assembly format, we can implement the parser as a custom hook. |
1246 | SmallVector<OpAsmParser::Argument, 4> ivs; |
1247 | if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren)) |
1248 | return failure(); |
1249 | |
1250 | DenseI64ArrayAttr staticLbs, staticUbs, staticSteps; |
1251 | SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs, |
1252 | dynamicSteps; |
1253 | if (succeeded(parser.parseOptionalKeyword("in"))) { |
1254 | // Parse upper bounds. |
1255 | if (parseDynamicIndexList(parser, dynamicUbs, staticUbs, |
1256 | /*valueTypes=*/nullptr, |
1257 | OpAsmParser::Delimiter::Paren) || |
1258 | parser.resolveOperands(dynamicUbs, indexType, result.operands)) |
1259 | return failure(); |
1260 | |
1261 | unsigned numLoops = ivs.size(); |
1262 | staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0)); |
1263 | staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1)); |
1264 | } else { |
1265 | // Parse lower bounds. |
1266 | if (parser.parseEqual() || |
1267 | parseDynamicIndexList(parser, dynamicLbs, staticLbs, |
1268 | /*valueTypes=*/nullptr, |
1269 | OpAsmParser::Delimiter::Paren) || |
1270 | |
1271 | parser.resolveOperands(dynamicLbs, indexType, result.operands)) |
1272 | return failure(); |
1273 | |
1274 | // Parse upper bounds. |
1275 | if (parser.parseKeyword("to") || |
1276 | parseDynamicIndexList(parser, dynamicUbs, staticUbs, |
1277 | /*valueTypes=*/nullptr, |
1278 | OpAsmParser::Delimiter::Paren) || |
1279 | parser.resolveOperands(dynamicUbs, indexType, result.operands)) |
1280 | return failure(); |
1281 | |
1282 | // Parse step values. |
1283 | if (parser.parseKeyword("step") || |
1284 | parseDynamicIndexList(parser, dynamicSteps, staticSteps, |
1285 | /*valueTypes=*/nullptr, |
1286 | OpAsmParser::Delimiter::Paren) || |
1287 | parser.resolveOperands(dynamicSteps, indexType, result.operands)) |
1288 | return failure(); |
1289 | } |
1290 | |
1291 | // Parse out operands and results. |
1292 | SmallVector<OpAsmParser::Argument, 4> regionOutArgs; |
1293 | SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands; |
1294 | SMLoc outOperandsLoc = parser.getCurrentLocation(); |
1295 | if (succeeded(parser.parseOptionalKeyword("shared_outs"))) { |
1296 | if (outOperands.size() != result.types.size()) |
1297 | return parser.emitError(outOperandsLoc, |
1298 | "mismatch between out operands and types"); |
1299 | if (parser.parseAssignmentList(regionOutArgs, outOperands) || |
1300 | parser.parseOptionalArrowTypeList(result.types) || |
1301 | parser.resolveOperands(outOperands, result.types, outOperandsLoc, |
1302 | result.operands)) |
1303 | return failure(); |
1304 | } |
1305 | |
1306 | // Parse region. |
1307 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
1308 | std::unique_ptr<Region> region = std::make_unique<Region>(); |
1309 | for (auto &iv : ivs) { |
1310 | iv.type = b.getIndexType(); |
1311 | regionArgs.push_back(iv); |
1312 | } |
1313 | for (const auto &it : llvm::enumerate(regionOutArgs)) { |
1314 | auto &out = it.value(); |
1315 | out.type = result.types[it.index()]; |
1316 | regionArgs.push_back(out); |
1317 | } |
1318 | if (parser.parseRegion(*region, regionArgs)) |
1319 | return failure(); |
1320 | |
1321 | // Ensure terminator and move region. |
1322 | ForallOp::ensureTerminator(*region, b, result.location); |
1323 | result.addRegion(std::move(region)); |
1324 | |
1325 | // Parse the optional attribute list. |
1326 | if (parser.parseOptionalAttrDict(result.attributes)) |
1327 | return failure(); |
1328 | |
1329 | result.addAttribute("staticLowerBound", staticLbs); |
1330 | result.addAttribute("staticUpperBound", staticUbs); |
1331 | result.addAttribute("staticStep", staticSteps); |
1332 | result.addAttribute("operandSegmentSizes", |
1333 | parser.getBuilder().getDenseI32ArrayAttr( |
1334 | {static_cast<int32_t>(dynamicLbs.size()), |
1335 | static_cast<int32_t>(dynamicUbs.size()), |
1336 | static_cast<int32_t>(dynamicSteps.size()), |
1337 | static_cast<int32_t>(outOperands.size())})); |
1338 | return success(); |
1339 | } |
1340 | |
1341 | // Builder that takes loop bounds. |
1342 | void ForallOp::build( |
1343 | mlir::OpBuilder &b, mlir::OperationState &result, |
1344 | ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, |
1345 | ArrayRef<OpFoldResult> steps, ValueRange outputs, |
1346 | std::optional<ArrayAttr> mapping, |
1347 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
1348 | SmallVector<int64_t> staticLbs, staticUbs, staticSteps; |
1349 | SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps; |
1350 | dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs); |
1351 | dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs); |
1352 | dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps); |
1353 | |
1354 | result.addOperands(dynamicLbs); |
1355 | result.addOperands(dynamicUbs); |
1356 | result.addOperands(dynamicSteps); |
1357 | result.addOperands(outputs); |
1358 | result.addTypes(TypeRange(outputs)); |
1359 | |
1360 | result.addAttribute(getStaticLowerBoundAttrName(result.name), |
1361 | b.getDenseI64ArrayAttr(staticLbs)); |
1362 | result.addAttribute(getStaticUpperBoundAttrName(result.name), |
1363 | b.getDenseI64ArrayAttr(staticUbs)); |
1364 | result.addAttribute(getStaticStepAttrName(result.name), |
1365 | b.getDenseI64ArrayAttr(staticSteps)); |
1366 | result.addAttribute( |
1367 | "operandSegmentSizes", |
1368 | b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()), |
1369 | static_cast<int32_t>(dynamicUbs.size()), |
1370 | static_cast<int32_t>(dynamicSteps.size()), |
1371 | static_cast<int32_t>(outputs.size())})); |
1372 | if (mapping.has_value()) { |
1373 | result.addAttribute(ForallOp::getMappingAttrName(result.name), |
1374 | mapping.value()); |
1375 | } |
1376 | |
1377 | Region *bodyRegion = result.addRegion(); |
1378 | OpBuilder::InsertionGuard g(b); |
1379 | b.createBlock(bodyRegion); |
1380 | Block &bodyBlock = bodyRegion->front(); |
1381 | |
1382 | // Add block arguments for indices and outputs. |
1383 | bodyBlock.addArguments( |
1384 | SmallVector<Type>(lbs.size(), b.getIndexType()), |
1385 | SmallVector<Location>(staticLbs.size(), result.location)); |
1386 | bodyBlock.addArguments( |
1387 | TypeRange(outputs), |
1388 | SmallVector<Location>(outputs.size(), result.location)); |
1389 | |
1390 | b.setInsertionPointToStart(&bodyBlock); |
1391 | if (!bodyBuilderFn) { |
1392 | ForallOp::ensureTerminator(*bodyRegion, b, result.location); |
1393 | return; |
1394 | } |
1395 | bodyBuilderFn(b, result.location, bodyBlock.getArguments()); |
1396 | } |
1397 | |
1398 | // Builder that takes loop bounds. |
1399 | void ForallOp::build( |
1400 | mlir::OpBuilder &b, mlir::OperationState &result, |
1401 | ArrayRef<OpFoldResult> ubs, ValueRange outputs, |
1402 | std::optional<ArrayAttr> mapping, |
1403 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
1404 | unsigned numLoops = ubs.size(); |
1405 | SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0)); |
1406 | SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1)); |
1407 | build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn); |
1408 | } |
1409 | |
1410 | // Checks if the lbs are zeros and steps are ones. |
1411 | bool ForallOp::isNormalized() { |
1412 | auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) { |
1413 | return llvm::all_of(results, [&](OpFoldResult ofr) { |
1414 | auto intValue = getConstantIntValue(ofr); |
1415 | return intValue.has_value() && intValue == val; |
1416 | }); |
1417 | }; |
1418 | return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1); |
1419 | } |
1420 | |
1421 | InParallelOp ForallOp::getTerminator() { |
1422 | return cast<InParallelOp>(getBody()->getTerminator()); |
1423 | } |
1424 | |
1425 | SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) { |
1426 | SmallVector<Operation *> storeOps; |
1427 | InParallelOp inParallelOp = getTerminator(); |
1428 | for (Operation &yieldOp : inParallelOp.getYieldingOps()) { |
1429 | if (auto parallelInsertSliceOp = |
1430 | dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp); |
1431 | parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) { |
1432 | storeOps.push_back(parallelInsertSliceOp); |
1433 | } |
1434 | } |
1435 | return storeOps; |
1436 | } |
1437 | |
1438 | std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() { |
1439 | return SmallVector<Value>{getBody()->getArguments().take_front(getRank())}; |
1440 | } |
1441 | |
1442 | // Get lower bounds as OpFoldResult. |
1443 | std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() { |
1444 | Builder b(getOperation()->getContext()); |
1445 | return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b); |
1446 | } |
1447 | |
1448 | // Get upper bounds as OpFoldResult. |
1449 | std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() { |
1450 | Builder b(getOperation()->getContext()); |
1451 | return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b); |
1452 | } |
1453 | |
1454 | // Get steps as OpFoldResult. |
1455 | std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() { |
1456 | Builder b(getOperation()->getContext()); |
1457 | return getMixedValues(getStaticStep(), getDynamicStep(), b); |
1458 | } |
1459 | |
1460 | ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) { |
1461 | auto tidxArg = llvm::dyn_cast<BlockArgument>(Val&: val); |
1462 | if (!tidxArg) |
1463 | return ForallOp(); |
1464 | assert(tidxArg.getOwner() && "unlinked block argument"); |
1465 | auto *containingOp = tidxArg.getOwner()->getParentOp(); |
1466 | return dyn_cast<ForallOp>(containingOp); |
1467 | } |
1468 | |
1469 | namespace { |
1470 | /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t). |
1471 | struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> { |
1472 | using OpRewritePattern<tensor::DimOp>::OpRewritePattern; |
1473 | |
1474 | LogicalResult matchAndRewrite(tensor::DimOp dimOp, |
1475 | PatternRewriter &rewriter) const final { |
1476 | auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>(); |
1477 | if (!forallOp) |
1478 | return failure(); |
1479 | Value sharedOut = |
1480 | forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource())) |
1481 | ->get(); |
1482 | rewriter.modifyOpInPlace( |
1483 | dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); }); |
1484 | return success(); |
1485 | } |
1486 | }; |
1487 | |
1488 | class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> { |
1489 | public: |
1490 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
1491 | |
1492 | LogicalResult matchAndRewrite(ForallOp op, |
1493 | PatternRewriter &rewriter) const override { |
1494 | SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound()); |
1495 | SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound()); |
1496 | SmallVector<OpFoldResult> mixedStep(op.getMixedStep()); |
1497 | if (failed(Result: foldDynamicIndexList(ofrs&: mixedLowerBound)) && |
1498 | failed(Result: foldDynamicIndexList(ofrs&: mixedUpperBound)) && |
1499 | failed(Result: foldDynamicIndexList(ofrs&: mixedStep))) |
1500 | return failure(); |
1501 | |
1502 | rewriter.modifyOpInPlace(op, [&]() { |
1503 | SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep; |
1504 | SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep; |
1505 | dispatchIndexOpFoldResults(ofrs: mixedLowerBound, dynamicVec&: dynamicLowerBound, |
1506 | staticVec&: staticLowerBound); |
1507 | op.getDynamicLowerBoundMutable().assign(dynamicLowerBound); |
1508 | op.setStaticLowerBound(staticLowerBound); |
1509 | |
1510 | dispatchIndexOpFoldResults(ofrs: mixedUpperBound, dynamicVec&: dynamicUpperBound, |
1511 | staticVec&: staticUpperBound); |
1512 | op.getDynamicUpperBoundMutable().assign(dynamicUpperBound); |
1513 | op.setStaticUpperBound(staticUpperBound); |
1514 | |
1515 | dispatchIndexOpFoldResults(ofrs: mixedStep, dynamicVec&: dynamicStep, staticVec&: staticStep); |
1516 | op.getDynamicStepMutable().assign(dynamicStep); |
1517 | op.setStaticStep(staticStep); |
1518 | |
1519 | op->setAttr(ForallOp::getOperandSegmentSizeAttr(), |
1520 | rewriter.getDenseI32ArrayAttr( |
1521 | {static_cast<int32_t>(dynamicLowerBound.size()), |
1522 | static_cast<int32_t>(dynamicUpperBound.size()), |
1523 | static_cast<int32_t>(dynamicStep.size()), |
1524 | static_cast<int32_t>(op.getNumResults())})); |
1525 | }); |
1526 | return success(); |
1527 | } |
1528 | }; |
1529 | |
1530 | /// The following canonicalization pattern folds the iter arguments of |
1531 | /// scf.forall op if :- |
1532 | /// 1. The corresponding result has zero uses. |
1533 | /// 2. The iter argument is NOT being modified within the loop body. |
1534 | /// uses. |
1535 | /// |
1536 | /// Example of first case :- |
1537 | /// INPUT: |
1538 | /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c) |
1539 | /// { |
1540 | /// ... |
1541 | /// <SOME USE OF %arg0> |
1542 | /// <SOME USE OF %arg1> |
1543 | /// <SOME USE OF %arg2> |
1544 | /// ... |
1545 | /// scf.forall.in_parallel { |
1546 | /// <STORE OP WITH DESTINATION %arg1> |
1547 | /// <STORE OP WITH DESTINATION %arg0> |
1548 | /// <STORE OP WITH DESTINATION %arg2> |
1549 | /// } |
1550 | /// } |
1551 | /// return %res#1 |
1552 | /// |
1553 | /// OUTPUT: |
1554 | /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b) |
1555 | /// { |
1556 | /// ... |
1557 | /// <SOME USE OF %a> |
1558 | /// <SOME USE OF %new_arg0> |
1559 | /// <SOME USE OF %c> |
1560 | /// ... |
1561 | /// scf.forall.in_parallel { |
1562 | /// <STORE OP WITH DESTINATION %new_arg0> |
1563 | /// } |
1564 | /// } |
1565 | /// return %res |
1566 | /// |
1567 | /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the |
1568 | /// scf.forall is replaced by their corresponding operands. |
1569 | /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body |
1570 | /// of the scf.forall besides within scf.forall.in_parallel terminator, |
1571 | /// this canonicalization remains valid. For more details, please refer |
1572 | /// to : |
1573 | /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124 |
1574 | /// 3. TODO(avarma): Generalize it for other store ops. Currently it |
1575 | /// handles tensor.parallel_insert_slice ops only. |
1576 | /// |
1577 | /// Example of second case :- |
1578 | /// INPUT: |
1579 | /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b) |
1580 | /// { |
1581 | /// ... |
1582 | /// <SOME USE OF %arg0> |
1583 | /// <SOME USE OF %arg1> |
1584 | /// ... |
1585 | /// scf.forall.in_parallel { |
1586 | /// <STORE OP WITH DESTINATION %arg1> |
1587 | /// } |
1588 | /// } |
1589 | /// return %res#0, %res#1 |
1590 | /// |
1591 | /// OUTPUT: |
1592 | /// %res = scf.forall ... shared_outs(%new_arg0 = %b) |
1593 | /// { |
1594 | /// ... |
1595 | /// <SOME USE OF %a> |
1596 | /// <SOME USE OF %new_arg0> |
1597 | /// ... |
1598 | /// scf.forall.in_parallel { |
1599 | /// <STORE OP WITH DESTINATION %new_arg0> |
1600 | /// } |
1601 | /// } |
1602 | /// return %a, %res |
1603 | struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> { |
1604 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
1605 | |
1606 | LogicalResult matchAndRewrite(ForallOp forallOp, |
1607 | PatternRewriter &rewriter) const final { |
1608 | // Step 1: For a given i-th result of scf.forall, check the following :- |
1609 | // a. If it has any use. |
1610 | // b. If the corresponding iter argument is being modified within |
1611 | // the loop, i.e. has at least one store op with the iter arg as |
1612 | // its destination operand. For this we use |
1613 | // ForallOp::getCombiningOps(iter_arg). |
1614 | // |
1615 | // Based on the check we maintain the following :- |
1616 | // a. `resultToDelete` - i-th result of scf.forall that'll be |
1617 | // deleted. |
1618 | // b. `resultToReplace` - i-th result of the old scf.forall |
1619 | // whose uses will be replaced by the new scf.forall. |
1620 | // c. `newOuts` - the shared_outs' operand of the new scf.forall |
1621 | // corresponding to the i-th result with at least one use. |
1622 | SetVector<OpResult> resultToDelete; |
1623 | SmallVector<Value> resultToReplace; |
1624 | SmallVector<Value> newOuts; |
1625 | for (OpResult result : forallOp.getResults()) { |
1626 | OpOperand *opOperand = forallOp.getTiedOpOperand(result); |
1627 | BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand); |
1628 | if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) { |
1629 | resultToDelete.insert(result); |
1630 | } else { |
1631 | resultToReplace.push_back(result); |
1632 | newOuts.push_back(opOperand->get()); |
1633 | } |
1634 | } |
1635 | |
1636 | // Return early if all results of scf.forall have at least one use and being |
1637 | // modified within the loop. |
1638 | if (resultToDelete.empty()) |
1639 | return failure(); |
1640 | |
1641 | // Step 2: For the the i-th result, do the following :- |
1642 | // a. Fetch the corresponding BlockArgument. |
1643 | // b. Look for store ops (currently tensor.parallel_insert_slice) |
1644 | // with the BlockArgument as its destination operand. |
1645 | // c. Remove the operations fetched in b. |
1646 | for (OpResult result : resultToDelete) { |
1647 | OpOperand *opOperand = forallOp.getTiedOpOperand(result); |
1648 | BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand); |
1649 | SmallVector<Operation *> combiningOps = |
1650 | forallOp.getCombiningOps(blockArg); |
1651 | for (Operation *combiningOp : combiningOps) |
1652 | rewriter.eraseOp(combiningOp); |
1653 | } |
1654 | |
1655 | // Step 3. Create a new scf.forall op with the new shared_outs' operands |
1656 | // fetched earlier |
1657 | auto newForallOp = rewriter.create<scf::ForallOp>( |
1658 | forallOp.getLoc(), forallOp.getMixedLowerBound(), |
1659 | forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts, |
1660 | forallOp.getMapping(), |
1661 | /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); |
1662 | |
1663 | // Step 4. Merge the block of the old scf.forall into the newly created |
1664 | // scf.forall using the new set of arguments. |
1665 | Block *loopBody = forallOp.getBody(); |
1666 | Block *newLoopBody = newForallOp.getBody(); |
1667 | ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments(); |
1668 | // Form initial new bbArg list with just the control operands of the new |
1669 | // scf.forall op. |
1670 | SmallVector<Value> newBlockArgs = |
1671 | llvm::map_to_vector(newBbArgs.take_front(N: forallOp.getRank()), |
1672 | [](BlockArgument b) -> Value { return b; }); |
1673 | Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs(); |
1674 | unsigned index = 0; |
1675 | // Take the new corresponding bbArg if the old bbArg was used as a |
1676 | // destination in the in_parallel op. For all other bbArgs, use the |
1677 | // corresponding init_arg from the old scf.forall op. |
1678 | for (OpResult result : forallOp.getResults()) { |
1679 | if (resultToDelete.count(result)) { |
1680 | newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get()); |
1681 | } else { |
1682 | newBlockArgs.push_back(newSharedOutsArgs[index++]); |
1683 | } |
1684 | } |
1685 | rewriter.mergeBlocks(source: loopBody, dest: newLoopBody, argValues: newBlockArgs); |
1686 | |
1687 | // Step 5. Replace the uses of result of old scf.forall with that of the new |
1688 | // scf.forall. |
1689 | for (auto &&[oldResult, newResult] : |
1690 | llvm::zip(resultToReplace, newForallOp->getResults())) |
1691 | rewriter.replaceAllUsesWith(oldResult, newResult); |
1692 | |
1693 | // Step 6. Replace the uses of those values that either has no use or are |
1694 | // not being modified within the loop with the corresponding |
1695 | // OpOperand. |
1696 | for (OpResult oldResult : resultToDelete) |
1697 | rewriter.replaceAllUsesWith(oldResult, |
1698 | forallOp.getTiedOpOperand(oldResult)->get()); |
1699 | return success(); |
1700 | } |
1701 | }; |
1702 | |
1703 | struct ForallOpSingleOrZeroIterationDimsFolder |
1704 | : public OpRewritePattern<ForallOp> { |
1705 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
1706 | |
1707 | LogicalResult matchAndRewrite(ForallOp op, |
1708 | PatternRewriter &rewriter) const override { |
1709 | // Do not fold dimensions if they are mapped to processing units. |
1710 | if (op.getMapping().has_value() && !op.getMapping()->empty()) |
1711 | return failure(); |
1712 | Location loc = op.getLoc(); |
1713 | |
1714 | // Compute new loop bounds that omit all single-iteration loop dimensions. |
1715 | SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds, |
1716 | newMixedSteps; |
1717 | IRMapping mapping; |
1718 | for (auto [lb, ub, step, iv] : |
1719 | llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), |
1720 | op.getMixedStep(), op.getInductionVars())) { |
1721 | auto numIterations = constantTripCount(lb, ub, step); |
1722 | if (numIterations.has_value()) { |
1723 | // Remove the loop if it performs zero iterations. |
1724 | if (*numIterations == 0) { |
1725 | rewriter.replaceOp(op, op.getOutputs()); |
1726 | return success(); |
1727 | } |
1728 | // Replace the loop induction variable by the lower bound if the loop |
1729 | // performs a single iteration. Otherwise, copy the loop bounds. |
1730 | if (*numIterations == 1) { |
1731 | mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb)); |
1732 | continue; |
1733 | } |
1734 | } |
1735 | newMixedLowerBounds.push_back(lb); |
1736 | newMixedUpperBounds.push_back(ub); |
1737 | newMixedSteps.push_back(step); |
1738 | } |
1739 | |
1740 | // All of the loop dimensions perform a single iteration. Inline loop body. |
1741 | if (newMixedLowerBounds.empty()) { |
1742 | promote(rewriter, op); |
1743 | return success(); |
1744 | } |
1745 | |
1746 | // Exit if none of the loop dimensions perform a single iteration. |
1747 | if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) { |
1748 | return rewriter.notifyMatchFailure( |
1749 | op, "no dimensions have 0 or 1 iterations"); |
1750 | } |
1751 | |
1752 | // Replace the loop by a lower-dimensional loop. |
1753 | ForallOp newOp; |
1754 | newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds, |
1755 | newMixedUpperBounds, newMixedSteps, |
1756 | op.getOutputs(), std::nullopt, nullptr); |
1757 | newOp.getBodyRegion().getBlocks().clear(); |
1758 | // The new loop needs to keep all attributes from the old one, except for |
1759 | // "operandSegmentSizes" and static loop bound attributes which capture |
1760 | // the outdated information of the old iteration domain. |
1761 | SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(), |
1762 | newOp.getStaticLowerBoundAttrName(), |
1763 | newOp.getStaticUpperBoundAttrName(), |
1764 | newOp.getStaticStepAttrName()}; |
1765 | for (const auto &namedAttr : op->getAttrs()) { |
1766 | if (llvm::is_contained(elidedAttrs, namedAttr.getName())) |
1767 | continue; |
1768 | rewriter.modifyOpInPlace(newOp, [&]() { |
1769 | newOp->setAttr(namedAttr.getName(), namedAttr.getValue()); |
1770 | }); |
1771 | } |
1772 | rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(), |
1773 | newOp.getRegion().begin(), mapping); |
1774 | rewriter.replaceOp(op, newOp.getResults()); |
1775 | return success(); |
1776 | } |
1777 | }; |
1778 | |
1779 | /// Replace all induction vars with a single trip count with their lower bound. |
1780 | struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> { |
1781 | using OpRewritePattern<ForallOp>::OpRewritePattern; |
1782 | |
1783 | LogicalResult matchAndRewrite(ForallOp op, |
1784 | PatternRewriter &rewriter) const override { |
1785 | Location loc = op.getLoc(); |
1786 | bool changed = false; |
1787 | for (auto [lb, ub, step, iv] : |
1788 | llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), |
1789 | op.getMixedStep(), op.getInductionVars())) { |
1790 | if (iv.hasNUses(0)) |
1791 | continue; |
1792 | auto numIterations = constantTripCount(lb, ub, step); |
1793 | if (!numIterations.has_value() || numIterations.value() != 1) { |
1794 | continue; |
1795 | } |
1796 | rewriter.replaceAllUsesWith( |
1797 | iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb)); |
1798 | changed = true; |
1799 | } |
1800 | return success(IsSuccess: changed); |
1801 | } |
1802 | }; |
1803 | |
1804 | struct FoldTensorCastOfOutputIntoForallOp |
1805 | : public OpRewritePattern<scf::ForallOp> { |
1806 | using OpRewritePattern<scf::ForallOp>::OpRewritePattern; |
1807 | |
1808 | struct TypeCast { |
1809 | Type srcType; |
1810 | Type dstType; |
1811 | }; |
1812 | |
1813 | LogicalResult matchAndRewrite(scf::ForallOp forallOp, |
1814 | PatternRewriter &rewriter) const final { |
1815 | llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers; |
1816 | llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs(); |
1817 | for (auto en : llvm::enumerate(newOutputTensors)) { |
1818 | auto castOp = en.value().getDefiningOp<tensor::CastOp>(); |
1819 | if (!castOp) |
1820 | continue; |
1821 | |
1822 | // Only casts that that preserve static information, i.e. will make the |
1823 | // loop result type "more" static than before, will be folded. |
1824 | if (!tensor::preservesStaticInformation(castOp.getDest().getType(), |
1825 | castOp.getSource().getType())) { |
1826 | continue; |
1827 | } |
1828 | |
1829 | tensorCastProducers[en.index()] = |
1830 | TypeCast{castOp.getSource().getType(), castOp.getType()}; |
1831 | newOutputTensors[en.index()] = castOp.getSource(); |
1832 | } |
1833 | |
1834 | if (tensorCastProducers.empty()) |
1835 | return failure(); |
1836 | |
1837 | // Create new loop. |
1838 | Location loc = forallOp.getLoc(); |
1839 | auto newForallOp = rewriter.create<ForallOp>( |
1840 | loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), |
1841 | forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(), |
1842 | [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) { |
1843 | auto castBlockArgs = |
1844 | llvm::to_vector(bbArgs.take_back(forallOp->getNumResults())); |
1845 | for (auto [index, cast] : tensorCastProducers) { |
1846 | Value &oldTypeBBArg = castBlockArgs[index]; |
1847 | oldTypeBBArg = nestedBuilder.create<tensor::CastOp>( |
1848 | nestedLoc, cast.dstType, oldTypeBBArg); |
1849 | } |
1850 | |
1851 | // Move old body into new parallel loop. |
1852 | SmallVector<Value> ivsBlockArgs = |
1853 | llvm::to_vector(bbArgs.take_front(forallOp.getRank())); |
1854 | ivsBlockArgs.append(castBlockArgs); |
1855 | rewriter.mergeBlocks(forallOp.getBody(), |
1856 | bbArgs.front().getParentBlock(), ivsBlockArgs); |
1857 | }); |
1858 | |
1859 | // After `mergeBlocks` happened, the destinations in the terminator were |
1860 | // mapped to the tensor.cast old-typed results of the output bbArgs. The |
1861 | // destination have to be updated to point to the output bbArgs directly. |
1862 | auto terminator = newForallOp.getTerminator(); |
1863 | for (auto [yieldingOp, outputBlockArg] : llvm::zip( |
1864 | terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) { |
1865 | auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp); |
1866 | insertSliceOp.getDestMutable().assign(outputBlockArg); |
1867 | } |
1868 | |
1869 | // Cast results back to the original types. |
1870 | rewriter.setInsertionPointAfter(newForallOp); |
1871 | SmallVector<Value> castResults = newForallOp.getResults(); |
1872 | for (auto &item : tensorCastProducers) { |
1873 | Value &oldTypeResult = castResults[item.first]; |
1874 | oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType, |
1875 | oldTypeResult); |
1876 | } |
1877 | rewriter.replaceOp(forallOp, castResults); |
1878 | return success(); |
1879 | } |
1880 | }; |
1881 | |
1882 | } // namespace |
1883 | |
1884 | void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results, |
1885 | MLIRContext *context) { |
1886 | results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp, |
1887 | ForallOpControlOperandsFolder, ForallOpIterArgsFolder, |
1888 | ForallOpSingleOrZeroIterationDimsFolder, |
1889 | ForallOpReplaceConstantInductionVar>(context); |
1890 | } |
1891 | |
1892 | /// Given the region at `index`, or the parent operation if `index` is None, |
1893 | /// return the successor regions. These are the regions that may be selected |
1894 | /// during the flow of control. `operands` is a set of optional attributes that |
1895 | /// correspond to a constant value for each operand, or null if that operand is |
1896 | /// not a constant. |
1897 | void ForallOp::getSuccessorRegions(RegionBranchPoint point, |
1898 | SmallVectorImpl<RegionSuccessor> ®ions) { |
1899 | // Both the operation itself and the region may be branching into the body or |
1900 | // back into the operation itself. It is possible for loop not to enter the |
1901 | // body. |
1902 | regions.push_back(RegionSuccessor(&getRegion())); |
1903 | regions.push_back(RegionSuccessor()); |
1904 | } |
1905 | |
1906 | //===----------------------------------------------------------------------===// |
1907 | // InParallelOp |
1908 | //===----------------------------------------------------------------------===// |
1909 | |
1910 | // Build a InParallelOp with mixed static and dynamic entries. |
1911 | void InParallelOp::build(OpBuilder &b, OperationState &result) { |
1912 | OpBuilder::InsertionGuard g(b); |
1913 | Region *bodyRegion = result.addRegion(); |
1914 | b.createBlock(bodyRegion); |
1915 | } |
1916 | |
1917 | LogicalResult InParallelOp::verify() { |
1918 | scf::ForallOp forallOp = |
1919 | dyn_cast<scf::ForallOp>(getOperation()->getParentOp()); |
1920 | if (!forallOp) |
1921 | return this->emitOpError("expected forall op parent"); |
1922 | |
1923 | // TODO: InParallelOpInterface. |
1924 | for (Operation &op : getRegion().front().getOperations()) { |
1925 | if (!isa<tensor::ParallelInsertSliceOp>(op)) { |
1926 | return this->emitOpError("expected only ") |
1927 | << tensor::ParallelInsertSliceOp::getOperationName() << " ops"; |
1928 | } |
1929 | |
1930 | // Verify that inserts are into out block arguments. |
1931 | Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest(); |
1932 | ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs(); |
1933 | if (!llvm::is_contained(regionOutArgs, dest)) |
1934 | return op.emitOpError("may only insert into an output block argument"); |
1935 | } |
1936 | return success(); |
1937 | } |
1938 | |
1939 | void InParallelOp::print(OpAsmPrinter &p) { |
1940 | p << " "; |
1941 | p.printRegion(getRegion(), |
1942 | /*printEntryBlockArgs=*/false, |
1943 | /*printBlockTerminators=*/false); |
1944 | p.printOptionalAttrDict(getOperation()->getAttrs()); |
1945 | } |
1946 | |
1947 | ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) { |
1948 | auto &builder = parser.getBuilder(); |
1949 | |
1950 | SmallVector<OpAsmParser::Argument, 8> regionOperands; |
1951 | std::unique_ptr<Region> region = std::make_unique<Region>(); |
1952 | if (parser.parseRegion(*region, regionOperands)) |
1953 | return failure(); |
1954 | |
1955 | if (region->empty()) |
1956 | OpBuilder(builder.getContext()).createBlock(region.get()); |
1957 | result.addRegion(std::move(region)); |
1958 | |
1959 | // Parse the optional attribute list. |
1960 | if (parser.parseOptionalAttrDict(result.attributes)) |
1961 | return failure(); |
1962 | return success(); |
1963 | } |
1964 | |
1965 | OpResult InParallelOp::getParentResult(int64_t idx) { |
1966 | return getOperation()->getParentOp()->getResult(idx); |
1967 | } |
1968 | |
1969 | SmallVector<BlockArgument> InParallelOp::getDests() { |
1970 | return llvm::to_vector<4>( |
1971 | llvm::map_range(getYieldingOps(), [](Operation &op) { |
1972 | // Add new ops here as needed. |
1973 | auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op); |
1974 | return llvm::cast<BlockArgument>(insertSliceOp.getDest()); |
1975 | })); |
1976 | } |
1977 | |
1978 | llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() { |
1979 | return getRegion().front().getOperations(); |
1980 | } |
1981 | |
1982 | //===----------------------------------------------------------------------===// |
1983 | // IfOp |
1984 | //===----------------------------------------------------------------------===// |
1985 | |
1986 | bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) { |
1987 | assert(a && "expected non-empty operation"); |
1988 | assert(b && "expected non-empty operation"); |
1989 | |
1990 | IfOp ifOp = a->getParentOfType<IfOp>(); |
1991 | while (ifOp) { |
1992 | // Check if b is inside ifOp. (We already know that a is.) |
1993 | if (ifOp->isProperAncestor(b)) |
1994 | // b is contained in ifOp. a and b are in mutually exclusive branches if |
1995 | // they are in different blocks of ifOp. |
1996 | return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) != |
1997 | static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b)); |
1998 | // Check next enclosing IfOp. |
1999 | ifOp = ifOp->getParentOfType<IfOp>(); |
2000 | } |
2001 | |
2002 | // Could not find a common IfOp among a's and b's ancestors. |
2003 | return false; |
2004 | } |
2005 | |
2006 | LogicalResult |
2007 | IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, |
2008 | IfOp::Adaptor adaptor, |
2009 | SmallVectorImpl<Type> &inferredReturnTypes) { |
2010 | if (adaptor.getRegions().empty()) |
2011 | return failure(); |
2012 | Region *r = &adaptor.getThenRegion(); |
2013 | if (r->empty()) |
2014 | return failure(); |
2015 | Block &b = r->front(); |
2016 | if (b.empty()) |
2017 | return failure(); |
2018 | auto yieldOp = llvm::dyn_cast<YieldOp>(b.back()); |
2019 | if (!yieldOp) |
2020 | return failure(); |
2021 | TypeRange types = yieldOp.getOperandTypes(); |
2022 | llvm::append_range(inferredReturnTypes, types); |
2023 | return success(); |
2024 | } |
2025 | |
2026 | void IfOp::build(OpBuilder &builder, OperationState &result, |
2027 | TypeRange resultTypes, Value cond) { |
2028 | return build(builder, result, resultTypes, cond, /*addThenBlock=*/false, |
2029 | /*addElseBlock=*/false); |
2030 | } |
2031 | |
2032 | void IfOp::build(OpBuilder &builder, OperationState &result, |
2033 | TypeRange resultTypes, Value cond, bool addThenBlock, |
2034 | bool addElseBlock) { |
2035 | assert((!addElseBlock || addThenBlock) && |
2036 | "must not create else block w/o then block"); |
2037 | result.addTypes(resultTypes); |
2038 | result.addOperands(cond); |
2039 | |
2040 | // Add regions and blocks. |
2041 | OpBuilder::InsertionGuard guard(builder); |
2042 | Region *thenRegion = result.addRegion(); |
2043 | if (addThenBlock) |
2044 | builder.createBlock(thenRegion); |
2045 | Region *elseRegion = result.addRegion(); |
2046 | if (addElseBlock) |
2047 | builder.createBlock(elseRegion); |
2048 | } |
2049 | |
2050 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
2051 | bool withElseRegion) { |
2052 | build(builder, result, TypeRange{}, cond, withElseRegion); |
2053 | } |
2054 | |
2055 | void IfOp::build(OpBuilder &builder, OperationState &result, |
2056 | TypeRange resultTypes, Value cond, bool withElseRegion) { |
2057 | result.addTypes(resultTypes); |
2058 | result.addOperands(cond); |
2059 | |
2060 | // Build then region. |
2061 | OpBuilder::InsertionGuard guard(builder); |
2062 | Region *thenRegion = result.addRegion(); |
2063 | builder.createBlock(thenRegion); |
2064 | if (resultTypes.empty()) |
2065 | IfOp::ensureTerminator(*thenRegion, builder, result.location); |
2066 | |
2067 | // Build else region. |
2068 | Region *elseRegion = result.addRegion(); |
2069 | if (withElseRegion) { |
2070 | builder.createBlock(elseRegion); |
2071 | if (resultTypes.empty()) |
2072 | IfOp::ensureTerminator(*elseRegion, builder, result.location); |
2073 | } |
2074 | } |
2075 | |
2076 | void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, |
2077 | function_ref<void(OpBuilder &, Location)> thenBuilder, |
2078 | function_ref<void(OpBuilder &, Location)> elseBuilder) { |
2079 | assert(thenBuilder && "the builder callback for 'then' must be present"); |
2080 | result.addOperands(cond); |
2081 | |
2082 | // Build then region. |
2083 | OpBuilder::InsertionGuard guard(builder); |
2084 | Region *thenRegion = result.addRegion(); |
2085 | builder.createBlock(thenRegion); |
2086 | thenBuilder(builder, result.location); |
2087 | |
2088 | // Build else region. |
2089 | Region *elseRegion = result.addRegion(); |
2090 | if (elseBuilder) { |
2091 | builder.createBlock(elseRegion); |
2092 | elseBuilder(builder, result.location); |
2093 | } |
2094 | |
2095 | // Infer result types. |
2096 | SmallVector<Type> inferredReturnTypes; |
2097 | MLIRContext *ctx = builder.getContext(); |
2098 | auto attrDict = DictionaryAttr::get(ctx, result.attributes); |
2099 | if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict, |
2100 | /*properties=*/nullptr, result.regions, |
2101 | inferredReturnTypes))) { |
2102 | result.addTypes(inferredReturnTypes); |
2103 | } |
2104 | } |
2105 | |
2106 | LogicalResult IfOp::verify() { |
2107 | if (getNumResults() != 0 && getElseRegion().empty()) |
2108 | return emitOpError("must have an else block if defining values"); |
2109 | return success(); |
2110 | } |
2111 | |
2112 | ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { |
2113 | // Create the regions for 'then'. |
2114 | result.regions.reserve(2); |
2115 | Region *thenRegion = result.addRegion(); |
2116 | Region *elseRegion = result.addRegion(); |
2117 | |
2118 | auto &builder = parser.getBuilder(); |
2119 | OpAsmParser::UnresolvedOperand cond; |
2120 | Type i1Type = builder.getIntegerType(1); |
2121 | if (parser.parseOperand(cond) || |
2122 | parser.resolveOperand(cond, i1Type, result.operands)) |
2123 | return failure(); |
2124 | // Parse optional results type list. |
2125 | if (parser.parseOptionalArrowTypeList(result.types)) |
2126 | return failure(); |
2127 | // Parse the 'then' region. |
2128 | if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) |
2129 | return failure(); |
2130 | IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); |
2131 | |
2132 | // If we find an 'else' keyword then parse the 'else' region. |
2133 | if (!parser.parseOptionalKeyword("else")) { |
2134 | if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) |
2135 | return failure(); |
2136 | IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); |
2137 | } |
2138 | |
2139 | // Parse the optional attribute list. |
2140 | if (parser.parseOptionalAttrDict(result.attributes)) |
2141 | return failure(); |
2142 | return success(); |
2143 | } |
2144 | |
2145 | void IfOp::print(OpAsmPrinter &p) { |
2146 | bool printBlockTerminators = false; |
2147 | |
2148 | p << " "<< getCondition(); |
2149 | if (!getResults().empty()) { |
2150 | p << " -> ("<< getResultTypes() << ")"; |
2151 | // Print yield explicitly if the op defines values. |
2152 | printBlockTerminators = true; |
2153 | } |
2154 | p << ' '; |
2155 | p.printRegion(getThenRegion(), |
2156 | /*printEntryBlockArgs=*/false, |
2157 | /*printBlockTerminators=*/printBlockTerminators); |
2158 | |
2159 | // Print the 'else' regions if it exists and has a block. |
2160 | auto &elseRegion = getElseRegion(); |
2161 | if (!elseRegion.empty()) { |
2162 | p << " else "; |
2163 | p.printRegion(elseRegion, |
2164 | /*printEntryBlockArgs=*/false, |
2165 | /*printBlockTerminators=*/printBlockTerminators); |
2166 | } |
2167 | |
2168 | p.printOptionalAttrDict((*this)->getAttrs()); |
2169 | } |
2170 | |
2171 | void IfOp::getSuccessorRegions(RegionBranchPoint point, |
2172 | SmallVectorImpl<RegionSuccessor> ®ions) { |
2173 | // The `then` and the `else` region branch back to the parent operation. |
2174 | if (!point.isParent()) { |
2175 | regions.push_back(RegionSuccessor(getResults())); |
2176 | return; |
2177 | } |
2178 | |
2179 | regions.push_back(RegionSuccessor(&getThenRegion())); |
2180 | |
2181 | // Don't consider the else region if it is empty. |
2182 | Region *elseRegion = &this->getElseRegion(); |
2183 | if (elseRegion->empty()) |
2184 | regions.push_back(RegionSuccessor()); |
2185 | else |
2186 | regions.push_back(RegionSuccessor(elseRegion)); |
2187 | } |
2188 | |
2189 | void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, |
2190 | SmallVectorImpl<RegionSuccessor> ®ions) { |
2191 | FoldAdaptor adaptor(operands, *this); |
2192 | auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition()); |
2193 | if (!boolAttr || boolAttr.getValue()) |
2194 | regions.emplace_back(&getThenRegion()); |
2195 | |
2196 | // If the else region is empty, execution continues after the parent op. |
2197 | if (!boolAttr || !boolAttr.getValue()) { |
2198 | if (!getElseRegion().empty()) |
2199 | regions.emplace_back(&getElseRegion()); |
2200 | else |
2201 | regions.emplace_back(getResults()); |
2202 | } |
2203 | } |
2204 | |
2205 | LogicalResult IfOp::fold(FoldAdaptor adaptor, |
2206 | SmallVectorImpl<OpFoldResult> &results) { |
2207 | // if (!c) then A() else B() -> if c then B() else A() |
2208 | if (getElseRegion().empty()) |
2209 | return failure(); |
2210 | |
2211 | arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>(); |
2212 | if (!xorStmt) |
2213 | return failure(); |
2214 | |
2215 | if (!matchPattern(xorStmt.getRhs(), m_One())) |
2216 | return failure(); |
2217 | |
2218 | getConditionMutable().assign(xorStmt.getLhs()); |
2219 | Block *thenBlock = &getThenRegion().front(); |
2220 | // It would be nicer to use iplist::swap, but that has no implemented |
2221 | // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224 |
2222 | getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(), |
2223 | getElseRegion().getBlocks()); |
2224 | getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(), |
2225 | getThenRegion().getBlocks(), thenBlock); |
2226 | return success(); |
2227 | } |
2228 | |
2229 | void IfOp::getRegionInvocationBounds( |
2230 | ArrayRef<Attribute> operands, |
2231 | SmallVectorImpl<InvocationBounds> &invocationBounds) { |
2232 | if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) { |
2233 | // If the condition is known, then one region is known to be executed once |
2234 | // and the other zero times. |
2235 | invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); |
2236 | invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1); |
2237 | } else { |
2238 | // Non-constant condition. Each region may be executed 0 or 1 times. |
2239 | invocationBounds.assign(2, {0, 1}); |
2240 | } |
2241 | } |
2242 | |
2243 | namespace { |
2244 | // Pattern to remove unused IfOp results. |
2245 | struct RemoveUnusedResults : public OpRewritePattern<IfOp> { |
2246 | using OpRewritePattern<IfOp>::OpRewritePattern; |
2247 | |
2248 | void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults, |
2249 | PatternRewriter &rewriter) const { |
2250 | // Move all operations to the destination block. |
2251 | rewriter.mergeBlocks(source, dest); |
2252 | // Replace the yield op by one that returns only the used values. |
2253 | auto yieldOp = cast<scf::YieldOp>(dest->getTerminator()); |
2254 | SmallVector<Value, 4> usedOperands; |
2255 | llvm::transform(Range&: usedResults, d_first: std::back_inserter(x&: usedOperands), |
2256 | F: [&](OpResult result) { |
2257 | return yieldOp.getOperand(result.getResultNumber()); |
2258 | }); |
2259 | rewriter.modifyOpInPlace(yieldOp, |
2260 | [&]() { yieldOp->setOperands(usedOperands); }); |
2261 | } |
2262 | |
2263 | LogicalResult matchAndRewrite(IfOp op, |
2264 | PatternRewriter &rewriter) const override { |
2265 | // Compute the list of used results. |
2266 | SmallVector<OpResult, 4> usedResults; |
2267 | llvm::copy_if(op.getResults(), std::back_inserter(x&: usedResults), |
2268 | [](OpResult result) { return !result.use_empty(); }); |
2269 | |
2270 | // Replace the operation if only a subset of its results have uses. |
2271 | if (usedResults.size() == op.getNumResults()) |
2272 | return failure(); |
2273 | |
2274 | // Compute the result types of the replacement operation. |
2275 | SmallVector<Type, 4> newTypes; |
2276 | llvm::transform(Range&: usedResults, d_first: std::back_inserter(x&: newTypes), |
2277 | F: [](OpResult result) { return result.getType(); }); |
2278 | |
2279 | // Create a replacement operation with empty then and else regions. |
2280 | auto newOp = |
2281 | rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition()); |
2282 | rewriter.createBlock(&newOp.getThenRegion()); |
2283 | rewriter.createBlock(&newOp.getElseRegion()); |
2284 | |
2285 | // Move the bodies and replace the terminators (note there is a then and |
2286 | // an else region since the operation returns results). |
2287 | transferBody(source: op.getBody(0), dest: newOp.getBody(0), usedResults, rewriter); |
2288 | transferBody(source: op.getBody(1), dest: newOp.getBody(1), usedResults, rewriter); |
2289 | |
2290 | // Replace the operation by the new one. |
2291 | SmallVector<Value, 4> repResults(op.getNumResults()); |
2292 | for (const auto &en : llvm::enumerate(First&: usedResults)) |
2293 | repResults[en.value().getResultNumber()] = newOp.getResult(en.index()); |
2294 | rewriter.replaceOp(op, repResults); |
2295 | return success(); |
2296 | } |
2297 | }; |
2298 | |
2299 | struct RemoveStaticCondition : public OpRewritePattern<IfOp> { |
2300 | using OpRewritePattern<IfOp>::OpRewritePattern; |
2301 | |
2302 | LogicalResult matchAndRewrite(IfOp op, |
2303 | PatternRewriter &rewriter) const override { |
2304 | BoolAttr condition; |
2305 | if (!matchPattern(op.getCondition(), m_Constant(bind_value: &condition))) |
2306 | return failure(); |
2307 | |
2308 | if (condition.getValue()) |
2309 | replaceOpWithRegion(rewriter, op, op.getThenRegion()); |
2310 | else if (!op.getElseRegion().empty()) |
2311 | replaceOpWithRegion(rewriter, op, op.getElseRegion()); |
2312 | else |
2313 | rewriter.eraseOp(op: op); |
2314 | |
2315 | return success(); |
2316 | } |
2317 | }; |
2318 | |
2319 | /// Hoist any yielded results whose operands are defined outside |
2320 | /// the if, to a select instruction. |
2321 | struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> { |
2322 | using OpRewritePattern<IfOp>::OpRewritePattern; |
2323 | |
2324 | LogicalResult matchAndRewrite(IfOp op, |
2325 | PatternRewriter &rewriter) const override { |
2326 | if (op->getNumResults() == 0) |
2327 | return failure(); |
2328 | |
2329 | auto cond = op.getCondition(); |
2330 | auto thenYieldArgs = op.thenYield().getOperands(); |
2331 | auto elseYieldArgs = op.elseYield().getOperands(); |
2332 | |
2333 | SmallVector<Type> nonHoistable; |
2334 | for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) { |
2335 | if (&op.getThenRegion() == trueVal.getParentRegion() || |
2336 | &op.getElseRegion() == falseVal.getParentRegion()) |
2337 | nonHoistable.push_back(trueVal.getType()); |
2338 | } |
2339 | // Early exit if there aren't any yielded values we can |
2340 | // hoist outside the if. |
2341 | if (nonHoistable.size() == op->getNumResults()) |
2342 | return failure(); |
2343 | |
2344 | IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond, |
2345 | /*withElseRegion=*/false); |
2346 | if (replacement.thenBlock()) |
2347 | rewriter.eraseBlock(block: replacement.thenBlock()); |
2348 | replacement.getThenRegion().takeBody(op.getThenRegion()); |
2349 | replacement.getElseRegion().takeBody(op.getElseRegion()); |
2350 | |
2351 | SmallVector<Value> results(op->getNumResults()); |
2352 | assert(thenYieldArgs.size() == results.size()); |
2353 | assert(elseYieldArgs.size() == results.size()); |
2354 | |
2355 | SmallVector<Value> trueYields; |
2356 | SmallVector<Value> falseYields; |
2357 | rewriter.setInsertionPoint(replacement); |
2358 | for (const auto &it : |
2359 | llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) { |
2360 | Value trueVal = std::get<0>(it.value()); |
2361 | Value falseVal = std::get<1>(it.value()); |
2362 | if (&replacement.getThenRegion() == trueVal.getParentRegion() || |
2363 | &replacement.getElseRegion() == falseVal.getParentRegion()) { |
2364 | results[it.index()] = replacement.getResult(trueYields.size()); |
2365 | trueYields.push_back(trueVal); |
2366 | falseYields.push_back(falseVal); |
2367 | } else if (trueVal == falseVal) |
2368 | results[it.index()] = trueVal; |
2369 | else |
2370 | results[it.index()] = rewriter.create<arith::SelectOp>( |
2371 | op.getLoc(), cond, trueVal, falseVal); |
2372 | } |
2373 | |
2374 | rewriter.setInsertionPointToEnd(replacement.thenBlock()); |
2375 | rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields); |
2376 | |
2377 | rewriter.setInsertionPointToEnd(replacement.elseBlock()); |
2378 | rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields); |
2379 | |
2380 | rewriter.replaceOp(op, results); |
2381 | return success(); |
2382 | } |
2383 | }; |
2384 | |
2385 | /// Allow the true region of an if to assume the condition is true |
2386 | /// and vice versa. For example: |
2387 | /// |
2388 | /// scf.if %cmp { |
2389 | /// print(%cmp) |
2390 | /// } |
2391 | /// |
2392 | /// becomes |
2393 | /// |
2394 | /// scf.if %cmp { |
2395 | /// print(true) |
2396 | /// } |
2397 | /// |
2398 | struct ConditionPropagation : public OpRewritePattern<IfOp> { |
2399 | using OpRewritePattern<IfOp>::OpRewritePattern; |
2400 | |
2401 | LogicalResult matchAndRewrite(IfOp op, |
2402 | PatternRewriter &rewriter) const override { |
2403 | // Early exit if the condition is constant since replacing a constant |
2404 | // in the body with another constant isn't a simplification. |
2405 | if (matchPattern(op.getCondition(), m_Constant())) |
2406 | return failure(); |
2407 | |
2408 | bool changed = false; |
2409 | mlir::Type i1Ty = rewriter.getI1Type(); |
2410 | |
2411 | // These variables serve to prevent creating duplicate constants |
2412 | // and hold constant true or false values. |
2413 | Value constantTrue = nullptr; |
2414 | Value constantFalse = nullptr; |
2415 | |
2416 | for (OpOperand &use : |
2417 | llvm::make_early_inc_range(op.getCondition().getUses())) { |
2418 | if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) { |
2419 | changed = true; |
2420 | |
2421 | if (!constantTrue) |
2422 | constantTrue = rewriter.create<arith::ConstantOp>( |
2423 | op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); |
2424 | |
2425 | rewriter.modifyOpInPlace(use.getOwner(), |
2426 | [&]() { use.set(constantTrue); }); |
2427 | } else if (op.getElseRegion().isAncestor( |
2428 | use.getOwner()->getParentRegion())) { |
2429 | changed = true; |
2430 | |
2431 | if (!constantFalse) |
2432 | constantFalse = rewriter.create<arith::ConstantOp>( |
2433 | op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); |
2434 | |
2435 | rewriter.modifyOpInPlace(use.getOwner(), |
2436 | [&]() { use.set(constantFalse); }); |
2437 | } |
2438 | } |
2439 | |
2440 | return success(IsSuccess: changed); |
2441 | } |
2442 | }; |
2443 | |
2444 | /// Remove any statements from an if that are equivalent to the condition |
2445 | /// or its negation. For example: |
2446 | /// |
2447 | /// %res:2 = scf.if %cmp { |
2448 | /// yield something(), true |
2449 | /// } else { |
2450 | /// yield something2(), false |
2451 | /// } |
2452 | /// print(%res#1) |
2453 | /// |
2454 | /// becomes |
2455 | /// %res = scf.if %cmp { |
2456 | /// yield something() |
2457 | /// } else { |
2458 | /// yield something2() |
2459 | /// } |
2460 | /// print(%cmp) |
2461 | /// |
2462 | /// Additionally if both branches yield the same value, replace all uses |
2463 | /// of the result with the yielded value. |
2464 | /// |
2465 | /// %res:2 = scf.if %cmp { |
2466 | /// yield something(), %arg1 |
2467 | /// } else { |
2468 | /// yield something2(), %arg1 |
2469 | /// } |
2470 | /// print(%res#1) |
2471 | /// |
2472 | /// becomes |
2473 | /// %res = scf.if %cmp { |
2474 | /// yield something() |
2475 | /// } else { |
2476 | /// yield something2() |
2477 | /// } |
2478 | /// print(%arg1) |
2479 | /// |
2480 | struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> { |
2481 | using OpRewritePattern<IfOp>::OpRewritePattern; |
2482 | |
2483 | LogicalResult matchAndRewrite(IfOp op, |
2484 | PatternRewriter &rewriter) const override { |
2485 | // Early exit if there are no results that could be replaced. |
2486 | if (op.getNumResults() == 0) |
2487 | return failure(); |
2488 | |
2489 | auto trueYield = |
2490 | cast<scf::YieldOp>(op.getThenRegion().back().getTerminator()); |
2491 | auto falseYield = |
2492 | cast<scf::YieldOp>(op.getElseRegion().back().getTerminator()); |
2493 | |
2494 | rewriter.setInsertionPoint(op->getBlock(), |
2495 | op.getOperation()->getIterator()); |
2496 | bool changed = false; |
2497 | Type i1Ty = rewriter.getI1Type(); |
2498 | for (auto [trueResult, falseResult, opResult] : |
2499 | llvm::zip(trueYield.getResults(), falseYield.getResults(), |
2500 | op.getResults())) { |
2501 | if (trueResult == falseResult) { |
2502 | if (!opResult.use_empty()) { |
2503 | opResult.replaceAllUsesWith(trueResult); |
2504 | changed = true; |
2505 | } |
2506 | continue; |
2507 | } |
2508 | |
2509 | BoolAttr trueYield, falseYield; |
2510 | if (!matchPattern(trueResult, m_Constant(&trueYield)) || |
2511 | !matchPattern(falseResult, m_Constant(&falseYield))) |
2512 | continue; |
2513 | |
2514 | bool trueVal = trueYield.getValue(); |
2515 | bool falseVal = falseYield.getValue(); |
2516 | if (!trueVal && falseVal) { |
2517 | if (!opResult.use_empty()) { |
2518 | Dialect *constDialect = trueResult.getDefiningOp()->getDialect(); |
2519 | Value notCond = rewriter.create<arith::XOrIOp>( |
2520 | op.getLoc(), op.getCondition(), |
2521 | constDialect |
2522 | ->materializeConstant(rewriter, |
2523 | rewriter.getIntegerAttr(i1Ty, 1), i1Ty, |
2524 | op.getLoc()) |
2525 | ->getResult(0)); |
2526 | opResult.replaceAllUsesWith(notCond); |
2527 | changed = true; |
2528 | } |
2529 | } |
2530 | if (trueVal && !falseVal) { |
2531 | if (!opResult.use_empty()) { |
2532 | opResult.replaceAllUsesWith(op.getCondition()); |
2533 | changed = true; |
2534 | } |
2535 | } |
2536 | } |
2537 | return success(IsSuccess: changed); |
2538 | } |
2539 | }; |
2540 | |
2541 | /// Merge any consecutive scf.if's with the same condition. |
2542 | /// |
2543 | /// scf.if %cond { |
2544 | /// firstCodeTrue();... |
2545 | /// } else { |
2546 | /// firstCodeFalse();... |
2547 | /// } |
2548 | /// %res = scf.if %cond { |
2549 | /// secondCodeTrue();... |
2550 | /// } else { |
2551 | /// secondCodeFalse();... |
2552 | /// } |
2553 | /// |
2554 | /// becomes |
2555 | /// %res = scf.if %cmp { |
2556 | /// firstCodeTrue();... |
2557 | /// secondCodeTrue();... |
2558 | /// } else { |
2559 | /// firstCodeFalse();... |
2560 | /// secondCodeFalse();... |
2561 | /// } |
2562 | struct CombineIfs : public OpRewritePattern<IfOp> { |
2563 | using OpRewritePattern<IfOp>::OpRewritePattern; |
2564 | |
2565 | LogicalResult matchAndRewrite(IfOp nextIf, |
2566 | PatternRewriter &rewriter) const override { |
2567 | Block *parent = nextIf->getBlock(); |
2568 | if (nextIf == &parent->front()) |
2569 | return failure(); |
2570 | |
2571 | auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode()); |
2572 | if (!prevIf) |
2573 | return failure(); |
2574 | |
2575 | // Determine the logical then/else blocks when prevIf's |
2576 | // condition is used. Null means the block does not exist |
2577 | // in that case (e.g. empty else). If neither of these |
2578 | // are set, the two conditions cannot be compared. |
2579 | Block *nextThen = nullptr; |
2580 | Block *nextElse = nullptr; |
2581 | if (nextIf.getCondition() == prevIf.getCondition()) { |
2582 | nextThen = nextIf.thenBlock(); |
2583 | if (!nextIf.getElseRegion().empty()) |
2584 | nextElse = nextIf.elseBlock(); |
2585 | } |
2586 | if (arith::XOrIOp notv = |
2587 | nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) { |
2588 | if (notv.getLhs() == prevIf.getCondition() && |
2589 | matchPattern(notv.getRhs(), m_One())) { |
2590 | nextElse = nextIf.thenBlock(); |
2591 | if (!nextIf.getElseRegion().empty()) |
2592 | nextThen = nextIf.elseBlock(); |
2593 | } |
2594 | } |
2595 | if (arith::XOrIOp notv = |
2596 | prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) { |
2597 | if (notv.getLhs() == nextIf.getCondition() && |
2598 | matchPattern(notv.getRhs(), m_One())) { |
2599 | nextElse = nextIf.thenBlock(); |
2600 | if (!nextIf.getElseRegion().empty()) |
2601 | nextThen = nextIf.elseBlock(); |
2602 | } |
2603 | } |
2604 | |
2605 | if (!nextThen && !nextElse) |
2606 | return failure(); |
2607 | |
2608 | SmallVector<Value> prevElseYielded; |
2609 | if (!prevIf.getElseRegion().empty()) |
2610 | prevElseYielded = prevIf.elseYield().getOperands(); |
2611 | // Replace all uses of return values of op within nextIf with the |
2612 | // corresponding yields |
2613 | for (auto it : llvm::zip(prevIf.getResults(), |
2614 | prevIf.thenYield().getOperands(), prevElseYielded)) |
2615 | for (OpOperand &use : |
2616 | llvm::make_early_inc_range(std::get<0>(it).getUses())) { |
2617 | if (nextThen && nextThen->getParent()->isAncestor( |
2618 | use.getOwner()->getParentRegion())) { |
2619 | rewriter.startOpModification(use.getOwner()); |
2620 | use.set(std::get<1>(it)); |
2621 | rewriter.finalizeOpModification(use.getOwner()); |
2622 | } else if (nextElse && nextElse->getParent()->isAncestor( |
2623 | use.getOwner()->getParentRegion())) { |
2624 | rewriter.startOpModification(use.getOwner()); |
2625 | use.set(std::get<2>(it)); |
2626 | rewriter.finalizeOpModification(use.getOwner()); |
2627 | } |
2628 | } |
2629 | |
2630 | SmallVector<Type> mergedTypes(prevIf.getResultTypes()); |
2631 | llvm::append_range(mergedTypes, nextIf.getResultTypes()); |
2632 | |
2633 | IfOp combinedIf = rewriter.create<IfOp>( |
2634 | nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false); |
2635 | rewriter.eraseBlock(block: &combinedIf.getThenRegion().back()); |
2636 | |
2637 | rewriter.inlineRegionBefore(prevIf.getThenRegion(), |
2638 | combinedIf.getThenRegion(), |
2639 | combinedIf.getThenRegion().begin()); |
2640 | |
2641 | if (nextThen) { |
2642 | YieldOp thenYield = combinedIf.thenYield(); |
2643 | YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator()); |
2644 | rewriter.mergeBlocks(source: nextThen, dest: combinedIf.thenBlock()); |
2645 | rewriter.setInsertionPointToEnd(combinedIf.thenBlock()); |
2646 | |
2647 | SmallVector<Value> mergedYields(thenYield.getOperands()); |
2648 | llvm::append_range(mergedYields, thenYield2.getOperands()); |
2649 | rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields); |
2650 | rewriter.eraseOp(op: thenYield); |
2651 | rewriter.eraseOp(op: thenYield2); |
2652 | } |
2653 | |
2654 | rewriter.inlineRegionBefore(prevIf.getElseRegion(), |
2655 | combinedIf.getElseRegion(), |
2656 | combinedIf.getElseRegion().begin()); |
2657 | |
2658 | if (nextElse) { |
2659 | if (combinedIf.getElseRegion().empty()) { |
2660 | rewriter.inlineRegionBefore(*nextElse->getParent(), |
2661 | combinedIf.getElseRegion(), |
2662 | combinedIf.getElseRegion().begin()); |
2663 | } else { |
2664 | YieldOp elseYield = combinedIf.elseYield(); |
2665 | YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator()); |
2666 | rewriter.mergeBlocks(source: nextElse, dest: combinedIf.elseBlock()); |
2667 | |
2668 | rewriter.setInsertionPointToEnd(combinedIf.elseBlock()); |
2669 | |
2670 | SmallVector<Value> mergedElseYields(elseYield.getOperands()); |
2671 | llvm::append_range(mergedElseYields, elseYield2.getOperands()); |
2672 | |
2673 | rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields); |
2674 | rewriter.eraseOp(op: elseYield); |
2675 | rewriter.eraseOp(op: elseYield2); |
2676 | } |
2677 | } |
2678 | |
2679 | SmallVector<Value> prevValues; |
2680 | SmallVector<Value> nextValues; |
2681 | for (const auto &pair : llvm::enumerate(combinedIf.getResults())) { |
2682 | if (pair.index() < prevIf.getNumResults()) |
2683 | prevValues.push_back(pair.value()); |
2684 | else |
2685 | nextValues.push_back(pair.value()); |
2686 | } |
2687 | rewriter.replaceOp(prevIf, prevValues); |
2688 | rewriter.replaceOp(nextIf, nextValues); |
2689 | return success(); |
2690 | } |
2691 | }; |
2692 | |
2693 | /// Pattern to remove an empty else branch. |
2694 | struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> { |
2695 | using OpRewritePattern<IfOp>::OpRewritePattern; |
2696 | |
2697 | LogicalResult matchAndRewrite(IfOp ifOp, |
2698 | PatternRewriter &rewriter) const override { |
2699 | // Cannot remove else region when there are operation results. |
2700 | if (ifOp.getNumResults()) |
2701 | return failure(); |
2702 | Block *elseBlock = ifOp.elseBlock(); |
2703 | if (!elseBlock || !llvm::hasSingleElement(C&: *elseBlock)) |
2704 | return failure(); |
2705 | auto newIfOp = rewriter.cloneWithoutRegions(ifOp); |
2706 | rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(), |
2707 | newIfOp.getThenRegion().begin()); |
2708 | rewriter.eraseOp(op: ifOp); |
2709 | return success(); |
2710 | } |
2711 | }; |
2712 | |
2713 | /// Convert nested `if`s into `arith.andi` + single `if`. |
2714 | /// |
2715 | /// scf.if %arg0 { |
2716 | /// scf.if %arg1 { |
2717 | /// ... |
2718 | /// scf.yield |
2719 | /// } |
2720 | /// scf.yield |
2721 | /// } |
2722 | /// becomes |
2723 | /// |
2724 | /// %0 = arith.andi %arg0, %arg1 |
2725 | /// scf.if %0 { |
2726 | /// ... |
2727 | /// scf.yield |
2728 | /// } |
2729 | struct CombineNestedIfs : public OpRewritePattern<IfOp> { |
2730 | using OpRewritePattern<IfOp>::OpRewritePattern; |
2731 | |
2732 | LogicalResult matchAndRewrite(IfOp op, |
2733 | PatternRewriter &rewriter) const override { |
2734 | auto nestedOps = op.thenBlock()->without_terminator(); |
2735 | // Nested `if` must be the only op in block. |
2736 | if (!llvm::hasSingleElement(nestedOps)) |
2737 | return failure(); |
2738 | |
2739 | // If there is an else block, it can only yield |
2740 | if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock())) |
2741 | return failure(); |
2742 | |
2743 | auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin()); |
2744 | if (!nestedIf) |
2745 | return failure(); |
2746 | |
2747 | if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock())) |
2748 | return failure(); |
2749 | |
2750 | SmallVector<Value> thenYield(op.thenYield().getOperands()); |
2751 | SmallVector<Value> elseYield; |
2752 | if (op.elseBlock()) |
2753 | llvm::append_range(elseYield, op.elseYield().getOperands()); |
2754 | |
2755 | // A list of indices for which we should upgrade the value yielded |
2756 | // in the else to a select. |
2757 | SmallVector<unsigned> elseYieldsToUpgradeToSelect; |
2758 | |
2759 | // If the outer scf.if yields a value produced by the inner scf.if, |
2760 | // only permit combining if the value yielded when the condition |
2761 | // is false in the outer scf.if is the same value yielded when the |
2762 | // inner scf.if condition is false. |
2763 | // Note that the array access to elseYield will not go out of bounds |
2764 | // since it must have the same length as thenYield, since they both |
2765 | // come from the same scf.if. |
2766 | for (const auto &tup : llvm::enumerate(thenYield)) { |
2767 | if (tup.value().getDefiningOp() == nestedIf) { |
2768 | auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber(); |
2769 | if (nestedIf.elseYield().getOperand(nestedIdx) != |
2770 | elseYield[tup.index()]) { |
2771 | return failure(); |
2772 | } |
2773 | // If the correctness test passes, we will yield |
2774 | // corresponding value from the inner scf.if |
2775 | thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx); |
2776 | continue; |
2777 | } |
2778 | |
2779 | // Otherwise, we need to ensure the else block of the combined |
2780 | // condition still returns the same value when the outer condition is |
2781 | // true and the inner condition is false. This can be accomplished if |
2782 | // the then value is defined outside the outer scf.if and we replace the |
2783 | // value with a select that considers just the outer condition. Since |
2784 | // the else region contains just the yield, its yielded value is |
2785 | // defined outside the scf.if, by definition. |
2786 | |
2787 | // If the then value is defined within the scf.if, bail. |
2788 | if (tup.value().getParentRegion() == &op.getThenRegion()) { |
2789 | return failure(); |
2790 | } |
2791 | elseYieldsToUpgradeToSelect.push_back(tup.index()); |
2792 | } |
2793 | |
2794 | Location loc = op.getLoc(); |
2795 | Value newCondition = rewriter.create<arith::AndIOp>( |
2796 | loc, op.getCondition(), nestedIf.getCondition()); |
2797 | auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition); |
2798 | Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion()); |
2799 | |
2800 | SmallVector<Value> results; |
2801 | llvm::append_range(results, newIf.getResults()); |
2802 | rewriter.setInsertionPoint(newIf); |
2803 | |
2804 | for (auto idx : elseYieldsToUpgradeToSelect) |
2805 | results[idx] = rewriter.create<arith::SelectOp>( |
2806 | op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]); |
2807 | |
2808 | rewriter.mergeBlocks(source: nestedIf.thenBlock(), dest: newIfBlock); |
2809 | rewriter.setInsertionPointToEnd(newIf.thenBlock()); |
2810 | rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield); |
2811 | if (!elseYield.empty()) { |
2812 | rewriter.createBlock(&newIf.getElseRegion()); |
2813 | rewriter.setInsertionPointToEnd(newIf.elseBlock()); |
2814 | rewriter.create<YieldOp>(loc, elseYield); |
2815 | } |
2816 | rewriter.replaceOp(op, results); |
2817 | return success(); |
2818 | } |
2819 | }; |
2820 | |
2821 | } // namespace |
2822 | |
2823 | void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2824 | MLIRContext *context) { |
2825 | results.add<CombineIfs, CombineNestedIfs, ConditionPropagation, |
2826 | ConvertTrivialIfToSelect, RemoveEmptyElseBranch, |
2827 | RemoveStaticCondition, RemoveUnusedResults, |
2828 | ReplaceIfYieldWithConditionOrValue>(context); |
2829 | } |
2830 | |
2831 | Block *IfOp::thenBlock() { return &getThenRegion().back(); } |
2832 | YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); } |
2833 | Block *IfOp::elseBlock() { |
2834 | Region &r = getElseRegion(); |
2835 | if (r.empty()) |
2836 | return nullptr; |
2837 | return &r.back(); |
2838 | } |
2839 | YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); } |
2840 | |
2841 | //===----------------------------------------------------------------------===// |
2842 | // ParallelOp |
2843 | //===----------------------------------------------------------------------===// |
2844 | |
2845 | void ParallelOp::build( |
2846 | OpBuilder &builder, OperationState &result, ValueRange lowerBounds, |
2847 | ValueRange upperBounds, ValueRange steps, ValueRange initVals, |
2848 | function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> |
2849 | bodyBuilderFn) { |
2850 | result.addOperands(lowerBounds); |
2851 | result.addOperands(upperBounds); |
2852 | result.addOperands(steps); |
2853 | result.addOperands(initVals); |
2854 | result.addAttribute( |
2855 | ParallelOp::getOperandSegmentSizeAttr(), |
2856 | builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()), |
2857 | static_cast<int32_t>(upperBounds.size()), |
2858 | static_cast<int32_t>(steps.size()), |
2859 | static_cast<int32_t>(initVals.size())})); |
2860 | result.addTypes(initVals.getTypes()); |
2861 | |
2862 | OpBuilder::InsertionGuard guard(builder); |
2863 | unsigned numIVs = steps.size(); |
2864 | SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType()); |
2865 | SmallVector<Location, 8> argLocs(numIVs, result.location); |
2866 | Region *bodyRegion = result.addRegion(); |
2867 | Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs); |
2868 | |
2869 | if (bodyBuilderFn) { |
2870 | builder.setInsertionPointToStart(bodyBlock); |
2871 | bodyBuilderFn(builder, result.location, |
2872 | bodyBlock->getArguments().take_front(numIVs), |
2873 | bodyBlock->getArguments().drop_front(numIVs)); |
2874 | } |
2875 | // Add terminator only if there are no reductions. |
2876 | if (initVals.empty()) |
2877 | ParallelOp::ensureTerminator(*bodyRegion, builder, result.location); |
2878 | } |
2879 | |
2880 | void ParallelOp::build( |
2881 | OpBuilder &builder, OperationState &result, ValueRange lowerBounds, |
2882 | ValueRange upperBounds, ValueRange steps, |
2883 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
2884 | // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure |
2885 | // we don't capture a reference to a temporary by constructing the lambda at |
2886 | // function level. |
2887 | auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder, |
2888 | Location nestedLoc, ValueRange ivs, |
2889 | ValueRange) { |
2890 | bodyBuilderFn(nestedBuilder, nestedLoc, ivs); |
2891 | }; |
2892 | function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper; |
2893 | if (bodyBuilderFn) |
2894 | wrapper = wrappedBuilderFn; |
2895 | |
2896 | build(builder, result, lowerBounds, upperBounds, steps, ValueRange(), |
2897 | wrapper); |
2898 | } |
2899 | |
2900 | LogicalResult ParallelOp::verify() { |
2901 | // Check that there is at least one value in lowerBound, upperBound and step. |
2902 | // It is sufficient to test only step, because it is ensured already that the |
2903 | // number of elements in lowerBound, upperBound and step are the same. |
2904 | Operation::operand_range stepValues = getStep(); |
2905 | if (stepValues.empty()) |
2906 | return emitOpError( |
2907 | "needs at least one tuple element for lowerBound, upperBound and step"); |
2908 | |
2909 | // Check whether all constant step values are positive. |
2910 | for (Value stepValue : stepValues) |
2911 | if (auto cst = getConstantIntValue(stepValue)) |
2912 | if (*cst <= 0) |
2913 | return emitOpError("constant step operand must be positive"); |
2914 | |
2915 | // Check that the body defines the same number of block arguments as the |
2916 | // number of tuple elements in step. |
2917 | Block *body = getBody(); |
2918 | if (body->getNumArguments() != stepValues.size()) |
2919 | return emitOpError() << "expects the same number of induction variables: " |
2920 | << body->getNumArguments() |
2921 | << " as bound and step values: "<< stepValues.size(); |
2922 | for (auto arg : body->getArguments()) |
2923 | if (!arg.getType().isIndex()) |
2924 | return emitOpError( |
2925 | "expects arguments for the induction variable to be of index type"); |
2926 | |
2927 | // Check that the terminator is an scf.reduce op. |
2928 | auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>( |
2929 | *this, getRegion(), "expects body to terminate with 'scf.reduce'"); |
2930 | if (!reduceOp) |
2931 | return failure(); |
2932 | |
2933 | // Check that the number of results is the same as the number of reductions. |
2934 | auto resultsSize = getResults().size(); |
2935 | auto reductionsSize = reduceOp.getReductions().size(); |
2936 | auto initValsSize = getInitVals().size(); |
2937 | if (resultsSize != reductionsSize) |
2938 | return emitOpError() << "expects number of results: "<< resultsSize |
2939 | << " to be the same as number of reductions: " |
2940 | << reductionsSize; |
2941 | if (resultsSize != initValsSize) |
2942 | return emitOpError() << "expects number of results: "<< resultsSize |
2943 | << " to be the same as number of initial values: " |
2944 | << initValsSize; |
2945 | |
2946 | // Check that the types of the results and reductions are the same. |
2947 | for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) { |
2948 | auto resultType = getOperation()->getResult(i).getType(); |
2949 | auto reductionOperandType = reduceOp.getOperands()[i].getType(); |
2950 | if (resultType != reductionOperandType) |
2951 | return reduceOp.emitOpError() |
2952 | << "expects type of "<< i |
2953 | << "-th reduction operand: "<< reductionOperandType |
2954 | << " to be the same as the "<< i |
2955 | << "-th result type: "<< resultType; |
2956 | } |
2957 | return success(); |
2958 | } |
2959 | |
2960 | ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { |
2961 | auto &builder = parser.getBuilder(); |
2962 | // Parse an opening `(` followed by induction variables followed by `)` |
2963 | SmallVector<OpAsmParser::Argument, 4> ivs; |
2964 | if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren)) |
2965 | return failure(); |
2966 | |
2967 | // Parse loop bounds. |
2968 | SmallVector<OpAsmParser::UnresolvedOperand, 4> lower; |
2969 | if (parser.parseEqual() || |
2970 | parser.parseOperandList(lower, ivs.size(), |
2971 | OpAsmParser::Delimiter::Paren) || |
2972 | parser.resolveOperands(lower, builder.getIndexType(), result.operands)) |
2973 | return failure(); |
2974 | |
2975 | SmallVector<OpAsmParser::UnresolvedOperand, 4> upper; |
2976 | if (parser.parseKeyword("to") || |
2977 | parser.parseOperandList(upper, ivs.size(), |
2978 | OpAsmParser::Delimiter::Paren) || |
2979 | parser.resolveOperands(upper, builder.getIndexType(), result.operands)) |
2980 | return failure(); |
2981 | |
2982 | // Parse step values. |
2983 | SmallVector<OpAsmParser::UnresolvedOperand, 4> steps; |
2984 | if (parser.parseKeyword("step") || |
2985 | parser.parseOperandList(steps, ivs.size(), |
2986 | OpAsmParser::Delimiter::Paren) || |
2987 | parser.resolveOperands(steps, builder.getIndexType(), result.operands)) |
2988 | return failure(); |
2989 | |
2990 | // Parse init values. |
2991 | SmallVector<OpAsmParser::UnresolvedOperand, 4> initVals; |
2992 | if (succeeded(parser.parseOptionalKeyword("init"))) { |
2993 | if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren)) |
2994 | return failure(); |
2995 | } |
2996 | |
2997 | // Parse optional results in case there is a reduce. |
2998 | if (parser.parseOptionalArrowTypeList(result.types)) |
2999 | return failure(); |
3000 | |
3001 | // Now parse the body. |
3002 | Region *body = result.addRegion(); |
3003 | for (auto &iv : ivs) |
3004 | iv.type = builder.getIndexType(); |
3005 | if (parser.parseRegion(*body, ivs)) |
3006 | return failure(); |
3007 | |
3008 | // Set `operandSegmentSizes` attribute. |
3009 | result.addAttribute( |
3010 | ParallelOp::getOperandSegmentSizeAttr(), |
3011 | builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()), |
3012 | static_cast<int32_t>(upper.size()), |
3013 | static_cast<int32_t>(steps.size()), |
3014 | static_cast<int32_t>(initVals.size())})); |
3015 | |
3016 | // Parse attributes. |
3017 | if (parser.parseOptionalAttrDict(result.attributes) || |
3018 | parser.resolveOperands(initVals, result.types, parser.getNameLoc(), |
3019 | result.operands)) |
3020 | return failure(); |
3021 | |
3022 | // Add a terminator if none was parsed. |
3023 | ParallelOp::ensureTerminator(*body, builder, result.location); |
3024 | return success(); |
3025 | } |
3026 | |
3027 | void ParallelOp::print(OpAsmPrinter &p) { |
3028 | p << " ("<< getBody()->getArguments() << ") = ("<< getLowerBound() |
3029 | << ") to ("<< getUpperBound() << ") step ("<< getStep() << ")"; |
3030 | if (!getInitVals().empty()) |
3031 | p << " init ("<< getInitVals() << ")"; |
3032 | p.printOptionalArrowTypeList(getResultTypes()); |
3033 | p << ' '; |
3034 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
3035 | p.printOptionalAttrDict( |
3036 | (*this)->getAttrs(), |
3037 | /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr()); |
3038 | } |
3039 | |
3040 | SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; } |
3041 | |
3042 | std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() { |
3043 | return SmallVector<Value>{getBody()->getArguments()}; |
3044 | } |
3045 | |
3046 | std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() { |
3047 | return getLowerBound(); |
3048 | } |
3049 | |
3050 | std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() { |
3051 | return getUpperBound(); |
3052 | } |
3053 | |
3054 | std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() { |
3055 | return getStep(); |
3056 | } |
3057 | |
3058 | ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) { |
3059 | auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val); |
3060 | if (!ivArg) |
3061 | return ParallelOp(); |
3062 | assert(ivArg.getOwner() && "unlinked block argument"); |
3063 | auto *containingOp = ivArg.getOwner()->getParentOp(); |
3064 | return dyn_cast<ParallelOp>(containingOp); |
3065 | } |
3066 | |
3067 | namespace { |
3068 | // Collapse loop dimensions that perform a single iteration. |
3069 | struct ParallelOpSingleOrZeroIterationDimsFolder |
3070 | : public OpRewritePattern<ParallelOp> { |
3071 | using OpRewritePattern<ParallelOp>::OpRewritePattern; |
3072 | |
3073 | LogicalResult matchAndRewrite(ParallelOp op, |
3074 | PatternRewriter &rewriter) const override { |
3075 | Location loc = op.getLoc(); |
3076 | |
3077 | // Compute new loop bounds that omit all single-iteration loop dimensions. |
3078 | SmallVector<Value> newLowerBounds, newUpperBounds, newSteps; |
3079 | IRMapping mapping; |
3080 | for (auto [lb, ub, step, iv] : |
3081 | llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(), |
3082 | op.getInductionVars())) { |
3083 | auto numIterations = constantTripCount(lb, ub, step); |
3084 | if (numIterations.has_value()) { |
3085 | // Remove the loop if it performs zero iterations. |
3086 | if (*numIterations == 0) { |
3087 | rewriter.replaceOp(op, op.getInitVals()); |
3088 | return success(); |
3089 | } |
3090 | // Replace the loop induction variable by the lower bound if the loop |
3091 | // performs a single iteration. Otherwise, copy the loop bounds. |
3092 | if (*numIterations == 1) { |
3093 | mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb)); |
3094 | continue; |
3095 | } |
3096 | } |
3097 | newLowerBounds.push_back(lb); |
3098 | newUpperBounds.push_back(ub); |
3099 | newSteps.push_back(step); |
3100 | } |
3101 | // Exit if none of the loop dimensions perform a single iteration. |
3102 | if (newLowerBounds.size() == op.getLowerBound().size()) |
3103 | return failure(); |
3104 | |
3105 | if (newLowerBounds.empty()) { |
3106 | // All of the loop dimensions perform a single iteration. Inline |
3107 | // loop body and nested ReduceOp's |
3108 | SmallVector<Value> results; |
3109 | results.reserve(N: op.getInitVals().size()); |
3110 | for (auto &bodyOp : op.getBody()->without_terminator()) |
3111 | rewriter.clone(bodyOp, mapping); |
3112 | auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator()); |
3113 | for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) { |
3114 | Block &reduceBlock = reduceOp.getReductions()[i].front(); |
3115 | auto initValIndex = results.size(); |
3116 | mapping.map(reduceBlock.getArgument(i: 0), op.getInitVals()[initValIndex]); |
3117 | mapping.map(reduceBlock.getArgument(i: 1), |
3118 | mapping.lookupOrDefault(reduceOp.getOperands()[i])); |
3119 | for (auto &reduceBodyOp : reduceBlock.without_terminator()) |
3120 | rewriter.clone(reduceBodyOp, mapping); |
3121 | |
3122 | auto result = mapping.lookupOrDefault( |
3123 | cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult()); |
3124 | results.push_back(Elt: result); |
3125 | } |
3126 | |
3127 | rewriter.replaceOp(op, results); |
3128 | return success(); |
3129 | } |
3130 | // Replace the parallel loop by lower-dimensional parallel loop. |
3131 | auto newOp = |
3132 | rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds, |
3133 | newSteps, op.getInitVals(), nullptr); |
3134 | // Erase the empty block that was inserted by the builder. |
3135 | rewriter.eraseBlock(block: newOp.getBody()); |
3136 | // Clone the loop body and remap the block arguments of the collapsed loops |
3137 | // (inlining does not support a cancellable block argument mapping). |
3138 | rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(), |
3139 | newOp.getRegion().begin(), mapping); |
3140 | rewriter.replaceOp(op, newOp.getResults()); |
3141 | return success(); |
3142 | } |
3143 | }; |
3144 | |
3145 | struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> { |
3146 | using OpRewritePattern<ParallelOp>::OpRewritePattern; |
3147 | |
3148 | LogicalResult matchAndRewrite(ParallelOp op, |
3149 | PatternRewriter &rewriter) const override { |
3150 | Block &outerBody = *op.getBody(); |
3151 | if (!llvm::hasSingleElement(C: outerBody.without_terminator())) |
3152 | return failure(); |
3153 | |
3154 | auto innerOp = dyn_cast<ParallelOp>(outerBody.front()); |
3155 | if (!innerOp) |
3156 | return failure(); |
3157 | |
3158 | for (auto val : outerBody.getArguments()) |
3159 | if (llvm::is_contained(innerOp.getLowerBound(), val) || |
3160 | llvm::is_contained(innerOp.getUpperBound(), val) || |
3161 | llvm::is_contained(innerOp.getStep(), val)) |
3162 | return failure(); |
3163 | |
3164 | // Reductions are not supported yet. |
3165 | if (!op.getInitVals().empty() || !innerOp.getInitVals().empty()) |
3166 | return failure(); |
3167 | |
3168 | auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/, |
3169 | ValueRange iterVals, ValueRange) { |
3170 | Block &innerBody = *innerOp.getBody(); |
3171 | assert(iterVals.size() == |
3172 | (outerBody.getNumArguments() + innerBody.getNumArguments())); |
3173 | IRMapping mapping; |
3174 | mapping.map(from: outerBody.getArguments(), |
3175 | to: iterVals.take_front(n: outerBody.getNumArguments())); |
3176 | mapping.map(from: innerBody.getArguments(), |
3177 | to: iterVals.take_back(n: innerBody.getNumArguments())); |
3178 | for (Operation &op : innerBody.without_terminator()) |
3179 | builder.clone(op, mapping); |
3180 | }; |
3181 | |
3182 | auto concatValues = [](const auto &first, const auto &second) { |
3183 | SmallVector<Value> ret; |
3184 | ret.reserve(N: first.size() + second.size()); |
3185 | ret.assign(first.begin(), first.end()); |
3186 | ret.append(second.begin(), second.end()); |
3187 | return ret; |
3188 | }; |
3189 | |
3190 | auto newLowerBounds = |
3191 | concatValues(op.getLowerBound(), innerOp.getLowerBound()); |
3192 | auto newUpperBounds = |
3193 | concatValues(op.getUpperBound(), innerOp.getUpperBound()); |
3194 | auto newSteps = concatValues(op.getStep(), innerOp.getStep()); |
3195 | |
3196 | rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds, |
3197 | newSteps, std::nullopt, |
3198 | bodyBuilder); |
3199 | return success(); |
3200 | } |
3201 | }; |
3202 | |
3203 | } // namespace |
3204 | |
3205 | void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results, |
3206 | MLIRContext *context) { |
3207 | results |
3208 | .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>( |
3209 | context); |
3210 | } |
3211 | |
3212 | /// Given the region at `index`, or the parent operation if `index` is None, |
3213 | /// return the successor regions. These are the regions that may be selected |
3214 | /// during the flow of control. `operands` is a set of optional attributes that |
3215 | /// correspond to a constant value for each operand, or null if that operand is |
3216 | /// not a constant. |
3217 | void ParallelOp::getSuccessorRegions( |
3218 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
3219 | // Both the operation itself and the region may be branching into the body or |
3220 | // back into the operation itself. It is possible for loop not to enter the |
3221 | // body. |
3222 | regions.push_back(RegionSuccessor(&getRegion())); |
3223 | regions.push_back(RegionSuccessor()); |
3224 | } |
3225 | |
3226 | //===----------------------------------------------------------------------===// |
3227 | // ReduceOp |
3228 | //===----------------------------------------------------------------------===// |
3229 | |
3230 | void ReduceOp::build(OpBuilder &builder, OperationState &result) {} |
3231 | |
3232 | void ReduceOp::build(OpBuilder &builder, OperationState &result, |
3233 | ValueRange operands) { |
3234 | result.addOperands(operands); |
3235 | for (Value v : operands) { |
3236 | OpBuilder::InsertionGuard guard(builder); |
3237 | Region *bodyRegion = result.addRegion(); |
3238 | builder.createBlock(bodyRegion, {}, |
3239 | ArrayRef<Type>{v.getType(), v.getType()}, |
3240 | {result.location, result.location}); |
3241 | } |
3242 | } |
3243 | |
3244 | LogicalResult ReduceOp::verifyRegions() { |
3245 | // The region of a ReduceOp has two arguments of the same type as its |
3246 | // corresponding operand. |
3247 | for (int64_t i = 0, e = getReductions().size(); i < e; ++i) { |
3248 | auto type = getOperands()[i].getType(); |
3249 | Block &block = getReductions()[i].front(); |
3250 | if (block.empty()) |
3251 | return emitOpError() << i << "-th reduction has an empty body"; |
3252 | if (block.getNumArguments() != 2 || |
3253 | llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) { |
3254 | return arg.getType() != type; |
3255 | })) |
3256 | return emitOpError() << "expected two block arguments with type "<< type |
3257 | << " in the "<< i << "-th reduction region"; |
3258 | |
3259 | // Check that the block is terminated by a ReduceReturnOp. |
3260 | if (!isa<ReduceReturnOp>(block.getTerminator())) |
3261 | return emitOpError("reduction bodies must be terminated with an " |
3262 | "'scf.reduce.return' op"); |
3263 | } |
3264 | |
3265 | return success(); |
3266 | } |
3267 | |
3268 | MutableOperandRange |
3269 | ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { |
3270 | // No operands are forwarded to the next iteration. |
3271 | return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); |
3272 | } |
3273 | |
3274 | //===----------------------------------------------------------------------===// |
3275 | // ReduceReturnOp |
3276 | //===----------------------------------------------------------------------===// |
3277 | |
3278 | LogicalResult ReduceReturnOp::verify() { |
3279 | // The type of the return value should be the same type as the types of the |
3280 | // block arguments of the reduction body. |
3281 | Block *reductionBody = getOperation()->getBlock(); |
3282 | // Should already be verified by an op trait. |
3283 | assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce"); |
3284 | Type expectedResultType = reductionBody->getArgument(0).getType(); |
3285 | if (expectedResultType != getResult().getType()) |
3286 | return emitOpError() << "must have type "<< expectedResultType |
3287 | << " (the type of the reduction inputs)"; |
3288 | return success(); |
3289 | } |
3290 | |
3291 | //===----------------------------------------------------------------------===// |
3292 | // WhileOp |
3293 | //===----------------------------------------------------------------------===// |
3294 | |
3295 | void WhileOp::build(::mlir::OpBuilder &odsBuilder, |
3296 | ::mlir::OperationState &odsState, TypeRange resultTypes, |
3297 | ValueRange inits, BodyBuilderFn beforeBuilder, |
3298 | BodyBuilderFn afterBuilder) { |
3299 | odsState.addOperands(inits); |
3300 | odsState.addTypes(resultTypes); |
3301 | |
3302 | OpBuilder::InsertionGuard guard(odsBuilder); |
3303 | |
3304 | // Build before region. |
3305 | SmallVector<Location, 4> beforeArgLocs; |
3306 | beforeArgLocs.reserve(inits.size()); |
3307 | for (Value operand : inits) { |
3308 | beforeArgLocs.push_back(operand.getLoc()); |
3309 | } |
3310 | |
3311 | Region *beforeRegion = odsState.addRegion(); |
3312 | Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{}, |
3313 | inits.getTypes(), beforeArgLocs); |
3314 | if (beforeBuilder) |
3315 | beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments()); |
3316 | |
3317 | // Build after region. |
3318 | SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location); |
3319 | |
3320 | Region *afterRegion = odsState.addRegion(); |
3321 | Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{}, |
3322 | resultTypes, afterArgLocs); |
3323 | |
3324 | if (afterBuilder) |
3325 | afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); |
3326 | } |
3327 | |
3328 | ConditionOp WhileOp::getConditionOp() { |
3329 | return cast<ConditionOp>(getBeforeBody()->getTerminator()); |
3330 | } |
3331 | |
3332 | YieldOp WhileOp::getYieldOp() { |
3333 | return cast<YieldOp>(getAfterBody()->getTerminator()); |
3334 | } |
3335 | |
3336 | std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() { |
3337 | return getYieldOp().getResultsMutable(); |
3338 | } |
3339 | |
3340 | Block::BlockArgListType WhileOp::getBeforeArguments() { |
3341 | return getBeforeBody()->getArguments(); |
3342 | } |
3343 | |
3344 | Block::BlockArgListType WhileOp::getAfterArguments() { |
3345 | return getAfterBody()->getArguments(); |
3346 | } |
3347 | |
3348 | Block::BlockArgListType WhileOp::getRegionIterArgs() { |
3349 | return getBeforeArguments(); |
3350 | } |
3351 | |
3352 | OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
3353 | assert(point == getBefore() && |
3354 | "WhileOp is expected to branch only to the first region"); |
3355 | return getInits(); |
3356 | } |
3357 | |
3358 | void WhileOp::getSuccessorRegions(RegionBranchPoint point, |
3359 | SmallVectorImpl<RegionSuccessor> ®ions) { |
3360 | // The parent op always branches to the condition region. |
3361 | if (point.isParent()) { |
3362 | regions.emplace_back(&getBefore(), getBefore().getArguments()); |
3363 | return; |
3364 | } |
3365 | |
3366 | assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && |
3367 | "there are only two regions in a WhileOp"); |
3368 | // The body region always branches back to the condition region. |
3369 | if (point == getAfter()) { |
3370 | regions.emplace_back(&getBefore(), getBefore().getArguments()); |
3371 | return; |
3372 | } |
3373 | |
3374 | regions.emplace_back(getResults()); |
3375 | regions.emplace_back(&getAfter(), getAfter().getArguments()); |
3376 | } |
3377 | |
3378 | SmallVector<Region *> WhileOp::getLoopRegions() { |
3379 | return {&getBefore(), &getAfter()}; |
3380 | } |
3381 | |
3382 | /// Parses a `while` op. |
3383 | /// |
3384 | /// op ::= `scf.while` assignments `:` function-type region `do` region |
3385 | /// `attributes` attribute-dict |
3386 | /// initializer ::= /* empty */ | `(` assignment-list `)` |
3387 | /// assignment-list ::= assignment | assignment `,` assignment-list |
3388 | /// assignment ::= ssa-value `=` ssa-value |
3389 | ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) { |
3390 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
3391 | SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; |
3392 | Region *before = result.addRegion(); |
3393 | Region *after = result.addRegion(); |
3394 | |
3395 | OptionalParseResult listResult = |
3396 | parser.parseOptionalAssignmentList(regionArgs, operands); |
3397 | if (listResult.has_value() && failed(listResult.value())) |
3398 | return failure(); |
3399 | |
3400 | FunctionType functionType; |
3401 | SMLoc typeLoc = parser.getCurrentLocation(); |
3402 | if (failed(parser.parseColonType(functionType))) |
3403 | return failure(); |
3404 | |
3405 | result.addTypes(functionType.getResults()); |
3406 | |
3407 | if (functionType.getNumInputs() != operands.size()) { |
3408 | return parser.emitError(typeLoc) |
3409 | << "expected as many input types as operands " |
3410 | << "(expected "<< operands.size() << " got " |
3411 | << functionType.getNumInputs() << ")"; |
3412 | } |
3413 | |
3414 | // Resolve input operands. |
3415 | if (failed(parser.resolveOperands(operands, functionType.getInputs(), |
3416 | parser.getCurrentLocation(), |
3417 | result.operands))) |
3418 | return failure(); |
3419 | |
3420 | // Propagate the types into the region arguments. |
3421 | for (size_t i = 0, e = regionArgs.size(); i != e; ++i) |
3422 | regionArgs[i].type = functionType.getInput(i); |
3423 | |
3424 | return failure(parser.parseRegion(*before, regionArgs) || |
3425 | parser.parseKeyword("do") || parser.parseRegion(*after) || |
3426 | parser.parseOptionalAttrDictWithKeyword(result.attributes)); |
3427 | } |
3428 | |
3429 | /// Prints a `while` op. |
3430 | void scf::WhileOp::print(OpAsmPrinter &p) { |
3431 | printInitializationList(p, getBeforeArguments(), getInits(), " "); |
3432 | p << " : "; |
3433 | p.printFunctionalType(getInits().getTypes(), getResults().getTypes()); |
3434 | p << ' '; |
3435 | p.printRegion(getBefore(), /*printEntryBlockArgs=*/false); |
3436 | p << " do "; |
3437 | p.printRegion(getAfter()); |
3438 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); |
3439 | } |
3440 | |
3441 | /// Verifies that two ranges of types match, i.e. have the same number of |
3442 | /// entries and that types are pairwise equals. Reports errors on the given |
3443 | /// operation in case of mismatch. |
3444 | template <typename OpTy> |
3445 | static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, |
3446 | TypeRange right, StringRef message) { |
3447 | if (left.size() != right.size()) |
3448 | return op.emitOpError("expects the same number of ") << message; |
3449 | |
3450 | for (unsigned i = 0, e = left.size(); i < e; ++i) { |
3451 | if (left[i] != right[i]) { |
3452 | InFlightDiagnostic diag = op.emitOpError("expects the same types for ") |
3453 | << message; |
3454 | diag.attachNote() << "for argument "<< i << ", found "<< left[i] |
3455 | << " and "<< right[i]; |
3456 | return diag; |
3457 | } |
3458 | } |
3459 | |
3460 | return success(); |
3461 | } |
3462 | |
3463 | LogicalResult scf::WhileOp::verify() { |
3464 | auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>( |
3465 | *this, getBefore(), |
3466 | "expects the 'before' region to terminate with 'scf.condition'"); |
3467 | if (!beforeTerminator) |
3468 | return failure(); |
3469 | |
3470 | auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>( |
3471 | *this, getAfter(), |
3472 | "expects the 'after' region to terminate with 'scf.yield'"); |
3473 | return success(afterTerminator != nullptr); |
3474 | } |
3475 | |
3476 | namespace { |
3477 | /// Replace uses of the condition within the do block with true, since otherwise |
3478 | /// the block would not be evaluated. |
3479 | /// |
3480 | /// scf.while (..) : (i1, ...) -> ... { |
3481 | /// %condition = call @evaluate_condition() : () -> i1 |
3482 | /// scf.condition(%condition) %condition : i1, ... |
3483 | /// } do { |
3484 | /// ^bb0(%arg0: i1, ...): |
3485 | /// use(%arg0) |
3486 | /// ... |
3487 | /// |
3488 | /// becomes |
3489 | /// scf.while (..) : (i1, ...) -> ... { |
3490 | /// %condition = call @evaluate_condition() : () -> i1 |
3491 | /// scf.condition(%condition) %condition : i1, ... |
3492 | /// } do { |
3493 | /// ^bb0(%arg0: i1, ...): |
3494 | /// use(%true) |
3495 | /// ... |
3496 | struct WhileConditionTruth : public OpRewritePattern<WhileOp> { |
3497 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
3498 | |
3499 | LogicalResult matchAndRewrite(WhileOp op, |
3500 | PatternRewriter &rewriter) const override { |
3501 | auto term = op.getConditionOp(); |
3502 | |
3503 | // These variables serve to prevent creating duplicate constants |
3504 | // and hold constant true or false values. |
3505 | Value constantTrue = nullptr; |
3506 | |
3507 | bool replaced = false; |
3508 | for (auto yieldedAndBlockArgs : |
3509 | llvm::zip(term.getArgs(), op.getAfterArguments())) { |
3510 | if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) { |
3511 | if (!std::get<1>(yieldedAndBlockArgs).use_empty()) { |
3512 | if (!constantTrue) |
3513 | constantTrue = rewriter.create<arith::ConstantOp>( |
3514 | op.getLoc(), term.getCondition().getType(), |
3515 | rewriter.getBoolAttr(true)); |
3516 | |
3517 | rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs), |
3518 | constantTrue); |
3519 | replaced = true; |
3520 | } |
3521 | } |
3522 | } |
3523 | return success(IsSuccess: replaced); |
3524 | } |
3525 | }; |
3526 | |
3527 | /// Remove loop invariant arguments from `before` block of scf.while. |
3528 | /// A before block argument is considered loop invariant if :- |
3529 | /// 1. i-th yield operand is equal to the i-th while operand. |
3530 | /// 2. i-th yield operand is k-th after block argument which is (k+1)-th |
3531 | /// condition operand AND this (k+1)-th condition operand is equal to i-th |
3532 | /// iter argument/while operand. |
3533 | /// For the arguments which are removed, their uses inside scf.while |
3534 | /// are replaced with their corresponding initial value. |
3535 | /// |
3536 | /// Eg: |
3537 | /// INPUT :- |
3538 | /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b, |
3539 | /// ..., %argN_before = %N) |
3540 | /// { |
3541 | /// ... |
3542 | /// scf.condition(%cond) %arg1_before, %arg0_before, |
3543 | /// %arg2_before, %arg0_before, ... |
3544 | /// } do { |
3545 | /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, |
3546 | /// ..., %argK_after): |
3547 | /// ... |
3548 | /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN |
3549 | /// } |
3550 | /// |
3551 | /// OUTPUT :- |
3552 | /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before = |
3553 | /// %N) |
3554 | /// { |
3555 | /// ... |
3556 | /// scf.condition(%cond) %b, %a, %arg2_before, %a, ... |
3557 | /// } do { |
3558 | /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, |
3559 | /// ..., %argK_after): |
3560 | /// ... |
3561 | /// scf.yield %arg1_after, ..., %argN |
3562 | /// } |
3563 | /// |
3564 | /// EXPLANATION: |
3565 | /// We iterate over each yield operand. |
3566 | /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand |
3567 | /// %arg0_before, which in turn is the 0-th iter argument. So we |
3568 | /// remove 0-th before block argument and yield operand, and replace |
3569 | /// all uses of the 0-th before block argument with its initial value |
3570 | /// %a. |
3571 | /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial |
3572 | /// value. So we remove this operand and the corresponding before |
3573 | /// block argument and replace all uses of 1-th before block argument |
3574 | /// with %b. |
3575 | struct RemoveLoopInvariantArgsFromBeforeBlock |
3576 | : public OpRewritePattern<WhileOp> { |
3577 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
3578 | |
3579 | LogicalResult matchAndRewrite(WhileOp op, |
3580 | PatternRewriter &rewriter) const override { |
3581 | Block &afterBlock = *op.getAfterBody(); |
3582 | Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments(); |
3583 | ConditionOp condOp = op.getConditionOp(); |
3584 | OperandRange condOpArgs = condOp.getArgs(); |
3585 | Operation *yieldOp = afterBlock.getTerminator(); |
3586 | ValueRange yieldOpArgs = yieldOp->getOperands(); |
3587 | |
3588 | bool canSimplify = false; |
3589 | for (const auto &it : |
3590 | llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { |
3591 | auto index = static_cast<unsigned>(it.index()); |
3592 | auto [initVal, yieldOpArg] = it.value(); |
3593 | // If i-th yield operand is equal to the i-th operand of the scf.while, |
3594 | // the i-th before block argument is a loop invariant. |
3595 | if (yieldOpArg == initVal) { |
3596 | canSimplify = true; |
3597 | break; |
3598 | } |
3599 | // If the i-th yield operand is k-th after block argument, then we check |
3600 | // if the (k+1)-th condition op operand is equal to either the i-th before |
3601 | // block argument or the initial value of i-th before block argument. If |
3602 | // the comparison results `true`, i-th before block argument is a loop |
3603 | // invariant. |
3604 | auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); |
3605 | if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { |
3606 | Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; |
3607 | if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { |
3608 | canSimplify = true; |
3609 | break; |
3610 | } |
3611 | } |
3612 | } |
3613 | |
3614 | if (!canSimplify) |
3615 | return failure(); |
3616 | |
3617 | SmallVector<Value> newInitArgs, newYieldOpArgs; |
3618 | DenseMap<unsigned, Value> beforeBlockInitValMap; |
3619 | SmallVector<Location> newBeforeBlockArgLocs; |
3620 | for (const auto &it : |
3621 | llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { |
3622 | auto index = static_cast<unsigned>(it.index()); |
3623 | auto [initVal, yieldOpArg] = it.value(); |
3624 | |
3625 | // If i-th yield operand is equal to the i-th operand of the scf.while, |
3626 | // the i-th before block argument is a loop invariant. |
3627 | if (yieldOpArg == initVal) { |
3628 | beforeBlockInitValMap.insert({index, initVal}); |
3629 | continue; |
3630 | } else { |
3631 | // If the i-th yield operand is k-th after block argument, then we check |
3632 | // if the (k+1)-th condition op operand is equal to either the i-th |
3633 | // before block argument or the initial value of i-th before block |
3634 | // argument. If the comparison results `true`, i-th before block |
3635 | // argument is a loop invariant. |
3636 | auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg); |
3637 | if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { |
3638 | Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; |
3639 | if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { |
3640 | beforeBlockInitValMap.insert({index, initVal}); |
3641 | continue; |
3642 | } |
3643 | } |
3644 | } |
3645 | newInitArgs.emplace_back(initVal); |
3646 | newYieldOpArgs.emplace_back(yieldOpArg); |
3647 | newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc()); |
3648 | } |
3649 | |
3650 | { |
3651 | OpBuilder::InsertionGuard g(rewriter); |
3652 | rewriter.setInsertionPoint(yieldOp); |
3653 | rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs); |
3654 | } |
3655 | |
3656 | auto newWhile = |
3657 | rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs); |
3658 | |
3659 | Block &newBeforeBlock = *rewriter.createBlock( |
3660 | &newWhile.getBefore(), /*insertPt*/ {}, |
3661 | ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs); |
3662 | |
3663 | Block &beforeBlock = *op.getBeforeBody(); |
3664 | SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments()); |
3665 | // For each i-th before block argument we find it's replacement value as :- |
3666 | // 1. If i-th before block argument is a loop invariant, we fetch it's |
3667 | // initial value from `beforeBlockInitValMap` by querying for key `i`. |
3668 | // 2. Else we fetch j-th new before block argument as the replacement |
3669 | // value of i-th before block argument. |
3670 | for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) { |
3671 | // If the index 'i' argument was a loop invariant we fetch it's initial |
3672 | // value from `beforeBlockInitValMap`. |
3673 | if (beforeBlockInitValMap.count(Val: i) != 0) |
3674 | newBeforeBlockArgs[i] = beforeBlockInitValMap[i]; |
3675 | else |
3676 | newBeforeBlockArgs[i] = newBeforeBlock.getArgument(i: j++); |
3677 | } |
3678 | |
3679 | rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock, argValues: newBeforeBlockArgs); |
3680 | rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(), |
3681 | newWhile.getAfter().begin()); |
3682 | |
3683 | rewriter.replaceOp(op, newWhile.getResults()); |
3684 | return success(); |
3685 | } |
3686 | }; |
3687 | |
3688 | /// Remove loop invariant value from result (condition op) of scf.while. |
3689 | /// A value is considered loop invariant if the final value yielded by |
3690 | /// scf.condition is defined outside of the `before` block. We remove the |
3691 | /// corresponding argument in `after` block and replace the use with the value. |
3692 | /// We also replace the use of the corresponding result of scf.while with the |
3693 | /// value. |
3694 | /// |
3695 | /// Eg: |
3696 | /// INPUT :- |
3697 | /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ..., |
3698 | /// %argN_before = %N) { |
3699 | /// ... |
3700 | /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ... |
3701 | /// } do { |
3702 | /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after): |
3703 | /// ... |
3704 | /// some_func(%arg1_after) |
3705 | /// ... |
3706 | /// scf.yield %arg0_after, %arg2_after, ..., %argN_after |
3707 | /// } |
3708 | /// |
3709 | /// OUTPUT :- |
3710 | /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) { |
3711 | /// ... |
3712 | /// scf.condition(%cond) %arg0, %arg1, ..., %argM |
3713 | /// } do { |
3714 | /// ^bb0(%arg0, %arg3, ..., %argM): |
3715 | /// ... |
3716 | /// some_func(%a) |
3717 | /// ... |
3718 | /// scf.yield %arg0, %b, ..., %argN |
3719 | /// } |
3720 | /// |
3721 | /// EXPLANATION: |
3722 | /// 1. The 1-th and 2-th operand of scf.condition are defined outside the |
3723 | /// before block of scf.while, so they get removed. |
3724 | /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are |
3725 | /// replaced by %b. |
3726 | /// 3. The corresponding after block argument %arg1_after's uses are |
3727 | /// replaced by %a and %arg2_after's uses are replaced by %b. |
3728 | struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> { |
3729 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
3730 | |
3731 | LogicalResult matchAndRewrite(WhileOp op, |
3732 | PatternRewriter &rewriter) const override { |
3733 | Block &beforeBlock = *op.getBeforeBody(); |
3734 | ConditionOp condOp = op.getConditionOp(); |
3735 | OperandRange condOpArgs = condOp.getArgs(); |
3736 | |
3737 | bool canSimplify = false; |
3738 | for (Value condOpArg : condOpArgs) { |
3739 | // Those values not defined within `before` block will be considered as |
3740 | // loop invariant values. We map the corresponding `index` with their |
3741 | // value. |
3742 | if (condOpArg.getParentBlock() != &beforeBlock) { |
3743 | canSimplify = true; |
3744 | break; |
3745 | } |
3746 | } |
3747 | |
3748 | if (!canSimplify) |
3749 | return failure(); |
3750 | |
3751 | Block::BlockArgListType afterBlockArgs = op.getAfterArguments(); |
3752 | |
3753 | SmallVector<Value> newCondOpArgs; |
3754 | SmallVector<Type> newAfterBlockType; |
3755 | DenseMap<unsigned, Value> condOpInitValMap; |
3756 | SmallVector<Location> newAfterBlockArgLocs; |
3757 | for (const auto &it : llvm::enumerate(condOpArgs)) { |
3758 | auto index = static_cast<unsigned>(it.index()); |
3759 | Value condOpArg = it.value(); |
3760 | // Those values not defined within `before` block will be considered as |
3761 | // loop invariant values. We map the corresponding `index` with their |
3762 | // value. |
3763 | if (condOpArg.getParentBlock() != &beforeBlock) { |
3764 | condOpInitValMap.insert({index, condOpArg}); |
3765 | } else { |
3766 | newCondOpArgs.emplace_back(condOpArg); |
3767 | newAfterBlockType.emplace_back(condOpArg.getType()); |
3768 | newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc()); |
3769 | } |
3770 | } |
3771 | |
3772 | { |
3773 | OpBuilder::InsertionGuard g(rewriter); |
3774 | rewriter.setInsertionPoint(condOp); |
3775 | rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(), |
3776 | newCondOpArgs); |
3777 | } |
3778 | |
3779 | auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType, |
3780 | op.getOperands()); |
3781 | |
3782 | Block &newAfterBlock = |
3783 | *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {}, |
3784 | newAfterBlockType, newAfterBlockArgLocs); |
3785 | |
3786 | Block &afterBlock = *op.getAfterBody(); |
3787 | // Since a new scf.condition op was created, we need to fetch the new |
3788 | // `after` block arguments which will be used while replacing operations of |
3789 | // previous scf.while's `after` blocks. We'd also be fetching new result |
3790 | // values too. |
3791 | SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments()); |
3792 | SmallVector<Value> newWhileResults(afterBlock.getNumArguments()); |
3793 | for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) { |
3794 | Value afterBlockArg, result; |
3795 | // If index 'i' argument was loop invariant we fetch it's value from the |
3796 | // `condOpInitMap` map. |
3797 | if (condOpInitValMap.count(Val: i) != 0) { |
3798 | afterBlockArg = condOpInitValMap[i]; |
3799 | result = afterBlockArg; |
3800 | } else { |
3801 | afterBlockArg = newAfterBlock.getArgument(i: j); |
3802 | result = newWhile.getResult(j); |
3803 | j++; |
3804 | } |
3805 | newAfterBlockArgs[i] = afterBlockArg; |
3806 | newWhileResults[i] = result; |
3807 | } |
3808 | |
3809 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: newAfterBlockArgs); |
3810 | rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), |
3811 | newWhile.getBefore().begin()); |
3812 | |
3813 | rewriter.replaceOp(op, newWhileResults); |
3814 | return success(); |
3815 | } |
3816 | }; |
3817 | |
3818 | /// Remove WhileOp results that are also unused in 'after' block. |
3819 | /// |
3820 | /// %0:2 = scf.while () : () -> (i32, i64) { |
3821 | /// %condition = "test.condition"() : () -> i1 |
3822 | /// %v1 = "test.get_some_value"() : () -> i32 |
3823 | /// %v2 = "test.get_some_value"() : () -> i64 |
3824 | /// scf.condition(%condition) %v1, %v2 : i32, i64 |
3825 | /// } do { |
3826 | /// ^bb0(%arg0: i32, %arg1: i64): |
3827 | /// "test.use"(%arg0) : (i32) -> () |
3828 | /// scf.yield |
3829 | /// } |
3830 | /// return %0#0 : i32 |
3831 | /// |
3832 | /// becomes |
3833 | /// %0 = scf.while () : () -> (i32) { |
3834 | /// %condition = "test.condition"() : () -> i1 |
3835 | /// %v1 = "test.get_some_value"() : () -> i32 |
3836 | /// %v2 = "test.get_some_value"() : () -> i64 |
3837 | /// scf.condition(%condition) %v1 : i32 |
3838 | /// } do { |
3839 | /// ^bb0(%arg0: i32): |
3840 | /// "test.use"(%arg0) : (i32) -> () |
3841 | /// scf.yield |
3842 | /// } |
3843 | /// return %0 : i32 |
3844 | struct WhileUnusedResult : public OpRewritePattern<WhileOp> { |
3845 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
3846 | |
3847 | LogicalResult matchAndRewrite(WhileOp op, |
3848 | PatternRewriter &rewriter) const override { |
3849 | auto term = op.getConditionOp(); |
3850 | auto afterArgs = op.getAfterArguments(); |
3851 | auto termArgs = term.getArgs(); |
3852 | |
3853 | // Collect results mapping, new terminator args and new result types. |
3854 | SmallVector<unsigned> newResultsIndices; |
3855 | SmallVector<Type> newResultTypes; |
3856 | SmallVector<Value> newTermArgs; |
3857 | SmallVector<Location> newArgLocs; |
3858 | bool needUpdate = false; |
3859 | for (const auto &it : |
3860 | llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) { |
3861 | auto i = static_cast<unsigned>(it.index()); |
3862 | Value result = std::get<0>(it.value()); |
3863 | Value afterArg = std::get<1>(it.value()); |
3864 | Value termArg = std::get<2>(it.value()); |
3865 | if (result.use_empty() && afterArg.use_empty()) { |
3866 | needUpdate = true; |
3867 | } else { |
3868 | newResultsIndices.emplace_back(i); |
3869 | newTermArgs.emplace_back(termArg); |
3870 | newResultTypes.emplace_back(result.getType()); |
3871 | newArgLocs.emplace_back(result.getLoc()); |
3872 | } |
3873 | } |
3874 | |
3875 | if (!needUpdate) |
3876 | return failure(); |
3877 | |
3878 | { |
3879 | OpBuilder::InsertionGuard g(rewriter); |
3880 | rewriter.setInsertionPoint(term); |
3881 | rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(), |
3882 | newTermArgs); |
3883 | } |
3884 | |
3885 | auto newWhile = |
3886 | rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits()); |
3887 | |
3888 | Block &newAfterBlock = *rewriter.createBlock( |
3889 | &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs); |
3890 | |
3891 | // Build new results list and new after block args (unused entries will be |
3892 | // null). |
3893 | SmallVector<Value> newResults(op.getNumResults()); |
3894 | SmallVector<Value> newAfterBlockArgs(op.getNumResults()); |
3895 | for (const auto &it : llvm::enumerate(First&: newResultsIndices)) { |
3896 | newResults[it.value()] = newWhile.getResult(it.index()); |
3897 | newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(i: it.index()); |
3898 | } |
3899 | |
3900 | rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), |
3901 | newWhile.getBefore().begin()); |
3902 | |
3903 | Block &afterBlock = *op.getAfterBody(); |
3904 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: newAfterBlockArgs); |
3905 | |
3906 | rewriter.replaceOp(op, newResults); |
3907 | return success(); |
3908 | } |
3909 | }; |
3910 | |
3911 | /// Replace operations equivalent to the condition in the do block with true, |
3912 | /// since otherwise the block would not be evaluated. |
3913 | /// |
3914 | /// scf.while (..) : (i32, ...) -> ... { |
3915 | /// %z = ... : i32 |
3916 | /// %condition = cmpi pred %z, %a |
3917 | /// scf.condition(%condition) %z : i32, ... |
3918 | /// } do { |
3919 | /// ^bb0(%arg0: i32, ...): |
3920 | /// %condition2 = cmpi pred %arg0, %a |
3921 | /// use(%condition2) |
3922 | /// ... |
3923 | /// |
3924 | /// becomes |
3925 | /// scf.while (..) : (i32, ...) -> ... { |
3926 | /// %z = ... : i32 |
3927 | /// %condition = cmpi pred %z, %a |
3928 | /// scf.condition(%condition) %z : i32, ... |
3929 | /// } do { |
3930 | /// ^bb0(%arg0: i32, ...): |
3931 | /// use(%true) |
3932 | /// ... |
3933 | struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> { |
3934 | using OpRewritePattern<scf::WhileOp>::OpRewritePattern; |
3935 | |
3936 | LogicalResult matchAndRewrite(scf::WhileOp op, |
3937 | PatternRewriter &rewriter) const override { |
3938 | using namespace scf; |
3939 | auto cond = op.getConditionOp(); |
3940 | auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>(); |
3941 | if (!cmp) |
3942 | return failure(); |
3943 | bool changed = false; |
3944 | for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) { |
3945 | for (size_t opIdx = 0; opIdx < 2; opIdx++) { |
3946 | if (std::get<0>(tup) != cmp.getOperand(opIdx)) |
3947 | continue; |
3948 | for (OpOperand &u : |
3949 | llvm::make_early_inc_range(std::get<1>(tup).getUses())) { |
3950 | auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner()); |
3951 | if (!cmp2) |
3952 | continue; |
3953 | // For a binary operator 1-opIdx gets the other side. |
3954 | if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx)) |
3955 | continue; |
3956 | bool samePredicate; |
3957 | if (cmp2.getPredicate() == cmp.getPredicate()) |
3958 | samePredicate = true; |
3959 | else if (cmp2.getPredicate() == |
3960 | arith::invertPredicate(cmp.getPredicate())) |
3961 | samePredicate = false; |
3962 | else |
3963 | continue; |
3964 | |
3965 | rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate, |
3966 | 1); |
3967 | changed = true; |
3968 | } |
3969 | } |
3970 | } |
3971 | return success(IsSuccess: changed); |
3972 | } |
3973 | }; |
3974 | |
3975 | /// Remove unused init/yield args. |
3976 | struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> { |
3977 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
3978 | |
3979 | LogicalResult matchAndRewrite(WhileOp op, |
3980 | PatternRewriter &rewriter) const override { |
3981 | |
3982 | if (!llvm::any_of(op.getBeforeArguments(), |
3983 | [](Value arg) { return arg.use_empty(); })) |
3984 | return rewriter.notifyMatchFailure(op, "No args to remove"); |
3985 | |
3986 | YieldOp yield = op.getYieldOp(); |
3987 | |
3988 | // Collect results mapping, new terminator args and new result types. |
3989 | SmallVector<Value> newYields; |
3990 | SmallVector<Value> newInits; |
3991 | llvm::BitVector argsToErase; |
3992 | |
3993 | size_t argsCount = op.getBeforeArguments().size(); |
3994 | newYields.reserve(N: argsCount); |
3995 | newInits.reserve(N: argsCount); |
3996 | argsToErase.reserve(N: argsCount); |
3997 | for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip( |
3998 | op.getBeforeArguments(), yield.getOperands(), op.getInits())) { |
3999 | if (beforeArg.use_empty()) { |
4000 | argsToErase.push_back(true); |
4001 | } else { |
4002 | argsToErase.push_back(false); |
4003 | newYields.emplace_back(yieldValue); |
4004 | newInits.emplace_back(initValue); |
4005 | } |
4006 | } |
4007 | |
4008 | Block &beforeBlock = *op.getBeforeBody(); |
4009 | Block &afterBlock = *op.getAfterBody(); |
4010 | |
4011 | beforeBlock.eraseArguments(eraseIndices: argsToErase); |
4012 | |
4013 | Location loc = op.getLoc(); |
4014 | auto newWhileOp = |
4015 | rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits, |
4016 | /*beforeBody*/ nullptr, /*afterBody*/ nullptr); |
4017 | Block &newBeforeBlock = *newWhileOp.getBeforeBody(); |
4018 | Block &newAfterBlock = *newWhileOp.getAfterBody(); |
4019 | |
4020 | OpBuilder::InsertionGuard g(rewriter); |
4021 | rewriter.setInsertionPoint(yield); |
4022 | rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields); |
4023 | |
4024 | rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock, |
4025 | argValues: newBeforeBlock.getArguments()); |
4026 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, |
4027 | argValues: newAfterBlock.getArguments()); |
4028 | |
4029 | rewriter.replaceOp(op, newWhileOp.getResults()); |
4030 | return success(); |
4031 | } |
4032 | }; |
4033 | |
4034 | /// Remove duplicated ConditionOp args. |
4035 | struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> { |
4036 | using OpRewritePattern::OpRewritePattern; |
4037 | |
4038 | LogicalResult matchAndRewrite(WhileOp op, |
4039 | PatternRewriter &rewriter) const override { |
4040 | ConditionOp condOp = op.getConditionOp(); |
4041 | ValueRange condOpArgs = condOp.getArgs(); |
4042 | |
4043 | llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs); |
4044 | |
4045 | if (argsSet.size() == condOpArgs.size()) |
4046 | return rewriter.notifyMatchFailure(op, "No results to remove"); |
4047 | |
4048 | llvm::SmallDenseMap<Value, unsigned> argsMap; |
4049 | SmallVector<Value> newArgs; |
4050 | argsMap.reserve(NumEntries: condOpArgs.size()); |
4051 | newArgs.reserve(N: condOpArgs.size()); |
4052 | for (Value arg : condOpArgs) { |
4053 | if (!argsMap.count(arg)) { |
4054 | auto pos = static_cast<unsigned>(argsMap.size()); |
4055 | argsMap.insert({arg, pos}); |
4056 | newArgs.emplace_back(arg); |
4057 | } |
4058 | } |
4059 | |
4060 | ValueRange argsRange(newArgs); |
4061 | |
4062 | Location loc = op.getLoc(); |
4063 | auto newWhileOp = rewriter.create<scf::WhileOp>( |
4064 | loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr, |
4065 | /*afterBody*/ nullptr); |
4066 | Block &newBeforeBlock = *newWhileOp.getBeforeBody(); |
4067 | Block &newAfterBlock = *newWhileOp.getAfterBody(); |
4068 | |
4069 | SmallVector<Value> afterArgsMapping; |
4070 | SmallVector<Value> resultsMapping; |
4071 | for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) { |
4072 | auto it = argsMap.find(arg); |
4073 | assert(it != argsMap.end()); |
4074 | auto pos = it->second; |
4075 | afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos)); |
4076 | resultsMapping.emplace_back(newWhileOp->getResult(pos)); |
4077 | } |
4078 | |
4079 | OpBuilder::InsertionGuard g(rewriter); |
4080 | rewriter.setInsertionPoint(condOp); |
4081 | rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(), |
4082 | argsRange); |
4083 | |
4084 | Block &beforeBlock = *op.getBeforeBody(); |
4085 | Block &afterBlock = *op.getAfterBody(); |
4086 | |
4087 | rewriter.mergeBlocks(source: &beforeBlock, dest: &newBeforeBlock, |
4088 | argValues: newBeforeBlock.getArguments()); |
4089 | rewriter.mergeBlocks(source: &afterBlock, dest: &newAfterBlock, argValues: afterArgsMapping); |
4090 | rewriter.replaceOp(op, resultsMapping); |
4091 | return success(); |
4092 | } |
4093 | }; |
4094 | |
4095 | /// If both ranges contain same values return mappping indices from args2 to |
4096 | /// args1. Otherwise return std::nullopt. |
4097 | static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1, |
4098 | ValueRange args2) { |
4099 | if (args1.size() != args2.size()) |
4100 | return std::nullopt; |
4101 | |
4102 | SmallVector<unsigned> ret(args1.size()); |
4103 | for (auto &&[i, arg1] : llvm::enumerate(First&: args1)) { |
4104 | auto it = llvm::find(Range&: args2, Val: arg1); |
4105 | if (it == args2.end()) |
4106 | return std::nullopt; |
4107 | |
4108 | ret[std::distance(first: args2.begin(), last: it)] = static_cast<unsigned>(i); |
4109 | } |
4110 | |
4111 | return ret; |
4112 | } |
4113 | |
4114 | static bool hasDuplicates(ValueRange args) { |
4115 | llvm::SmallDenseSet<Value> set; |
4116 | for (Value arg : args) { |
4117 | if (!set.insert(V: arg).second) |
4118 | return true; |
4119 | } |
4120 | return false; |
4121 | } |
4122 | |
4123 | /// If `before` block args are directly forwarded to `scf.condition`, rearrange |
4124 | /// `scf.condition` args into same order as block args. Update `after` block |
4125 | /// args and op result values accordingly. |
4126 | /// Needed to simplify `scf.while` -> `scf.for` uplifting. |
4127 | struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { |
4128 | using OpRewritePattern::OpRewritePattern; |
4129 | |
4130 | LogicalResult matchAndRewrite(WhileOp loop, |
4131 | PatternRewriter &rewriter) const override { |
4132 | auto oldBefore = loop.getBeforeBody(); |
4133 | ConditionOp oldTerm = loop.getConditionOp(); |
4134 | ValueRange beforeArgs = oldBefore->getArguments(); |
4135 | ValueRange termArgs = oldTerm.getArgs(); |
4136 | if (beforeArgs == termArgs) |
4137 | return failure(); |
4138 | |
4139 | if (hasDuplicates(args: termArgs)) |
4140 | return failure(); |
4141 | |
4142 | auto mapping = getArgsMapping(args1: beforeArgs, args2: termArgs); |
4143 | if (!mapping) |
4144 | return failure(); |
4145 | |
4146 | { |
4147 | OpBuilder::InsertionGuard g(rewriter); |
4148 | rewriter.setInsertionPoint(oldTerm); |
4149 | rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(), |
4150 | beforeArgs); |
4151 | } |
4152 | |
4153 | auto oldAfter = loop.getAfterBody(); |
4154 | |
4155 | SmallVector<Type> newResultTypes(beforeArgs.size()); |
4156 | for (auto &&[i, j] : llvm::enumerate(*mapping)) |
4157 | newResultTypes[j] = loop.getResult(i).getType(); |
4158 | |
4159 | auto newLoop = rewriter.create<WhileOp>( |
4160 | loop.getLoc(), newResultTypes, loop.getInits(), |
4161 | /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr); |
4162 | auto newBefore = newLoop.getBeforeBody(); |
4163 | auto newAfter = newLoop.getAfterBody(); |
4164 | |
4165 | SmallVector<Value> newResults(beforeArgs.size()); |
4166 | SmallVector<Value> newAfterArgs(beforeArgs.size()); |
4167 | for (auto &&[i, j] : llvm::enumerate(*mapping)) { |
4168 | newResults[i] = newLoop.getResult(j); |
4169 | newAfterArgs[i] = newAfter->getArgument(j); |
4170 | } |
4171 | |
4172 | rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(), |
4173 | newBefore->getArguments()); |
4174 | rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(), |
4175 | newAfterArgs); |
4176 | |
4177 | rewriter.replaceOp(loop, newResults); |
4178 | return success(); |
4179 | } |
4180 | }; |
4181 | } // namespace |
4182 | |
4183 | void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, |
4184 | MLIRContext *context) { |
4185 | results.add<RemoveLoopInvariantArgsFromBeforeBlock, |
4186 | RemoveLoopInvariantValueYielded, WhileConditionTruth, |
4187 | WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, |
4188 | WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context); |
4189 | } |
4190 | |
4191 | //===----------------------------------------------------------------------===// |
4192 | // IndexSwitchOp |
4193 | //===----------------------------------------------------------------------===// |
4194 | |
4195 | /// Parse the case regions and values. |
4196 | static ParseResult |
4197 | parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, |
4198 | SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) { |
4199 | SmallVector<int64_t> caseValues; |
4200 | while (succeeded(Result: p.parseOptionalKeyword(keyword: "case"))) { |
4201 | int64_t value; |
4202 | Region ®ion = *caseRegions.emplace_back(Args: std::make_unique<Region>()); |
4203 | if (p.parseInteger(result&: value) || p.parseRegion(region, /*arguments=*/{})) |
4204 | return failure(); |
4205 | caseValues.push_back(Elt: value); |
4206 | } |
4207 | cases = p.getBuilder().getDenseI64ArrayAttr(caseValues); |
4208 | return success(); |
4209 | } |
4210 | |
4211 | /// Print the case regions and values. |
4212 | static void printSwitchCases(OpAsmPrinter &p, Operation *op, |
4213 | DenseI64ArrayAttr cases, RegionRange caseRegions) { |
4214 | for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { |
4215 | p.printNewline(); |
4216 | p << "case "<< value << ' '; |
4217 | p.printRegion(*region, /*printEntryBlockArgs=*/false); |
4218 | } |
4219 | } |
4220 | |
4221 | LogicalResult scf::IndexSwitchOp::verify() { |
4222 | if (getCases().size() != getCaseRegions().size()) { |
4223 | return emitOpError("has ") |
4224 | << getCaseRegions().size() << " case regions but " |
4225 | << getCases().size() << " case values"; |
4226 | } |
4227 | |
4228 | DenseSet<int64_t> valueSet; |
4229 | for (int64_t value : getCases()) |
4230 | if (!valueSet.insert(value).second) |
4231 | return emitOpError("has duplicate case value: ") << value; |
4232 | auto verifyRegion = [&](Region ®ion, const Twine &name) -> LogicalResult { |
4233 | auto yield = dyn_cast<YieldOp>(region.front().back()); |
4234 | if (!yield) |
4235 | return emitOpError("expected region to end with scf.yield, but got ") |
4236 | << region.front().back().getName(); |
4237 | |
4238 | if (yield.getNumOperands() != getNumResults()) { |
4239 | return (emitOpError("expected each region to return ") |
4240 | << getNumResults() << " values, but "<< name << " returns " |
4241 | << yield.getNumOperands()) |
4242 | .attachNote(yield.getLoc()) |
4243 | << "see yield operation here"; |
4244 | } |
4245 | for (auto [idx, result, operand] : |
4246 | llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(), |
4247 | yield.getOperandTypes())) { |
4248 | if (result == operand) |
4249 | continue; |
4250 | return (emitOpError("expected result #") |
4251 | << idx << " of each region to be "<< result) |
4252 | .attachNote(yield.getLoc()) |
4253 | << name << " returns "<< operand << " here"; |
4254 | } |
4255 | return success(); |
4256 | }; |
4257 | |
4258 | if (failed(verifyRegion(getDefaultRegion(), "default region"))) |
4259 | return failure(); |
4260 | for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions())) |
4261 | if (failed(verifyRegion(caseRegion, "case region #"+ Twine(idx)))) |
4262 | return failure(); |
4263 | |
4264 | return success(); |
4265 | } |
4266 | |
4267 | unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); } |
4268 | |
4269 | Block &scf::IndexSwitchOp::getDefaultBlock() { |
4270 | return getDefaultRegion().front(); |
4271 | } |
4272 | |
4273 | Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) { |
4274 | assert(idx < getNumCases() && "case index out-of-bounds"); |
4275 | return getCaseRegions()[idx].front(); |
4276 | } |
4277 | |
4278 | void IndexSwitchOp::getSuccessorRegions( |
4279 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) { |
4280 | // All regions branch back to the parent op. |
4281 | if (!point.isParent()) { |
4282 | successors.emplace_back(getResults()); |
4283 | return; |
4284 | } |
4285 | |
4286 | llvm::append_range(successors, getRegions()); |
4287 | } |
4288 | |
4289 | void IndexSwitchOp::getEntrySuccessorRegions( |
4290 | ArrayRef<Attribute> operands, |
4291 | SmallVectorImpl<RegionSuccessor> &successors) { |
4292 | FoldAdaptor adaptor(operands, *this); |
4293 | |
4294 | // If a constant was not provided, all regions are possible successors. |
4295 | auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg()); |
4296 | if (!arg) { |
4297 | llvm::append_range(successors, getRegions()); |
4298 | return; |
4299 | } |
4300 | |
4301 | // Otherwise, try to find a case with a matching value. If not, the |
4302 | // default region is the only successor. |
4303 | for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) { |
4304 | if (caseValue == arg.getInt()) { |
4305 | successors.emplace_back(&caseRegion); |
4306 | return; |
4307 | } |
4308 | } |
4309 | successors.emplace_back(&getDefaultRegion()); |
4310 | } |
4311 | |
4312 | void IndexSwitchOp::getRegionInvocationBounds( |
4313 | ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { |
4314 | auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front()); |
4315 | if (!operandValue) { |
4316 | // All regions are invoked at most once. |
4317 | bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1)); |
4318 | return; |
4319 | } |
4320 | |
4321 | unsigned liveIndex = getNumRegions() - 1; |
4322 | const auto *it = llvm::find(getCases(), operandValue.getInt()); |
4323 | if (it != getCases().end()) |
4324 | liveIndex = std::distance(getCases().begin(), it); |
4325 | for (unsigned i = 0, e = getNumRegions(); i < e; ++i) |
4326 | bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex); |
4327 | } |
4328 | |
4329 | struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { |
4330 | using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern; |
4331 | |
4332 | LogicalResult matchAndRewrite(scf::IndexSwitchOp op, |
4333 | PatternRewriter &rewriter) const override { |
4334 | // If `op.getArg()` is a constant, select the region that matches with |
4335 | // the constant value. Use the default region if no matche is found. |
4336 | std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg()); |
4337 | if (!maybeCst.has_value()) |
4338 | return failure(); |
4339 | int64_t cst = *maybeCst; |
4340 | int64_t caseIdx, e = op.getNumCases(); |
4341 | for (caseIdx = 0; caseIdx < e; ++caseIdx) { |
4342 | if (cst == op.getCases()[caseIdx]) |
4343 | break; |
4344 | } |
4345 | |
4346 | Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx] |
4347 | : op.getDefaultRegion(); |
4348 | Block &source = r.front(); |
4349 | Operation *terminator = source.getTerminator(); |
4350 | SmallVector<Value> results = terminator->getOperands(); |
4351 | |
4352 | rewriter.inlineBlockBefore(&source, op); |
4353 | rewriter.eraseOp(op: terminator); |
4354 | // Replace the operation with a potentially empty list of results. |
4355 | // Fold mechanism doesn't support the case where the result list is empty. |
4356 | rewriter.replaceOp(op, results); |
4357 | |
4358 | return success(); |
4359 | } |
4360 | }; |
4361 | |
4362 | void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, |
4363 | MLIRContext *context) { |
4364 | results.add<FoldConstantCase>(context); |
4365 | } |
4366 | |
4367 | //===----------------------------------------------------------------------===// |
4368 | // TableGen'd op method definitions |
4369 | //===----------------------------------------------------------------------===// |
4370 | |
4371 | #define GET_OP_CLASSES |
4372 | #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc" |
4373 |
Definitions
- SCFInlinerInterface
- isLegalToInline
- isLegalToInline
- handleTerminator
- buildTerminatedBody
- verifyAndGetTerminator
- replaceOpWithRegion
- SingleBlockExecuteInliner
- matchAndRewrite
- MultiBlockExecuteInliner
- matchAndRewrite
- printInitializationList
- getForInductionVarOwner
- promote
- buildLoopNest
- buildLoopNest
- replaceAndCastForOpIterArg
- ForOpIterArgsFolder
- matchAndRewrite
- computeConstDiff
- SimplifyTrivialLoops
- matchAndRewrite
- ForOpTensorCastFolder
- matchAndRewrite
- getForallOpThreadIndexOwner
- DimOfForallOp
- matchAndRewrite
- ForallOpControlOperandsFolder
- matchAndRewrite
- ForallOpIterArgsFolder
- matchAndRewrite
- ForallOpSingleOrZeroIterationDimsFolder
- matchAndRewrite
- ForallOpReplaceConstantInductionVar
- matchAndRewrite
- FoldTensorCastOfOutputIntoForallOp
- TypeCast
- matchAndRewrite
- insideMutuallyExclusiveBranches
- RemoveUnusedResults
- transferBody
- matchAndRewrite
- RemoveStaticCondition
- matchAndRewrite
- ConvertTrivialIfToSelect
- matchAndRewrite
- ConditionPropagation
- matchAndRewrite
- ReplaceIfYieldWithConditionOrValue
- matchAndRewrite
- CombineIfs
- matchAndRewrite
- RemoveEmptyElseBranch
- matchAndRewrite
- CombineNestedIfs
- matchAndRewrite
- getParallelForInductionVarOwner
- ParallelOpSingleOrZeroIterationDimsFolder
- matchAndRewrite
- MergeNestedParallelLoops
- matchAndRewrite
- verifyTypeRangesMatch
- WhileConditionTruth
- matchAndRewrite
- RemoveLoopInvariantArgsFromBeforeBlock
- matchAndRewrite
- RemoveLoopInvariantValueYielded
- matchAndRewrite
- WhileUnusedResult
- matchAndRewrite
- WhileCmpCond
- matchAndRewrite
- WhileRemoveUnusedArgs
- matchAndRewrite
- WhileRemoveDuplicatedResults
- matchAndRewrite
- getArgsMapping
- hasDuplicates
- WhileOpAlignBeforeArgs
- matchAndRewrite
- parseSwitchCases
- printSwitchCases
- FoldConstantCase
Learn to use CMake with our Intro Training
Find out more