1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
13 | #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
15 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
19 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
20 | #include "mlir/IR/Dialect.h" |
21 | #include "mlir/IR/Operation.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::bufferization; |
26 | using namespace mlir::scf; |
27 | |
28 | namespace mlir { |
29 | namespace scf { |
30 | namespace { |
31 | |
32 | /// Helper function for loop bufferization. Cast the given buffer to the given |
33 | /// memref type. |
34 | static Value castBuffer(OpBuilder &b, Value buffer, Type type) { |
35 | assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType" ); |
36 | assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType" ); |
37 | // If the buffer already has the correct type, no cast is needed. |
38 | if (buffer.getType() == type) |
39 | return buffer; |
40 | // TODO: In case `type` has a layout map that is not the fully dynamic |
41 | // one, we may not be able to cast the buffer. In that case, the loop |
42 | // iter_arg's layout map must be changed (see uses of `castBuffer`). |
43 | assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && |
44 | "scf.while op bufferization: cast incompatible" ); |
45 | return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); |
46 | } |
47 | |
48 | /// Helper function for loop bufferization. Return "true" if the given value |
49 | /// is guaranteed to not alias with an external tensor apart from values in |
50 | /// `exceptions`. A value is external if it is defined outside of the given |
51 | /// region or if it is an entry block argument of the region. |
52 | static bool doesNotAliasExternalValue(Value value, Region *region, |
53 | ValueRange exceptions, |
54 | const OneShotAnalysisState &state) { |
55 | assert(region->getBlocks().size() == 1 && |
56 | "expected region with single block" ); |
57 | bool result = true; |
58 | state.applyOnAliases(v: value, fun: [&](Value alias) { |
59 | if (llvm::is_contained(Range&: exceptions, Element: alias)) |
60 | return; |
61 | Region *aliasRegion = alias.getParentRegion(); |
62 | if (isa<BlockArgument>(Val: alias) && !region->isProperAncestor(other: aliasRegion)) |
63 | result = false; |
64 | if (isa<OpResult>(Val: alias) && !region->isAncestor(other: aliasRegion)) |
65 | result = false; |
66 | }); |
67 | return result; |
68 | } |
69 | |
70 | /// Bufferization of scf.condition. |
71 | struct ConditionOpInterface |
72 | : public BufferizableOpInterface::ExternalModel<ConditionOpInterface, |
73 | scf::ConditionOp> { |
74 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
75 | const AnalysisState &state) const { |
76 | return true; |
77 | } |
78 | |
79 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
80 | const AnalysisState &state) const { |
81 | return false; |
82 | } |
83 | |
84 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
85 | const AnalysisState &state) const { |
86 | return {}; |
87 | } |
88 | |
89 | bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
90 | const AnalysisState &state) const { |
91 | // Condition operands always bufferize inplace. Otherwise, an alloc + copy |
92 | // may be generated inside the block. We should not return/yield allocations |
93 | // when possible. |
94 | return true; |
95 | } |
96 | |
97 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
98 | const BufferizationOptions &options) const { |
99 | auto conditionOp = cast<scf::ConditionOp>(op); |
100 | auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp()); |
101 | |
102 | SmallVector<Value> newArgs; |
103 | for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { |
104 | Value value = it.value(); |
105 | if (isa<TensorType>(value.getType())) { |
106 | FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options); |
107 | if (failed(maybeBuffer)) |
108 | return failure(); |
109 | FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
110 | whileOp.getAfterArguments()[it.index()], options); |
111 | if (failed(resultType)) |
112 | return failure(); |
113 | Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType); |
114 | newArgs.push_back(buffer); |
115 | } else { |
116 | newArgs.push_back(value); |
117 | } |
118 | } |
119 | |
120 | replaceOpWithNewBufferizedOp<scf::ConditionOp>( |
121 | rewriter, op, conditionOp.getCondition(), newArgs); |
122 | return success(); |
123 | } |
124 | }; |
125 | |
126 | /// Return the unique scf.yield op. If there are multiple or no scf.yield ops, |
127 | /// return an empty op. |
128 | static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) { |
129 | scf::YieldOp result; |
130 | for (Block &block : executeRegionOp.getRegion()) { |
131 | if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) { |
132 | if (result) |
133 | return {}; |
134 | result = yieldOp; |
135 | } |
136 | } |
137 | return result; |
138 | } |
139 | |
140 | /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not |
141 | /// fully implemented at the moment. |
142 | struct ExecuteRegionOpInterface |
143 | : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< |
144 | ExecuteRegionOpInterface, scf::ExecuteRegionOp> { |
145 | |
146 | static bool supportsUnstructuredControlFlow() { return true; } |
147 | |
148 | bool isWritable(Operation *op, Value value, |
149 | const AnalysisState &state) const { |
150 | return true; |
151 | } |
152 | |
153 | LogicalResult verifyAnalysis(Operation *op, |
154 | const AnalysisState &state) const { |
155 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
156 | // TODO: scf.execute_region with multiple yields are not supported. |
157 | if (!getUniqueYieldOp(executeRegionOp)) |
158 | return op->emitOpError(message: "op without unique scf.yield is not supported" ); |
159 | return success(); |
160 | } |
161 | |
162 | AliasingOpOperandList |
163 | getAliasingOpOperands(Operation *op, Value value, |
164 | const AnalysisState &state) const { |
165 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) |
166 | return getAliasingBranchOpOperands(op, bbArg, state); |
167 | |
168 | // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be |
169 | // any SSA value that is in scope. To allow for use-def chain traversal |
170 | // through ExecuteRegionOps in the analysis, the corresponding yield value |
171 | // is considered to be aliasing with the result. |
172 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
173 | auto it = llvm::find(Range: op->getOpResults(), Val: value); |
174 | assert(it != op->getOpResults().end() && "invalid value" ); |
175 | size_t resultNum = std::distance(first: op->getOpResults().begin(), last: it); |
176 | auto yieldOp = getUniqueYieldOp(executeRegionOp); |
177 | // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail. |
178 | if (!yieldOp) |
179 | return {}; |
180 | return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; |
181 | } |
182 | |
183 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
184 | const BufferizationOptions &options) const { |
185 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
186 | auto yieldOp = getUniqueYieldOp(executeRegionOp); |
187 | TypeRange newResultTypes(yieldOp.getResults()); |
188 | |
189 | // Create new op and move over region. |
190 | auto newOp = |
191 | rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); |
192 | newOp.getRegion().takeBody(executeRegionOp.getRegion()); |
193 | |
194 | // Bufferize every block. |
195 | for (Block &block : newOp.getRegion()) |
196 | if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, |
197 | options))) |
198 | return failure(); |
199 | |
200 | // Update all uses of the old op. |
201 | rewriter.setInsertionPointAfter(newOp); |
202 | SmallVector<Value> newResults; |
203 | for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { |
204 | if (isa<TensorType>(it.value())) { |
205 | newResults.push_back(rewriter.create<bufferization::ToTensorOp>( |
206 | executeRegionOp.getLoc(), newOp->getResult(it.index()))); |
207 | } else { |
208 | newResults.push_back(newOp->getResult(it.index())); |
209 | } |
210 | } |
211 | |
212 | // Replace old op. |
213 | rewriter.replaceOp(executeRegionOp, newResults); |
214 | |
215 | return success(); |
216 | } |
217 | }; |
218 | |
219 | /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. |
220 | struct IfOpInterface |
221 | : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { |
222 | AliasingOpOperandList |
223 | getAliasingOpOperands(Operation *op, Value value, |
224 | const AnalysisState &state) const { |
225 | // IfOps do not have tensor OpOperands. The yielded value can be any SSA |
226 | // value that is in scope. To allow for use-def chain traversal through |
227 | // IfOps in the analysis, both corresponding yield values from the then/else |
228 | // branches are considered to be aliasing with the result. |
229 | auto ifOp = cast<scf::IfOp>(op); |
230 | size_t resultNum = std::distance(first: op->getOpResults().begin(), |
231 | last: llvm::find(Range: op->getOpResults(), Val: value)); |
232 | OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum); |
233 | OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum); |
234 | return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false}, |
235 | {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}}; |
236 | } |
237 | |
238 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
239 | const BufferizationOptions &options) const { |
240 | OpBuilder::InsertionGuard g(rewriter); |
241 | auto ifOp = cast<scf::IfOp>(op); |
242 | |
243 | // Compute bufferized result types. |
244 | SmallVector<Type> newTypes; |
245 | for (Value result : ifOp.getResults()) { |
246 | if (!isa<TensorType>(result.getType())) { |
247 | newTypes.push_back(result.getType()); |
248 | continue; |
249 | } |
250 | auto bufferType = bufferization::getBufferType(result, options); |
251 | if (failed(bufferType)) |
252 | return failure(); |
253 | newTypes.push_back(*bufferType); |
254 | } |
255 | |
256 | // Create new op. |
257 | rewriter.setInsertionPoint(ifOp); |
258 | auto newIfOp = |
259 | rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), |
260 | /*withElseRegion=*/true); |
261 | |
262 | // Move over then/else blocks. |
263 | rewriter.mergeBlocks(source: ifOp.thenBlock(), dest: newIfOp.thenBlock()); |
264 | rewriter.mergeBlocks(source: ifOp.elseBlock(), dest: newIfOp.elseBlock()); |
265 | |
266 | // Replace op results. |
267 | replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); |
268 | |
269 | return success(); |
270 | } |
271 | |
272 | FailureOr<BaseMemRefType> |
273 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
274 | SmallVector<Value> &invocationStack) const { |
275 | auto ifOp = cast<scf::IfOp>(op); |
276 | auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator()); |
277 | auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator()); |
278 | assert(value.getDefiningOp() == op && "invalid valid" ); |
279 | |
280 | // Determine buffer types of the true/false branches. |
281 | auto opResult = cast<OpResult>(Val&: value); |
282 | auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); |
283 | auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); |
284 | BaseMemRefType thenBufferType, elseBufferType; |
285 | if (isa<BaseMemRefType>(thenValue.getType())) { |
286 | // True branch was already bufferized. |
287 | thenBufferType = cast<BaseMemRefType>(thenValue.getType()); |
288 | } else { |
289 | auto maybeBufferType = |
290 | bufferization::getBufferType(value: thenValue, options, invocationStack); |
291 | if (failed(maybeBufferType)) |
292 | return failure(); |
293 | thenBufferType = *maybeBufferType; |
294 | } |
295 | if (isa<BaseMemRefType>(elseValue.getType())) { |
296 | // False branch was already bufferized. |
297 | elseBufferType = cast<BaseMemRefType>(elseValue.getType()); |
298 | } else { |
299 | auto maybeBufferType = |
300 | bufferization::getBufferType(value: elseValue, options, invocationStack); |
301 | if (failed(maybeBufferType)) |
302 | return failure(); |
303 | elseBufferType = *maybeBufferType; |
304 | } |
305 | |
306 | // Best case: Both branches have the exact same buffer type. |
307 | if (thenBufferType == elseBufferType) |
308 | return thenBufferType; |
309 | |
310 | // Memory space mismatch. |
311 | if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace()) |
312 | return op->emitError(message: "inconsistent memory space on then/else branches" ); |
313 | |
314 | // Layout maps are different: Promote to fully dynamic layout map. |
315 | return getMemRefTypeWithFullyDynamicLayout( |
316 | tensorType: cast<TensorType>(Val: opResult.getType()), memorySpace: thenBufferType.getMemorySpace()); |
317 | } |
318 | }; |
319 | |
320 | /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that |
321 | /// yields memrefs. |
322 | struct IndexSwitchOpInterface |
323 | : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface, |
324 | scf::IndexSwitchOp> { |
325 | AliasingOpOperandList |
326 | getAliasingOpOperands(Operation *op, Value value, |
327 | const AnalysisState &state) const { |
328 | // IndexSwitchOps do not have tensor OpOperands. The yielded value can be |
329 | // any SSA. This is similar to IfOps. |
330 | auto switchOp = cast<scf::IndexSwitchOp>(op); |
331 | int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber(); |
332 | AliasingOpOperandList result; |
333 | for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { |
334 | auto yieldOp = |
335 | cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator()); |
336 | result.addAlias(alias: AliasingOpOperand(&yieldOp->getOpOperand(resultNum), |
337 | BufferRelation::Equivalent, |
338 | /*isDefinite=*/false)); |
339 | } |
340 | auto defaultYieldOp = |
341 | cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator()); |
342 | result.addAlias(alias: AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum), |
343 | BufferRelation::Equivalent, |
344 | /*isDefinite=*/false)); |
345 | return result; |
346 | } |
347 | |
348 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
349 | const BufferizationOptions &options) const { |
350 | OpBuilder::InsertionGuard g(rewriter); |
351 | auto switchOp = cast<scf::IndexSwitchOp>(op); |
352 | |
353 | // Compute bufferized result types. |
354 | SmallVector<Type> newTypes; |
355 | for (Value result : switchOp.getResults()) { |
356 | if (!isa<TensorType>(result.getType())) { |
357 | newTypes.push_back(result.getType()); |
358 | continue; |
359 | } |
360 | auto bufferType = bufferization::getBufferType(result, options); |
361 | if (failed(bufferType)) |
362 | return failure(); |
363 | newTypes.push_back(*bufferType); |
364 | } |
365 | |
366 | // Create new op. |
367 | rewriter.setInsertionPoint(switchOp); |
368 | auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>( |
369 | switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(), |
370 | switchOp.getCases().size()); |
371 | |
372 | // Move over blocks. |
373 | for (auto [src, dest] : |
374 | llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions())) |
375 | rewriter.inlineRegionBefore(src, dest, dest.begin()); |
376 | rewriter.inlineRegionBefore(switchOp.getDefaultRegion(), |
377 | newSwitchOp.getDefaultRegion(), |
378 | newSwitchOp.getDefaultRegion().begin()); |
379 | |
380 | // Replace op results. |
381 | replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults()); |
382 | |
383 | return success(); |
384 | } |
385 | |
386 | FailureOr<BaseMemRefType> |
387 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
388 | SmallVector<Value> &invocationStack) const { |
389 | auto switchOp = cast<scf::IndexSwitchOp>(op); |
390 | assert(value.getDefiningOp() == op && "invalid value" ); |
391 | int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber(); |
392 | |
393 | // Helper function to get buffer type of a case. |
394 | SmallVector<BaseMemRefType> yieldedTypes; |
395 | auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> { |
396 | auto yieldOp = cast<scf::YieldOp>(b.getTerminator()); |
397 | Value yieldedValue = yieldOp->getOperand(resultNum); |
398 | if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType())) |
399 | return bufferType; |
400 | auto maybeBufferType = |
401 | bufferization::getBufferType(yieldedValue, options, invocationStack); |
402 | if (failed(maybeBufferType)) |
403 | return failure(); |
404 | return maybeBufferType; |
405 | }; |
406 | |
407 | // Compute buffer type of the default case. |
408 | auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock()); |
409 | if (failed(maybeBufferType)) |
410 | return failure(); |
411 | BaseMemRefType bufferType = *maybeBufferType; |
412 | |
413 | // Compute buffer types of all other cases. |
414 | for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { |
415 | auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i)); |
416 | if (failed(yieldedBufferType)) |
417 | return failure(); |
418 | |
419 | // Best case: Both branches have the exact same buffer type. |
420 | if (bufferType == *yieldedBufferType) |
421 | continue; |
422 | |
423 | // Memory space mismatch. |
424 | if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace()) |
425 | return op->emitError(message: "inconsistent memory space on switch cases" ); |
426 | |
427 | // Layout maps are different: Promote to fully dynamic layout map. |
428 | bufferType = getMemRefTypeWithFullyDynamicLayout( |
429 | tensorType: cast<TensorType>(Val: value.getType()), memorySpace: bufferType.getMemorySpace()); |
430 | } |
431 | |
432 | return bufferType; |
433 | } |
434 | }; |
435 | |
436 | /// Helper function for loop bufferization. Return the indices of all values |
437 | /// that have a tensor type. |
438 | static DenseSet<int64_t> getTensorIndices(ValueRange values) { |
439 | DenseSet<int64_t> result; |
440 | for (const auto &it : llvm::enumerate(First&: values)) |
441 | if (isa<TensorType>(Val: it.value().getType())) |
442 | result.insert(V: it.index()); |
443 | return result; |
444 | } |
445 | |
446 | /// Helper function for loop bufferization. Return the indices of all |
447 | /// bbArg/yielded value pairs who's buffer relation is "Equivalent". |
448 | DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, |
449 | ValueRange yieldedValues, |
450 | const AnalysisState &state) { |
451 | unsigned int minSize = std::min(a: bbArgs.size(), b: yieldedValues.size()); |
452 | DenseSet<int64_t> result; |
453 | for (unsigned int i = 0; i < minSize; ++i) { |
454 | if (!isa<TensorType>(Val: bbArgs[i].getType()) || |
455 | !isa<TensorType>(Val: yieldedValues[i].getType())) |
456 | continue; |
457 | if (state.areEquivalentBufferizedValues(v1: bbArgs[i], v2: yieldedValues[i])) |
458 | result.insert(V: i); |
459 | } |
460 | return result; |
461 | } |
462 | |
463 | /// Helper function for loop bufferization. Return the bufferized values of the |
464 | /// given OpOperands. If an operand is not a tensor, return the original value. |
465 | static FailureOr<SmallVector<Value>> |
466 | getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, |
467 | const BufferizationOptions &options) { |
468 | SmallVector<Value> result; |
469 | for (OpOperand &opOperand : operands) { |
470 | if (isa<TensorType>(Val: opOperand.get().getType())) { |
471 | FailureOr<Value> resultBuffer = |
472 | getBuffer(rewriter, value: opOperand.get(), options); |
473 | if (failed(result: resultBuffer)) |
474 | return failure(); |
475 | result.push_back(Elt: *resultBuffer); |
476 | } else { |
477 | result.push_back(Elt: opOperand.get()); |
478 | } |
479 | } |
480 | return result; |
481 | } |
482 | |
483 | /// Helper function for loop bufferization. Given a list of bbArgs of the new |
484 | /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into |
485 | /// ToTensorOps, so that the block body can be moved over to the new op. |
486 | static SmallVector<Value> |
487 | getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, |
488 | const DenseSet<int64_t> &tensorIndices) { |
489 | SmallVector<Value> result; |
490 | for (const auto &it : llvm::enumerate(First&: bbArgs)) { |
491 | size_t idx = it.index(); |
492 | Value val = it.value(); |
493 | if (tensorIndices.contains(V: idx)) { |
494 | result.push_back( |
495 | rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val) |
496 | .getResult()); |
497 | } else { |
498 | result.push_back(Elt: val); |
499 | } |
500 | } |
501 | return result; |
502 | } |
503 | |
504 | /// Compute the bufferized type of a loop iter_arg. This type must be equal to |
505 | /// the bufferized type of the corresponding init_arg and the bufferized type |
506 | /// of the corresponding yielded value. |
507 | /// |
508 | /// This function uses bufferization::getBufferType to compute the bufferized |
509 | /// type of the init_arg and of the yielded value. (The computation of the |
510 | /// bufferized yielded value type usually requires computing the bufferized type |
511 | /// of the iter_arg again; the implementation of getBufferType traces back the |
512 | /// use-def chain of the given value and computes a buffer type along the way.) |
513 | /// If both buffer types are equal, no casts are needed the computed buffer type |
514 | /// can be used directly. Otherwise, the buffer types can only differ in their |
515 | /// layout map and a cast must be inserted. |
516 | static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType( |
517 | Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue, |
518 | const BufferizationOptions &options, SmallVector<Value> &invocationStack) { |
519 | // Determine the buffer type of the init_arg. |
520 | auto initArgBufferType = |
521 | bufferization::getBufferType(initArg, options, invocationStack); |
522 | if (failed(initArgBufferType)) |
523 | return failure(); |
524 | |
525 | if (llvm::count(Range&: invocationStack, Element: iterArg) >= 2) { |
526 | // If the iter_arg is already twice on the invocation stack, just take the |
527 | // type of the init_arg. This is to avoid infinite loops when calculating |
528 | // the buffer type. This will most likely result in computing a memref type |
529 | // with a fully dynamic layout map. |
530 | |
531 | // Note: For more precise layout map computation, a fixpoint iteration could |
532 | // be done (i.e., re-computing the yielded buffer type until the bufferized |
533 | // iter_arg type no longer changes). This current implementation immediately |
534 | // switches to a fully dynamic layout map when a mismatch between bufferized |
535 | // init_arg type and bufferized yield value type is detected. |
536 | return *initArgBufferType; |
537 | } |
538 | |
539 | // Compute the buffer type of the yielded value. |
540 | BaseMemRefType yieldedValueBufferType; |
541 | if (isa<BaseMemRefType>(Val: yieldedValue.getType())) { |
542 | // scf.yield was already bufferized. |
543 | yieldedValueBufferType = cast<BaseMemRefType>(Val: yieldedValue.getType()); |
544 | } else { |
545 | // Note: This typically triggers a recursive call for the buffer type of |
546 | // the iter_arg. |
547 | auto maybeBufferType = |
548 | bufferization::getBufferType(yieldedValue, options, invocationStack); |
549 | if (failed(maybeBufferType)) |
550 | return failure(); |
551 | yieldedValueBufferType = *maybeBufferType; |
552 | } |
553 | |
554 | // If yielded type and init_arg type are the same, use that type directly. |
555 | if (*initArgBufferType == yieldedValueBufferType) |
556 | return yieldedValueBufferType; |
557 | |
558 | // If there is a mismatch between the yielded buffer type and the init_arg |
559 | // buffer type, the buffer type must be promoted to a fully dynamic layout |
560 | // map. |
561 | auto yieldedBufferType = cast<BaseMemRefType>(Val&: yieldedValueBufferType); |
562 | auto iterTensorType = cast<TensorType>(Val: iterArg.getType()); |
563 | auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType); |
564 | if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace()) |
565 | return loopOp->emitOpError( |
566 | message: "init_arg and yielded value bufferize to inconsistent memory spaces" ); |
567 | #ifndef NDEBUG |
568 | if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) { |
569 | assert( |
570 | llvm::all_equal({yieldedRankedBufferType.getShape(), |
571 | cast<MemRefType>(initBufferType).getShape(), |
572 | cast<RankedTensorType>(iterTensorType).getShape()}) && |
573 | "expected same shape" ); |
574 | } |
575 | #endif // NDEBUG |
576 | return getMemRefTypeWithFullyDynamicLayout( |
577 | tensorType: iterTensorType, memorySpace: yieldedBufferType.getMemorySpace()); |
578 | } |
579 | |
580 | /// Return `true` if the given loop may have 0 iterations. |
581 | bool mayHaveZeroIterations(scf::ForOp forOp) { |
582 | std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound()); |
583 | std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound()); |
584 | if (!lb.has_value() || !ub.has_value()) |
585 | return true; |
586 | return *ub <= *lb; |
587 | } |
588 | |
589 | /// Bufferization of scf.for. Replace with a new scf.for that operates on |
590 | /// memrefs. |
591 | struct ForOpInterface |
592 | : public BufferizableOpInterface::ExternalModel<ForOpInterface, |
593 | scf::ForOp> { |
594 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
595 | const AnalysisState &state) const { |
596 | auto forOp = cast<scf::ForOp>(op); |
597 | |
598 | // If the loop has zero iterations, the results of the op are their |
599 | // corresponding init_args, meaning that the init_args bufferize to a read. |
600 | if (mayHaveZeroIterations(forOp)) |
601 | return true; |
602 | |
603 | // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of |
604 | // its matching bbArg may. |
605 | return state.isValueRead(value: forOp.getTiedLoopRegionIterArg(&opOperand)); |
606 | } |
607 | |
608 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
609 | const AnalysisState &state) const { |
610 | // Tensor iter_args of scf::ForOps are always considered as a write. |
611 | return true; |
612 | } |
613 | |
614 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
615 | const AnalysisState &state) const { |
616 | auto forOp = cast<scf::ForOp>(op); |
617 | OpResult opResult = forOp.getTiedLoopResult(&opOperand); |
618 | BufferRelation relation = bufferRelation(op, opResult, state); |
619 | return {{opResult, relation, |
620 | /*isDefinite=*/relation == BufferRelation::Equivalent}}; |
621 | } |
622 | |
623 | BufferRelation bufferRelation(Operation *op, OpResult opResult, |
624 | const AnalysisState &state) const { |
625 | // ForOp results are equivalent to their corresponding init_args if the |
626 | // corresponding iter_args and yield values are equivalent. |
627 | auto forOp = cast<scf::ForOp>(op); |
628 | BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); |
629 | bool equivalentYield = state.areEquivalentBufferizedValues( |
630 | v1: bbArg, v2: forOp.getTiedLoopYieldedValue(bbArg)->get()); |
631 | return equivalentYield ? BufferRelation::Equivalent |
632 | : BufferRelation::Unknown; |
633 | } |
634 | |
635 | bool isWritable(Operation *op, Value value, |
636 | const AnalysisState &state) const { |
637 | // Interestingly, scf::ForOp's bbArg can **always** be viewed |
638 | // inplace from the perspective of ops nested under: |
639 | // 1. Either the matching iter operand is not bufferized inplace and an |
640 | // alloc + optional copy makes the bbArg itself inplaceable. |
641 | // 2. Or the matching iter operand is bufferized inplace and bbArg just |
642 | // bufferizes to that too. |
643 | return true; |
644 | } |
645 | |
646 | LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, |
647 | const AnalysisState &state) const { |
648 | auto bufferizableOp = cast<BufferizableOpInterface>(op); |
649 | if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) |
650 | return failure(); |
651 | |
652 | if (!state.getOptions().enforceAliasingInvariants) |
653 | return success(); |
654 | |
655 | // According to the `getAliasing...` implementations, a bufferized OpResult |
656 | // may alias only with the corresponding bufferized init_arg (or with a |
657 | // newly allocated buffer) and not with other buffers defined outside of the |
658 | // loop. I.e., the i-th OpResult may alias with the i-th init_arg; |
659 | // but not with any other OpOperand. |
660 | auto forOp = cast<scf::ForOp>(op); |
661 | auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
662 | OpBuilder::InsertionGuard g(rewriter); |
663 | rewriter.setInsertionPoint(yieldOp); |
664 | |
665 | // Indices of all iter_args that have tensor type. These are the ones that |
666 | // are bufferized. |
667 | DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); |
668 | // For every yielded value, does it alias with something defined outside of |
669 | // the loop? |
670 | SmallVector<Value> yieldValues; |
671 | for (const auto it : llvm::enumerate(yieldOp.getResults())) { |
672 | // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this |
673 | // type cannot be used in the signature of `resolveConflicts` because the |
674 | // op interface is in the "IR" build unit and the `OneShotAnalysisState` |
675 | // is defined in the "Transforms" build unit. |
676 | if (!indices.contains(it.index()) || |
677 | doesNotAliasExternalValue( |
678 | it.value(), &forOp.getRegion(), |
679 | /*exceptions=*/forOp.getRegionIterArg(it.index()), |
680 | static_cast<const OneShotAnalysisState &>(state))) { |
681 | yieldValues.push_back(it.value()); |
682 | continue; |
683 | } |
684 | FailureOr<Value> alloc = allocateTensorForShapedValue( |
685 | rewriter, yieldOp.getLoc(), it.value(), state.getOptions()); |
686 | if (failed(alloc)) |
687 | return failure(); |
688 | yieldValues.push_back(*alloc); |
689 | } |
690 | |
691 | rewriter.modifyOpInPlace( |
692 | yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); |
693 | return success(); |
694 | } |
695 | |
696 | FailureOr<BaseMemRefType> |
697 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
698 | SmallVector<Value> &invocationStack) const { |
699 | auto forOp = cast<scf::ForOp>(op); |
700 | assert(getOwnerOfValue(value) == op && "invalid value" ); |
701 | assert(isa<TensorType>(value.getType()) && "expected tensor type" ); |
702 | |
703 | if (auto opResult = dyn_cast<OpResult>(Val&: value)) { |
704 | // The type of an OpResult must match the corresponding iter_arg type. |
705 | BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); |
706 | return bufferization::getBufferType(bbArg, options, invocationStack); |
707 | } |
708 | |
709 | // Compute result/argument number. |
710 | BlockArgument bbArg = cast<BlockArgument>(Val&: value); |
711 | unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber(); |
712 | |
713 | // Compute the bufferized type. |
714 | auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
715 | Value yieldedValue = yieldOp.getOperand(resultNum); |
716 | BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum]; |
717 | Value initArg = forOp.getInitArgs()[resultNum]; |
718 | return computeLoopRegionIterArgBufferType( |
719 | loopOp: op, iterArg, initArg, yieldedValue, options, invocationStack); |
720 | } |
721 | |
722 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
723 | const BufferizationOptions &options) const { |
724 | auto forOp = cast<scf::ForOp>(op); |
725 | Block *oldLoopBody = forOp.getBody(); |
726 | |
727 | // Indices of all iter_args that have tensor type. These are the ones that |
728 | // are bufferized. |
729 | DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); |
730 | |
731 | // The new memref init_args of the loop. |
732 | FailureOr<SmallVector<Value>> maybeInitArgs = |
733 | getBuffers(rewriter, forOp.getInitArgsMutable(), options); |
734 | if (failed(result: maybeInitArgs)) |
735 | return failure(); |
736 | SmallVector<Value> initArgs = *maybeInitArgs; |
737 | |
738 | // Cast init_args if necessary. |
739 | SmallVector<Value> castedInitArgs; |
740 | for (const auto &it : llvm::enumerate(initArgs)) { |
741 | Value initArg = it.value(); |
742 | Value result = forOp->getResult(it.index()); |
743 | // If the type is not a tensor, bufferization doesn't need to touch it. |
744 | if (!isa<TensorType>(result.getType())) { |
745 | castedInitArgs.push_back(initArg); |
746 | continue; |
747 | } |
748 | auto targetType = bufferization::getBufferType(result, options); |
749 | if (failed(targetType)) |
750 | return failure(); |
751 | castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); |
752 | } |
753 | |
754 | // Construct a new scf.for op with memref instead of tensor values. |
755 | auto newForOp = rewriter.create<scf::ForOp>( |
756 | forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
757 | forOp.getStep(), castedInitArgs); |
758 | newForOp->setAttrs(forOp->getAttrs()); |
759 | Block *loopBody = newForOp.getBody(); |
760 | |
761 | // Set up new iter_args. The loop body uses tensors, so wrap the (memref) |
762 | // iter_args of the new loop in ToTensorOps. |
763 | rewriter.setInsertionPointToStart(loopBody); |
764 | SmallVector<Value> iterArgs = |
765 | getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); |
766 | iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); |
767 | |
768 | // Move loop body to new loop. |
769 | rewriter.mergeBlocks(source: oldLoopBody, dest: loopBody, argValues: iterArgs); |
770 | |
771 | // Replace loop results. |
772 | replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); |
773 | |
774 | return success(); |
775 | } |
776 | |
777 | /// Assert that yielded values of an scf.for op are equivalent to their |
778 | /// corresponding bbArgs. In that case, the buffer relations of the |
779 | /// corresponding OpResults are "Equivalent". |
780 | /// |
781 | /// If this is not the case, an allocs+copies are inserted and yielded from |
782 | /// the loop. This could be a performance problem, so it must be explicitly |
783 | /// activated with `alloc-return-allocs`. |
784 | LogicalResult verifyAnalysis(Operation *op, |
785 | const AnalysisState &state) const { |
786 | const auto &options = |
787 | static_cast<const OneShotBufferizationOptions &>(state.getOptions()); |
788 | if (options.allowReturnAllocsFromLoops) |
789 | return success(); |
790 | |
791 | auto forOp = cast<scf::ForOp>(op); |
792 | auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
793 | for (OpResult opResult : op->getOpResults()) { |
794 | if (!isa<TensorType>(Val: opResult.getType())) |
795 | continue; |
796 | |
797 | // Note: This is overly strict. We should check for aliasing bufferized |
798 | // values. But we don't have a "must-alias" analysis yet. |
799 | if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) |
800 | return yieldOp->emitError() |
801 | << "Yield operand #" << opResult.getResultNumber() |
802 | << " is not equivalent to the corresponding iter bbArg" ; |
803 | } |
804 | |
805 | return success(); |
806 | } |
807 | }; |
808 | |
809 | /// Bufferization of scf.while. Replace with a new scf.while that operates on |
810 | /// memrefs. |
811 | struct WhileOpInterface |
812 | : public BufferizableOpInterface::ExternalModel<WhileOpInterface, |
813 | scf::WhileOp> { |
814 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
815 | const AnalysisState &state) const { |
816 | // Tensor iter_args of scf::WhileOps are always considered as a read. |
817 | return true; |
818 | } |
819 | |
820 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
821 | const AnalysisState &state) const { |
822 | // Tensor iter_args of scf::WhileOps are always considered as a write. |
823 | return true; |
824 | } |
825 | |
826 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
827 | const AnalysisState &state) const { |
828 | auto whileOp = cast<scf::WhileOp>(op); |
829 | unsigned int idx = opOperand.getOperandNumber(); |
830 | |
831 | // The OpResults and OpOperands may not match. They may not even have the |
832 | // same type. The number of OpResults and OpOperands can also differ. |
833 | if (idx >= op->getNumResults() || |
834 | opOperand.get().getType() != op->getResult(idx).getType()) |
835 | return {}; |
836 | |
837 | // The only aliasing OpResult may be the one at the same index. |
838 | OpResult opResult = whileOp->getResult(idx); |
839 | BufferRelation relation = bufferRelation(op, opResult, state); |
840 | return {{opResult, relation, |
841 | /*isDefinite=*/relation == BufferRelation::Equivalent}}; |
842 | } |
843 | |
844 | BufferRelation bufferRelation(Operation *op, OpResult opResult, |
845 | const AnalysisState &state) const { |
846 | // WhileOp results are equivalent to their corresponding init_args if the |
847 | // corresponding iter_args and yield values are equivalent (for both the |
848 | // "before" and the "after" block). |
849 | unsigned int resultNumber = opResult.getResultNumber(); |
850 | auto whileOp = cast<scf::WhileOp>(op); |
851 | |
852 | // The "before" region bbArgs and the OpResults may not match. |
853 | if (resultNumber >= whileOp.getBeforeArguments().size()) |
854 | return BufferRelation::Unknown; |
855 | if (opResult.getType() != |
856 | whileOp.getBeforeArguments()[resultNumber].getType()) |
857 | return BufferRelation::Unknown; |
858 | |
859 | auto conditionOp = whileOp.getConditionOp(); |
860 | BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; |
861 | Value conditionOperand = conditionOp.getArgs()[resultNumber]; |
862 | bool equivCondition = |
863 | state.areEquivalentBufferizedValues(v1: conditionBbArg, v2: conditionOperand); |
864 | |
865 | auto yieldOp = whileOp.getYieldOp(); |
866 | BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; |
867 | Value yieldOperand = yieldOp.getOperand(resultNumber); |
868 | bool equivYield = |
869 | state.areEquivalentBufferizedValues(v1: bodyBbArg, v2: yieldOperand); |
870 | |
871 | return equivCondition && equivYield ? BufferRelation::Equivalent |
872 | : BufferRelation::Unknown; |
873 | } |
874 | |
875 | bool isWritable(Operation *op, Value value, |
876 | const AnalysisState &state) const { |
877 | // Interestingly, scf::WhileOp's bbArg can **always** be viewed |
878 | // inplace from the perspective of ops nested under: |
879 | // 1. Either the matching iter operand is not bufferized inplace and an |
880 | // alloc + optional copy makes the bbArg itself inplaceable. |
881 | // 2. Or the matching iter operand is bufferized inplace and bbArg just |
882 | // bufferizes to that too. |
883 | return true; |
884 | } |
885 | |
886 | LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, |
887 | const AnalysisState &state) const { |
888 | auto bufferizableOp = cast<BufferizableOpInterface>(op); |
889 | if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) |
890 | return failure(); |
891 | |
892 | if (!state.getOptions().enforceAliasingInvariants) |
893 | return success(); |
894 | |
895 | // According to the `getAliasing...` implementations, a bufferized OpResult |
896 | // may alias only with the corresponding bufferized init_arg and with no |
897 | // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; |
898 | // but not with any other OpOperand. If a corresponding OpResult/init_arg |
899 | // pair bufferizes to equivalent buffers, this aliasing requirement is |
900 | // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. |
901 | // (New buffer copies do not alias with any buffer.) |
902 | OpBuilder::InsertionGuard g(rewriter); |
903 | auto whileOp = cast<scf::WhileOp>(op); |
904 | auto conditionOp = whileOp.getConditionOp(); |
905 | |
906 | // For every yielded value, is the value equivalent to its corresponding |
907 | // bbArg? |
908 | DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( |
909 | whileOp.getBeforeArguments(), conditionOp.getArgs(), state); |
910 | DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers( |
911 | whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state); |
912 | |
913 | // Update "before" region. |
914 | rewriter.setInsertionPoint(conditionOp); |
915 | SmallVector<Value> beforeYieldValues; |
916 | for (int64_t idx = 0; |
917 | idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) { |
918 | Value value = conditionOp.getArgs()[idx]; |
919 | if (!isa<TensorType>(Val: value.getType()) || |
920 | (equivalentYieldsAfter.contains(V: idx) && |
921 | equivalentYieldsBefore.contains(V: idx))) { |
922 | beforeYieldValues.push_back(Elt: value); |
923 | continue; |
924 | } |
925 | FailureOr<Value> alloc = allocateTensorForShapedValue( |
926 | rewriter, conditionOp.getLoc(), value, state.getOptions()); |
927 | if (failed(result: alloc)) |
928 | return failure(); |
929 | beforeYieldValues.push_back(Elt: *alloc); |
930 | } |
931 | rewriter.modifyOpInPlace(conditionOp, [&]() { |
932 | conditionOp.getArgsMutable().assign(beforeYieldValues); |
933 | }); |
934 | |
935 | return success(); |
936 | } |
937 | |
938 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
939 | const BufferizationOptions &options) const { |
940 | auto whileOp = cast<scf::WhileOp>(op); |
941 | |
942 | // Indices of all bbArgs that have tensor type. These are the ones that |
943 | // are bufferized. The "before" and "after" regions may have different args. |
944 | DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); |
945 | DenseSet<int64_t> indicesAfter = |
946 | getTensorIndices(whileOp.getAfterArguments()); |
947 | |
948 | // The new memref init_args of the loop. |
949 | FailureOr<SmallVector<Value>> maybeInitArgs = |
950 | getBuffers(rewriter, whileOp.getInitsMutable(), options); |
951 | if (failed(result: maybeInitArgs)) |
952 | return failure(); |
953 | SmallVector<Value> initArgs = *maybeInitArgs; |
954 | |
955 | // Cast init_args if necessary. |
956 | SmallVector<Value> castedInitArgs; |
957 | for (const auto &it : llvm::enumerate(initArgs)) { |
958 | Value initArg = it.value(); |
959 | Value beforeArg = whileOp.getBeforeArguments()[it.index()]; |
960 | // If the type is not a tensor, bufferization doesn't need to touch it. |
961 | if (!isa<TensorType>(beforeArg.getType())) { |
962 | castedInitArgs.push_back(initArg); |
963 | continue; |
964 | } |
965 | auto targetType = bufferization::getBufferType(beforeArg, options); |
966 | if (failed(targetType)) |
967 | return failure(); |
968 | castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); |
969 | } |
970 | |
971 | // The result types of a WhileOp are the same as the "after" bbArg types. |
972 | SmallVector<Type> argsTypesAfter = llvm::to_vector( |
973 | llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { |
974 | if (!isa<TensorType>(Val: bbArg.getType())) |
975 | return bbArg.getType(); |
976 | // TODO: error handling |
977 | return llvm::cast<Type>( |
978 | Val: *bufferization::getBufferType(value: bbArg, options)); |
979 | })); |
980 | |
981 | // Construct a new scf.while op with memref instead of tensor values. |
982 | ValueRange argsRangeBefore(castedInitArgs); |
983 | TypeRange argsTypesBefore(argsRangeBefore); |
984 | auto newWhileOp = rewriter.create<scf::WhileOp>( |
985 | whileOp.getLoc(), argsTypesAfter, castedInitArgs); |
986 | |
987 | // Add before/after regions to the new op. |
988 | SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(), |
989 | whileOp.getLoc()); |
990 | SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(), |
991 | whileOp.getLoc()); |
992 | Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); |
993 | newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); |
994 | Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); |
995 | newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); |
996 | |
997 | // Set up new iter_args and move the loop condition block to the new op. |
998 | // The old block uses tensors, so wrap the (memref) bbArgs of the new block |
999 | // in ToTensorOps. |
1000 | rewriter.setInsertionPointToStart(newBeforeBody); |
1001 | SmallVector<Value> newBeforeArgs = getBbArgReplacements( |
1002 | rewriter, newWhileOp.getBeforeArguments(), indicesBefore); |
1003 | rewriter.mergeBlocks(source: whileOp.getBeforeBody(), dest: newBeforeBody, argValues: newBeforeArgs); |
1004 | |
1005 | // Set up new iter_args and move the loop body block to the new op. |
1006 | // The old block uses tensors, so wrap the (memref) bbArgs of the new block |
1007 | // in ToTensorOps. |
1008 | rewriter.setInsertionPointToStart(newAfterBody); |
1009 | SmallVector<Value> newAfterArgs = getBbArgReplacements( |
1010 | rewriter, newWhileOp.getAfterArguments(), indicesAfter); |
1011 | rewriter.mergeBlocks(source: whileOp.getAfterBody(), dest: newAfterBody, argValues: newAfterArgs); |
1012 | |
1013 | // Replace loop results. |
1014 | replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); |
1015 | |
1016 | return success(); |
1017 | } |
1018 | |
1019 | FailureOr<BaseMemRefType> |
1020 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
1021 | SmallVector<Value> &invocationStack) const { |
1022 | auto whileOp = cast<scf::WhileOp>(op); |
1023 | assert(getOwnerOfValue(value) == op && "invalid value" ); |
1024 | assert(isa<TensorType>(value.getType()) && "expected tensor type" ); |
1025 | |
1026 | // Case 1: Block argument of the "before" region. |
1027 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) { |
1028 | if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { |
1029 | Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; |
1030 | auto yieldOp = whileOp.getYieldOp(); |
1031 | Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber()); |
1032 | return computeLoopRegionIterArgBufferType( |
1033 | loopOp: op, iterArg: bbArg, initArg, yieldedValue, options, invocationStack); |
1034 | } |
1035 | } |
1036 | |
1037 | // Case 2: OpResult of the loop or block argument of the "after" region. |
1038 | // The bufferized "after" bbArg type can be directly computed from the |
1039 | // bufferized "before" bbArg type. |
1040 | unsigned resultNum; |
1041 | if (auto opResult = dyn_cast<OpResult>(Val&: value)) { |
1042 | resultNum = opResult.getResultNumber(); |
1043 | } else if (cast<BlockArgument>(Val&: value).getOwner()->getParent() == |
1044 | &whileOp.getAfter()) { |
1045 | resultNum = cast<BlockArgument>(Val&: value).getArgNumber(); |
1046 | } else { |
1047 | llvm_unreachable("invalid value" ); |
1048 | } |
1049 | Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; |
1050 | if (!isa<TensorType>(Val: conditionYieldedVal.getType())) { |
1051 | // scf.condition was already bufferized. |
1052 | return cast<BaseMemRefType>(Val: conditionYieldedVal.getType()); |
1053 | } |
1054 | return bufferization::getBufferType(conditionYieldedVal, options, |
1055 | invocationStack); |
1056 | } |
1057 | |
1058 | /// Assert that yielded values of an scf.while op are equivalent to their |
1059 | /// corresponding bbArgs. In that case, the buffer relations of the |
1060 | /// corresponding OpResults are "Equivalent". |
1061 | /// |
1062 | /// If this is not the case, allocs+copies are inserted and yielded from |
1063 | /// the loop. This could be a performance problem, so it must be explicitly |
1064 | /// activated with `allow-return-allocs`. |
1065 | /// |
1066 | /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the |
1067 | /// equivalence condition must be checked for both. |
1068 | LogicalResult verifyAnalysis(Operation *op, |
1069 | const AnalysisState &state) const { |
1070 | auto whileOp = cast<scf::WhileOp>(op); |
1071 | const auto &options = |
1072 | static_cast<const OneShotBufferizationOptions &>(state.getOptions()); |
1073 | if (options.allowReturnAllocsFromLoops) |
1074 | return success(); |
1075 | |
1076 | auto conditionOp = whileOp.getConditionOp(); |
1077 | for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { |
1078 | Block *block = conditionOp->getBlock(); |
1079 | if (!isa<TensorType>(it.value().getType())) |
1080 | continue; |
1081 | if (it.index() >= block->getNumArguments() || |
1082 | !state.areEquivalentBufferizedValues(it.value(), |
1083 | block->getArgument(it.index()))) |
1084 | return conditionOp->emitError() |
1085 | << "Condition arg #" << it.index() |
1086 | << " is not equivalent to the corresponding iter bbArg" ; |
1087 | } |
1088 | |
1089 | auto yieldOp = whileOp.getYieldOp(); |
1090 | for (const auto &it : llvm::enumerate(yieldOp.getResults())) { |
1091 | Block *block = yieldOp->getBlock(); |
1092 | if (!isa<TensorType>(it.value().getType())) |
1093 | continue; |
1094 | if (it.index() >= block->getNumArguments() || |
1095 | !state.areEquivalentBufferizedValues(it.value(), |
1096 | block->getArgument(it.index()))) |
1097 | return yieldOp->emitError() |
1098 | << "Yield operand #" << it.index() |
1099 | << " is not equivalent to the corresponding iter bbArg" ; |
1100 | } |
1101 | |
1102 | return success(); |
1103 | } |
1104 | }; |
1105 | |
1106 | /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so |
1107 | /// this is for analysis only. |
1108 | struct YieldOpInterface |
1109 | : public BufferizableOpInterface::ExternalModel<YieldOpInterface, |
1110 | scf::YieldOp> { |
1111 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
1112 | const AnalysisState &state) const { |
1113 | return true; |
1114 | } |
1115 | |
1116 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
1117 | const AnalysisState &state) const { |
1118 | return false; |
1119 | } |
1120 | |
1121 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
1122 | const AnalysisState &state) const { |
1123 | if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) { |
1124 | return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()), |
1125 | BufferRelation::Equivalent, /*isDefinite=*/false}}; |
1126 | } |
1127 | if (isa<scf::ExecuteRegionOp>(op->getParentOp())) |
1128 | return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()), |
1129 | BufferRelation::Equivalent}}; |
1130 | return {}; |
1131 | } |
1132 | |
1133 | bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
1134 | const AnalysisState &state) const { |
1135 | // Yield operands always bufferize inplace. Otherwise, an alloc + copy |
1136 | // may be generated inside the block. We should not return/yield allocations |
1137 | // when possible. |
1138 | return true; |
1139 | } |
1140 | |
1141 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
1142 | const BufferizationOptions &options) const { |
1143 | auto yieldOp = cast<scf::YieldOp>(op); |
1144 | if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp, |
1145 | scf::WhileOp>(yieldOp->getParentOp())) |
1146 | return yieldOp->emitError("unsupported scf::YieldOp parent" ); |
1147 | |
1148 | SmallVector<Value> newResults; |
1149 | for (const auto &it : llvm::enumerate(yieldOp.getResults())) { |
1150 | Value value = it.value(); |
1151 | if (isa<TensorType>(value.getType())) { |
1152 | FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options); |
1153 | if (failed(maybeBuffer)) |
1154 | return failure(); |
1155 | Value buffer = *maybeBuffer; |
1156 | // We may have to cast the value before yielding it. |
1157 | if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>( |
1158 | yieldOp->getParentOp())) { |
1159 | FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
1160 | yieldOp->getParentOp()->getResult(it.index()), options); |
1161 | if (failed(resultType)) |
1162 | return failure(); |
1163 | buffer = castBuffer(rewriter, buffer, *resultType); |
1164 | } else if (auto whileOp = |
1165 | dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) { |
1166 | FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
1167 | whileOp.getBeforeArguments()[it.index()], options); |
1168 | if (failed(resultType)) |
1169 | return failure(); |
1170 | buffer = castBuffer(rewriter, buffer, *resultType); |
1171 | } |
1172 | newResults.push_back(buffer); |
1173 | } else { |
1174 | newResults.push_back(value); |
1175 | } |
1176 | } |
1177 | |
1178 | replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults); |
1179 | return success(); |
1180 | } |
1181 | }; |
1182 | |
1183 | /// Return `true` if the given loop may have 0 iterations. |
1184 | bool mayHaveZeroIterations(scf::ForallOp forallOp) { |
1185 | for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(), |
1186 | forallOp.getMixedUpperBound())) { |
1187 | std::optional<int64_t> lbConst = getConstantIntValue(lb); |
1188 | std::optional<int64_t> ubConst = getConstantIntValue(ub); |
1189 | if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst) |
1190 | return true; |
1191 | } |
1192 | return false; |
1193 | } |
1194 | |
1195 | /// Bufferization of ForallOp. This also bufferizes the terminator of the |
1196 | /// region. There are op interfaces for the terminators (InParallelOp |
1197 | /// and ParallelInsertSliceOp), but these are only used during analysis. Not |
1198 | /// for bufferization. |
1199 | struct ForallOpInterface |
1200 | : public BufferizableOpInterface::ExternalModel<ForallOpInterface, |
1201 | ForallOp> { |
1202 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
1203 | const AnalysisState &state) const { |
1204 | auto forallOp = cast<ForallOp>(op); |
1205 | |
1206 | // If the loop has zero iterations, the results of the op are their |
1207 | // corresponding shared_outs, meaning that the shared_outs bufferize to a |
1208 | // read. |
1209 | if (mayHaveZeroIterations(forallOp)) |
1210 | return true; |
1211 | |
1212 | // scf::ForallOp alone doesn't bufferize to a memory read, one of the |
1213 | // uses of its matching bbArg may. |
1214 | return state.isValueRead(value: forallOp.getTiedBlockArgument(&opOperand)); |
1215 | } |
1216 | |
1217 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
1218 | const AnalysisState &state) const { |
1219 | // Outputs of scf::ForallOps are always considered as a write. |
1220 | return true; |
1221 | } |
1222 | |
1223 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
1224 | const AnalysisState &state) const { |
1225 | auto forallOp = cast<ForallOp>(op); |
1226 | return { |
1227 | {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}}; |
1228 | } |
1229 | |
1230 | bool isWritable(Operation *op, Value value, |
1231 | const AnalysisState &state) const { |
1232 | return true; |
1233 | } |
1234 | |
1235 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
1236 | const BufferizationOptions &options) const { |
1237 | OpBuilder::InsertionGuard guard(rewriter); |
1238 | auto forallOp = cast<ForallOp>(op); |
1239 | int64_t rank = forallOp.getRank(); |
1240 | |
1241 | // Get buffers for all output operands. |
1242 | SmallVector<Value> buffers; |
1243 | for (Value out : forallOp.getOutputs()) { |
1244 | FailureOr<Value> buffer = getBuffer(rewriter, out, options); |
1245 | if (failed(buffer)) |
1246 | return failure(); |
1247 | buffers.push_back(*buffer); |
1248 | } |
1249 | |
1250 | // Use buffers instead of block arguments. |
1251 | rewriter.setInsertionPointToStart(forallOp.getBody()); |
1252 | for (const auto &it : llvm::zip( |
1253 | forallOp.getBody()->getArguments().drop_front(rank), buffers)) { |
1254 | BlockArgument bbArg = std::get<0>(it); |
1255 | Value buffer = std::get<1>(it); |
1256 | Value bufferAsTensor = |
1257 | rewriter.create<ToTensorOp>(forallOp.getLoc(), buffer); |
1258 | bbArg.replaceAllUsesWith(bufferAsTensor); |
1259 | } |
1260 | |
1261 | // Create new ForallOp without any results and drop the automatically |
1262 | // introduced terminator. |
1263 | rewriter.setInsertionPoint(forallOp); |
1264 | ForallOp newForallOp; |
1265 | newForallOp = rewriter.create<ForallOp>( |
1266 | forallOp.getLoc(), forallOp.getMixedLowerBound(), |
1267 | forallOp.getMixedUpperBound(), forallOp.getMixedStep(), |
1268 | /*outputs=*/ValueRange(), forallOp.getMapping()); |
1269 | |
1270 | rewriter.eraseOp(op: newForallOp.getBody()->getTerminator()); |
1271 | |
1272 | // Move over block contents of the old op. |
1273 | SmallVector<Value> replacementBbArgs; |
1274 | replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(), |
1275 | newForallOp.getBody()->getArguments().end()); |
1276 | replacementBbArgs.append(forallOp.getOutputs().size(), Value()); |
1277 | rewriter.mergeBlocks(source: forallOp.getBody(), dest: newForallOp.getBody(), |
1278 | argValues: replacementBbArgs); |
1279 | |
1280 | // Remove the old op and replace all of its uses. |
1281 | replaceOpWithBufferizedValues(rewriter, op, values: buffers); |
1282 | |
1283 | return success(); |
1284 | } |
1285 | |
1286 | FailureOr<BaseMemRefType> |
1287 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
1288 | SmallVector<Value> &invocationStack) const { |
1289 | auto forallOp = cast<ForallOp>(op); |
1290 | |
1291 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) |
1292 | // A tensor block argument has the same bufferized type as the |
1293 | // corresponding output operand. |
1294 | return bufferization::getBufferType( |
1295 | value: forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack); |
1296 | |
1297 | // The bufferized result type is the same as the bufferized type of the |
1298 | // corresponding output operand. |
1299 | return bufferization::getBufferType( |
1300 | value: forallOp.getOutputs()[cast<OpResult>(Val&: value).getResultNumber()], options, |
1301 | invocationStack); |
1302 | } |
1303 | |
1304 | bool isRepetitiveRegion(Operation *op, unsigned index) const { |
1305 | auto forallOp = cast<ForallOp>(op); |
1306 | |
1307 | // This op is repetitive if it has 1 or more steps. |
1308 | // If the control variables are dynamic, it is also considered so. |
1309 | for (auto [lb, ub, step] : |
1310 | llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), |
1311 | forallOp.getMixedStep())) { |
1312 | std::optional<int64_t> lbConstant = getConstantIntValue(lb); |
1313 | if (!lbConstant) |
1314 | return true; |
1315 | |
1316 | std::optional<int64_t> ubConstant = getConstantIntValue(ub); |
1317 | if (!ubConstant) |
1318 | return true; |
1319 | |
1320 | std::optional<int64_t> stepConstant = getConstantIntValue(step); |
1321 | if (!stepConstant) |
1322 | return true; |
1323 | |
1324 | if (*lbConstant + *stepConstant < *ubConstant) |
1325 | return true; |
1326 | } |
1327 | return false; |
1328 | } |
1329 | |
1330 | bool isParallelRegion(Operation *op, unsigned index) const { |
1331 | return isRepetitiveRegion(op, index); |
1332 | } |
1333 | }; |
1334 | |
1335 | /// Nothing to do for InParallelOp. |
1336 | struct InParallelOpInterface |
1337 | : public BufferizableOpInterface::ExternalModel<InParallelOpInterface, |
1338 | InParallelOp> { |
1339 | LogicalResult bufferize(Operation *op, RewriterBase &b, |
1340 | const BufferizationOptions &options) const { |
1341 | llvm_unreachable("op does not have any tensor OpOperands / OpResults" ); |
1342 | return failure(); |
1343 | } |
1344 | }; |
1345 | |
1346 | } // namespace |
1347 | } // namespace scf |
1348 | } // namespace mlir |
1349 | |
1350 | void mlir::scf::registerBufferizableOpInterfaceExternalModels( |
1351 | DialectRegistry ®istry) { |
1352 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) { |
1353 | ConditionOp::attachInterface<ConditionOpInterface>(*ctx); |
1354 | ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); |
1355 | ForOp::attachInterface<ForOpInterface>(*ctx); |
1356 | IfOp::attachInterface<IfOpInterface>(*ctx); |
1357 | IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx); |
1358 | ForallOp::attachInterface<ForallOpInterface>(*ctx); |
1359 | InParallelOp::attachInterface<InParallelOpInterface>(*ctx); |
1360 | WhileOp::attachInterface<WhileOpInterface>(*ctx); |
1361 | YieldOp::attachInterface<YieldOpInterface>(*ctx); |
1362 | }); |
1363 | } |
1364 | |