1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
13 | #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
15 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
19 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
20 | #include "mlir/IR/Dialect.h" |
21 | #include "mlir/IR/Operation.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | |
24 | using namespace mlir; |
25 | using namespace mlir::bufferization; |
26 | using namespace mlir::scf; |
27 | |
28 | namespace mlir { |
29 | namespace scf { |
30 | namespace { |
31 | |
32 | /// Helper function for loop bufferization. Cast the given buffer to the given |
33 | /// memref type. |
34 | static Value castBuffer(OpBuilder &b, Value buffer, Type type) { |
35 | assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType"); |
36 | assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType"); |
37 | // If the buffer already has the correct type, no cast is needed. |
38 | if (buffer.getType() == type) |
39 | return buffer; |
40 | // TODO: In case `type` has a layout map that is not the fully dynamic |
41 | // one, we may not be able to cast the buffer. In that case, the loop |
42 | // iter_arg's layout map must be changed (see uses of `castBuffer`). |
43 | assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && |
44 | "scf.while op bufferization: cast incompatible"); |
45 | return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); |
46 | } |
47 | |
48 | /// Helper function for loop bufferization. Return "true" if the given value |
49 | /// is guaranteed to not alias with an external tensor apart from values in |
50 | /// `exceptions`. A value is external if it is defined outside of the given |
51 | /// region or if it is an entry block argument of the region. |
52 | static bool doesNotAliasExternalValue(Value value, Region *region, |
53 | ValueRange exceptions, |
54 | const OneShotAnalysisState &state) { |
55 | assert(llvm::hasSingleElement(region->getBlocks()) && |
56 | "expected region with single block"); |
57 | bool result = true; |
58 | state.applyOnAliases(v: value, fun: [&](Value alias) { |
59 | if (llvm::is_contained(Range&: exceptions, Element: alias)) |
60 | return; |
61 | Region *aliasRegion = alias.getParentRegion(); |
62 | if (isa<BlockArgument>(Val: alias) && !region->isProperAncestor(other: aliasRegion)) |
63 | result = false; |
64 | if (isa<OpResult>(Val: alias) && !region->isAncestor(other: aliasRegion)) |
65 | result = false; |
66 | }); |
67 | return result; |
68 | } |
69 | |
70 | /// Bufferization of scf.condition. |
71 | struct ConditionOpInterface |
72 | : public BufferizableOpInterface::ExternalModel<ConditionOpInterface, |
73 | scf::ConditionOp> { |
74 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
75 | const AnalysisState &state) const { |
76 | return true; |
77 | } |
78 | |
79 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
80 | const AnalysisState &state) const { |
81 | return false; |
82 | } |
83 | |
84 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
85 | const AnalysisState &state) const { |
86 | return {}; |
87 | } |
88 | |
89 | bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
90 | const AnalysisState &state) const { |
91 | // Condition operands always bufferize inplace. Otherwise, an alloc + copy |
92 | // may be generated inside the block. We should not return/yield allocations |
93 | // when possible. |
94 | return true; |
95 | } |
96 | |
97 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
98 | const BufferizationOptions &options, |
99 | BufferizationState &state) const { |
100 | auto conditionOp = cast<scf::ConditionOp>(op); |
101 | auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp()); |
102 | |
103 | SmallVector<Value> newArgs; |
104 | for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { |
105 | Value value = it.value(); |
106 | if (isa<TensorType>(value.getType())) { |
107 | FailureOr<Value> maybeBuffer = |
108 | getBuffer(rewriter, value, options, state); |
109 | if (failed(maybeBuffer)) |
110 | return failure(); |
111 | FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
112 | whileOp.getAfterArguments()[it.index()], options, state); |
113 | if (failed(resultType)) |
114 | return failure(); |
115 | Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType); |
116 | newArgs.push_back(buffer); |
117 | } else { |
118 | newArgs.push_back(value); |
119 | } |
120 | } |
121 | |
122 | replaceOpWithNewBufferizedOp<scf::ConditionOp>( |
123 | rewriter, op, conditionOp.getCondition(), newArgs); |
124 | return success(); |
125 | } |
126 | }; |
127 | |
128 | /// Return the unique scf.yield op. If there are multiple or no scf.yield ops, |
129 | /// return an empty op. |
130 | static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) { |
131 | scf::YieldOp result; |
132 | for (Block &block : executeRegionOp.getRegion()) { |
133 | if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) { |
134 | if (result) |
135 | return {}; |
136 | result = yieldOp; |
137 | } |
138 | } |
139 | return result; |
140 | } |
141 | |
142 | /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not |
143 | /// fully implemented at the moment. |
144 | struct ExecuteRegionOpInterface |
145 | : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< |
146 | ExecuteRegionOpInterface, scf::ExecuteRegionOp> { |
147 | |
148 | static bool supportsUnstructuredControlFlow() { return true; } |
149 | |
150 | bool isWritable(Operation *op, Value value, |
151 | const AnalysisState &state) const { |
152 | return true; |
153 | } |
154 | |
155 | LogicalResult verifyAnalysis(Operation *op, |
156 | const AnalysisState &state) const { |
157 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
158 | // TODO: scf.execute_region with multiple yields are not supported. |
159 | if (!getUniqueYieldOp(executeRegionOp)) |
160 | return op->emitOpError(message: "op without unique scf.yield is not supported"); |
161 | return success(); |
162 | } |
163 | |
164 | AliasingOpOperandList |
165 | getAliasingOpOperands(Operation *op, Value value, |
166 | const AnalysisState &state) const { |
167 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) |
168 | return getAliasingBranchOpOperands(op, bbArg, state); |
169 | |
170 | // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be |
171 | // any SSA value that is in scope. To allow for use-def chain traversal |
172 | // through ExecuteRegionOps in the analysis, the corresponding yield value |
173 | // is considered to be aliasing with the result. |
174 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
175 | auto it = llvm::find(Range: op->getOpResults(), Val: value); |
176 | assert(it != op->getOpResults().end() && "invalid value"); |
177 | size_t resultNum = std::distance(first: op->getOpResults().begin(), last: it); |
178 | auto yieldOp = getUniqueYieldOp(executeRegionOp); |
179 | // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail. |
180 | if (!yieldOp) |
181 | return {}; |
182 | return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; |
183 | } |
184 | |
185 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
186 | const BufferizationOptions &options, |
187 | BufferizationState &state) const { |
188 | auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
189 | auto yieldOp = getUniqueYieldOp(executeRegionOp); |
190 | TypeRange newResultTypes(yieldOp.getResults()); |
191 | |
192 | // Create new op and move over region. |
193 | auto newOp = |
194 | rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); |
195 | newOp.getRegion().takeBody(executeRegionOp.getRegion()); |
196 | |
197 | // Bufferize every block. |
198 | for (Block &block : newOp.getRegion()) |
199 | if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, |
200 | options, state))) |
201 | return failure(); |
202 | |
203 | // Update all uses of the old op. |
204 | rewriter.setInsertionPointAfter(newOp); |
205 | SmallVector<Value> newResults; |
206 | for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { |
207 | if (isa<TensorType>(it.value())) { |
208 | newResults.push_back(rewriter.create<bufferization::ToTensorOp>( |
209 | executeRegionOp.getLoc(), it.value(), |
210 | newOp->getResult(it.index()))); |
211 | } else { |
212 | newResults.push_back(newOp->getResult(it.index())); |
213 | } |
214 | } |
215 | |
216 | // Replace old op. |
217 | rewriter.replaceOp(executeRegionOp, newResults); |
218 | |
219 | return success(); |
220 | } |
221 | }; |
222 | |
223 | /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. |
224 | struct IfOpInterface |
225 | : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { |
226 | AliasingOpOperandList |
227 | getAliasingOpOperands(Operation *op, Value value, |
228 | const AnalysisState &state) const { |
229 | // IfOps do not have tensor OpOperands. The yielded value can be any SSA |
230 | // value that is in scope. To allow for use-def chain traversal through |
231 | // IfOps in the analysis, both corresponding yield values from the then/else |
232 | // branches are considered to be aliasing with the result. |
233 | auto ifOp = cast<scf::IfOp>(op); |
234 | size_t resultNum = std::distance(first: op->getOpResults().begin(), |
235 | last: llvm::find(Range: op->getOpResults(), Val: value)); |
236 | OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum); |
237 | OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum); |
238 | return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false}, |
239 | {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}}; |
240 | } |
241 | |
242 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
243 | const BufferizationOptions &options, |
244 | BufferizationState &state) const { |
245 | OpBuilder::InsertionGuard g(rewriter); |
246 | auto ifOp = cast<scf::IfOp>(op); |
247 | |
248 | // Compute bufferized result types. |
249 | SmallVector<Type> newTypes; |
250 | for (Value result : ifOp.getResults()) { |
251 | if (!isa<TensorType>(result.getType())) { |
252 | newTypes.push_back(result.getType()); |
253 | continue; |
254 | } |
255 | auto bufferType = bufferization::getBufferType(result, options, state); |
256 | if (failed(bufferType)) |
257 | return failure(); |
258 | newTypes.push_back(*bufferType); |
259 | } |
260 | |
261 | // Create new op. |
262 | rewriter.setInsertionPoint(ifOp); |
263 | auto newIfOp = |
264 | rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), |
265 | /*withElseRegion=*/true); |
266 | |
267 | // Move over then/else blocks. |
268 | rewriter.mergeBlocks(source: ifOp.thenBlock(), dest: newIfOp.thenBlock()); |
269 | rewriter.mergeBlocks(source: ifOp.elseBlock(), dest: newIfOp.elseBlock()); |
270 | |
271 | // Replace op results. |
272 | replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); |
273 | |
274 | return success(); |
275 | } |
276 | |
277 | FailureOr<BaseMemRefType> |
278 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
279 | const BufferizationState &state, |
280 | SmallVector<Value> &invocationStack) const { |
281 | auto ifOp = cast<scf::IfOp>(op); |
282 | auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator()); |
283 | auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator()); |
284 | assert(value.getDefiningOp() == op && "invalid valid"); |
285 | |
286 | // Determine buffer types of the true/false branches. |
287 | auto opResult = cast<OpResult>(Val&: value); |
288 | auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); |
289 | auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); |
290 | BaseMemRefType thenBufferType, elseBufferType; |
291 | if (isa<BaseMemRefType>(thenValue.getType())) { |
292 | // True branch was already bufferized. |
293 | thenBufferType = cast<BaseMemRefType>(thenValue.getType()); |
294 | } else { |
295 | auto maybeBufferType = bufferization::getBufferType( |
296 | value: thenValue, options, state, invocationStack); |
297 | if (failed(maybeBufferType)) |
298 | return failure(); |
299 | thenBufferType = *maybeBufferType; |
300 | } |
301 | if (isa<BaseMemRefType>(elseValue.getType())) { |
302 | // False branch was already bufferized. |
303 | elseBufferType = cast<BaseMemRefType>(elseValue.getType()); |
304 | } else { |
305 | auto maybeBufferType = bufferization::getBufferType( |
306 | value: elseValue, options, state, invocationStack); |
307 | if (failed(maybeBufferType)) |
308 | return failure(); |
309 | elseBufferType = *maybeBufferType; |
310 | } |
311 | |
312 | // Best case: Both branches have the exact same buffer type. |
313 | if (thenBufferType == elseBufferType) |
314 | return thenBufferType; |
315 | |
316 | // Memory space mismatch. |
317 | if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace()) |
318 | return op->emitError(message: "inconsistent memory space on then/else branches"); |
319 | |
320 | // Layout maps are different: Promote to fully dynamic layout map. |
321 | return getMemRefTypeWithFullyDynamicLayout( |
322 | tensorType: cast<TensorType>(Val: opResult.getType()), memorySpace: thenBufferType.getMemorySpace()); |
323 | } |
324 | }; |
325 | |
326 | /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that |
327 | /// yields memrefs. |
328 | struct IndexSwitchOpInterface |
329 | : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface, |
330 | scf::IndexSwitchOp> { |
331 | AliasingOpOperandList |
332 | getAliasingOpOperands(Operation *op, Value value, |
333 | const AnalysisState &state) const { |
334 | // IndexSwitchOps do not have tensor OpOperands. The yielded value can be |
335 | // any SSA. This is similar to IfOps. |
336 | auto switchOp = cast<scf::IndexSwitchOp>(op); |
337 | int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber(); |
338 | AliasingOpOperandList result; |
339 | for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { |
340 | auto yieldOp = |
341 | cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator()); |
342 | result.addAlias(alias: AliasingOpOperand(&yieldOp->getOpOperand(resultNum), |
343 | BufferRelation::Equivalent, |
344 | /*isDefinite=*/false)); |
345 | } |
346 | auto defaultYieldOp = |
347 | cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator()); |
348 | result.addAlias(alias: AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum), |
349 | BufferRelation::Equivalent, |
350 | /*isDefinite=*/false)); |
351 | return result; |
352 | } |
353 | |
354 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
355 | const BufferizationOptions &options, |
356 | BufferizationState &state) const { |
357 | OpBuilder::InsertionGuard g(rewriter); |
358 | auto switchOp = cast<scf::IndexSwitchOp>(op); |
359 | |
360 | // Compute bufferized result types. |
361 | SmallVector<Type> newTypes; |
362 | for (Value result : switchOp.getResults()) { |
363 | if (!isa<TensorType>(result.getType())) { |
364 | newTypes.push_back(result.getType()); |
365 | continue; |
366 | } |
367 | auto bufferType = bufferization::getBufferType(result, options, state); |
368 | if (failed(bufferType)) |
369 | return failure(); |
370 | newTypes.push_back(*bufferType); |
371 | } |
372 | |
373 | // Create new op. |
374 | rewriter.setInsertionPoint(switchOp); |
375 | auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>( |
376 | switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(), |
377 | switchOp.getCases().size()); |
378 | |
379 | // Move over blocks. |
380 | for (auto [src, dest] : |
381 | llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions())) |
382 | rewriter.inlineRegionBefore(src, dest, dest.begin()); |
383 | rewriter.inlineRegionBefore(switchOp.getDefaultRegion(), |
384 | newSwitchOp.getDefaultRegion(), |
385 | newSwitchOp.getDefaultRegion().begin()); |
386 | |
387 | // Replace op results. |
388 | replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults()); |
389 | |
390 | return success(); |
391 | } |
392 | |
393 | FailureOr<BaseMemRefType> |
394 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
395 | const BufferizationState &state, |
396 | SmallVector<Value> &invocationStack) const { |
397 | auto switchOp = cast<scf::IndexSwitchOp>(op); |
398 | assert(value.getDefiningOp() == op && "invalid value"); |
399 | int64_t resultNum = cast<OpResult>(Val&: value).getResultNumber(); |
400 | |
401 | // Helper function to get buffer type of a case. |
402 | auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> { |
403 | auto yieldOp = cast<scf::YieldOp>(b.getTerminator()); |
404 | Value yieldedValue = yieldOp->getOperand(resultNum); |
405 | if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType())) |
406 | return bufferType; |
407 | auto maybeBufferType = bufferization::getBufferType( |
408 | yieldedValue, options, state, invocationStack); |
409 | if (failed(maybeBufferType)) |
410 | return failure(); |
411 | return maybeBufferType; |
412 | }; |
413 | |
414 | // Compute buffer type of the default case. |
415 | auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock()); |
416 | if (failed(maybeBufferType)) |
417 | return failure(); |
418 | BaseMemRefType bufferType = *maybeBufferType; |
419 | |
420 | // Compute buffer types of all other cases. |
421 | for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { |
422 | auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i)); |
423 | if (failed(yieldedBufferType)) |
424 | return failure(); |
425 | |
426 | // Best case: Both branches have the exact same buffer type. |
427 | if (bufferType == *yieldedBufferType) |
428 | continue; |
429 | |
430 | // Memory space mismatch. |
431 | if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace()) |
432 | return op->emitError(message: "inconsistent memory space on switch cases"); |
433 | |
434 | // Layout maps are different: Promote to fully dynamic layout map. |
435 | bufferType = getMemRefTypeWithFullyDynamicLayout( |
436 | tensorType: cast<TensorType>(Val: value.getType()), memorySpace: bufferType.getMemorySpace()); |
437 | } |
438 | |
439 | return bufferType; |
440 | } |
441 | }; |
442 | |
443 | /// Helper function for loop bufferization. Return the indices of all values |
444 | /// that have a tensor type. |
445 | static DenseSet<int64_t> getTensorIndices(ValueRange values) { |
446 | DenseSet<int64_t> result; |
447 | for (const auto &it : llvm::enumerate(First&: values)) |
448 | if (isa<TensorType>(Val: it.value().getType())) |
449 | result.insert(V: it.index()); |
450 | return result; |
451 | } |
452 | |
453 | /// Helper function for loop bufferization. Return the indices of all |
454 | /// bbArg/yielded value pairs who's buffer relation is "Equivalent". |
455 | DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, |
456 | ValueRange yieldedValues, |
457 | const AnalysisState &state) { |
458 | unsigned int minSize = std::min(a: bbArgs.size(), b: yieldedValues.size()); |
459 | DenseSet<int64_t> result; |
460 | for (unsigned int i = 0; i < minSize; ++i) { |
461 | if (!isa<TensorType>(Val: bbArgs[i].getType()) || |
462 | !isa<TensorType>(Val: yieldedValues[i].getType())) |
463 | continue; |
464 | if (state.areEquivalentBufferizedValues(v1: bbArgs[i], v2: yieldedValues[i])) |
465 | result.insert(V: i); |
466 | } |
467 | return result; |
468 | } |
469 | |
470 | /// Helper function for loop bufferization. Return the bufferized values of the |
471 | /// given OpOperands. If an operand is not a tensor, return the original value. |
472 | static FailureOr<SmallVector<Value>> |
473 | getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, |
474 | const BufferizationOptions &options, BufferizationState &state) { |
475 | SmallVector<Value> result; |
476 | for (OpOperand &opOperand : operands) { |
477 | if (isa<TensorType>(Val: opOperand.get().getType())) { |
478 | FailureOr<Value> resultBuffer = |
479 | getBuffer(rewriter, value: opOperand.get(), options, state); |
480 | if (failed(Result: resultBuffer)) |
481 | return failure(); |
482 | result.push_back(Elt: *resultBuffer); |
483 | } else { |
484 | result.push_back(Elt: opOperand.get()); |
485 | } |
486 | } |
487 | return result; |
488 | } |
489 | |
490 | /// Helper function for loop bufferization. Given a list of bbArgs of the new |
491 | /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into |
492 | /// ToTensorOps, so that the block body can be moved over to the new op. |
493 | static SmallVector<Value> |
494 | getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, |
495 | Block::BlockArgListType oldBbArgs, |
496 | const DenseSet<int64_t> &tensorIndices) { |
497 | SmallVector<Value> result; |
498 | for (const auto &it : llvm::enumerate(First&: bbArgs)) { |
499 | size_t idx = it.index(); |
500 | Value val = it.value(); |
501 | if (tensorIndices.contains(V: idx)) { |
502 | result.push_back(rewriter |
503 | .create<bufferization::ToTensorOp>( |
504 | val.getLoc(), oldBbArgs[idx].getType(), val) |
505 | .getResult()); |
506 | } else { |
507 | result.push_back(Elt: val); |
508 | } |
509 | } |
510 | return result; |
511 | } |
512 | |
513 | /// Compute the bufferized type of a loop iter_arg. This type must be equal to |
514 | /// the bufferized type of the corresponding init_arg and the bufferized type |
515 | /// of the corresponding yielded value. |
516 | /// |
517 | /// This function uses bufferization::getBufferType to compute the bufferized |
518 | /// type of the init_arg and of the yielded value. (The computation of the |
519 | /// bufferized yielded value type usually requires computing the bufferized type |
520 | /// of the iter_arg again; the implementation of getBufferType traces back the |
521 | /// use-def chain of the given value and computes a buffer type along the way.) |
522 | /// If both buffer types are equal, no casts are needed the computed buffer type |
523 | /// can be used directly. Otherwise, the buffer types can only differ in their |
524 | /// layout map and a cast must be inserted. |
525 | static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType( |
526 | Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue, |
527 | const BufferizationOptions &options, const BufferizationState &state, |
528 | SmallVector<Value> &invocationStack) { |
529 | // Determine the buffer type of the init_arg. |
530 | auto initArgBufferType = |
531 | bufferization::getBufferType(initArg, options, state, invocationStack); |
532 | if (failed(initArgBufferType)) |
533 | return failure(); |
534 | |
535 | if (llvm::count(Range&: invocationStack, Element: iterArg) >= 2) { |
536 | // If the iter_arg is already twice on the invocation stack, just take the |
537 | // type of the init_arg. This is to avoid infinite loops when calculating |
538 | // the buffer type. This will most likely result in computing a memref type |
539 | // with a fully dynamic layout map. |
540 | |
541 | // Note: For more precise layout map computation, a fixpoint iteration could |
542 | // be done (i.e., re-computing the yielded buffer type until the bufferized |
543 | // iter_arg type no longer changes). This current implementation immediately |
544 | // switches to a fully dynamic layout map when a mismatch between bufferized |
545 | // init_arg type and bufferized yield value type is detected. |
546 | return *initArgBufferType; |
547 | } |
548 | |
549 | // Compute the buffer type of the yielded value. |
550 | BaseMemRefType yieldedValueBufferType; |
551 | if (isa<BaseMemRefType>(Val: yieldedValue.getType())) { |
552 | // scf.yield was already bufferized. |
553 | yieldedValueBufferType = cast<BaseMemRefType>(Val: yieldedValue.getType()); |
554 | } else { |
555 | // Note: This typically triggers a recursive call for the buffer type of |
556 | // the iter_arg. |
557 | auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, |
558 | state, invocationStack); |
559 | if (failed(maybeBufferType)) |
560 | return failure(); |
561 | yieldedValueBufferType = *maybeBufferType; |
562 | } |
563 | |
564 | // If yielded type and init_arg type are the same, use that type directly. |
565 | if (*initArgBufferType == yieldedValueBufferType) |
566 | return yieldedValueBufferType; |
567 | |
568 | // If there is a mismatch between the yielded buffer type and the init_arg |
569 | // buffer type, the buffer type must be promoted to a fully dynamic layout |
570 | // map. |
571 | auto yieldedBufferType = cast<BaseMemRefType>(Val&: yieldedValueBufferType); |
572 | auto iterTensorType = cast<TensorType>(Val: iterArg.getType()); |
573 | auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType); |
574 | if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace()) |
575 | return loopOp->emitOpError( |
576 | message: "init_arg and yielded value bufferize to inconsistent memory spaces"); |
577 | #ifndef NDEBUG |
578 | if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) { |
579 | assert( |
580 | llvm::all_equal({yieldedRankedBufferType.getShape(), |
581 | cast<MemRefType>(initBufferType).getShape(), |
582 | cast<RankedTensorType>(iterTensorType).getShape()}) && |
583 | "expected same shape"); |
584 | } |
585 | #endif // NDEBUG |
586 | return getMemRefTypeWithFullyDynamicLayout( |
587 | tensorType: iterTensorType, memorySpace: yieldedBufferType.getMemorySpace()); |
588 | } |
589 | |
590 | /// Return `true` if the given loop may have 0 iterations. |
591 | bool mayHaveZeroIterations(scf::ForOp forOp) { |
592 | std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound()); |
593 | std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound()); |
594 | if (!lb.has_value() || !ub.has_value()) |
595 | return true; |
596 | return *ub <= *lb; |
597 | } |
598 | |
599 | /// Bufferization of scf.for. Replace with a new scf.for that operates on |
600 | /// memrefs. |
601 | struct ForOpInterface |
602 | : public BufferizableOpInterface::ExternalModel<ForOpInterface, |
603 | scf::ForOp> { |
604 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
605 | const AnalysisState &state) const { |
606 | auto forOp = cast<scf::ForOp>(op); |
607 | |
608 | // If the loop has zero iterations, the results of the op are their |
609 | // corresponding init_args, meaning that the init_args bufferize to a read. |
610 | if (mayHaveZeroIterations(forOp)) |
611 | return true; |
612 | |
613 | // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of |
614 | // its matching bbArg may. |
615 | return state.isValueRead(value: forOp.getTiedLoopRegionIterArg(&opOperand)); |
616 | } |
617 | |
618 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
619 | const AnalysisState &state) const { |
620 | // Tensor iter_args of scf::ForOps are always considered as a write. |
621 | return true; |
622 | } |
623 | |
624 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
625 | const AnalysisState &state) const { |
626 | auto forOp = cast<scf::ForOp>(op); |
627 | OpResult opResult = forOp.getTiedLoopResult(&opOperand); |
628 | BufferRelation relation = bufferRelation(op, opResult, state); |
629 | return {{opResult, relation, |
630 | /*isDefinite=*/relation == BufferRelation::Equivalent}}; |
631 | } |
632 | |
633 | BufferRelation bufferRelation(Operation *op, OpResult opResult, |
634 | const AnalysisState &state) const { |
635 | // ForOp results are equivalent to their corresponding init_args if the |
636 | // corresponding iter_args and yield values are equivalent. |
637 | auto forOp = cast<scf::ForOp>(op); |
638 | BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); |
639 | bool equivalentYield = state.areEquivalentBufferizedValues( |
640 | v1: bbArg, v2: forOp.getTiedLoopYieldedValue(bbArg)->get()); |
641 | return equivalentYield ? BufferRelation::Equivalent |
642 | : BufferRelation::Unknown; |
643 | } |
644 | |
645 | bool isWritable(Operation *op, Value value, |
646 | const AnalysisState &state) const { |
647 | // Interestingly, scf::ForOp's bbArg can **always** be viewed |
648 | // inplace from the perspective of ops nested under: |
649 | // 1. Either the matching iter operand is not bufferized inplace and an |
650 | // alloc + optional copy makes the bbArg itself inplaceable. |
651 | // 2. Or the matching iter operand is bufferized inplace and bbArg just |
652 | // bufferizes to that too. |
653 | return true; |
654 | } |
655 | |
656 | LogicalResult |
657 | resolveConflicts(Operation *op, RewriterBase &rewriter, |
658 | const AnalysisState &analysisState, |
659 | const BufferizationState &bufferizationState) const { |
660 | auto bufferizableOp = cast<BufferizableOpInterface>(op); |
661 | if (failed(bufferizableOp.resolveTensorOpOperandConflicts( |
662 | rewriter, analysisState, bufferizationState))) |
663 | return failure(); |
664 | |
665 | if (analysisState.getOptions().copyBeforeWrite) |
666 | return success(); |
667 | |
668 | // According to the `getAliasing...` implementations, a bufferized OpResult |
669 | // may alias only with the corresponding bufferized init_arg (or with a |
670 | // newly allocated buffer) and not with other buffers defined outside of the |
671 | // loop. I.e., the i-th OpResult may alias with the i-th init_arg; |
672 | // but not with any other OpOperand. |
673 | auto forOp = cast<scf::ForOp>(op); |
674 | auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
675 | OpBuilder::InsertionGuard g(rewriter); |
676 | rewriter.setInsertionPoint(yieldOp); |
677 | |
678 | // Indices of all iter_args that have tensor type. These are the ones that |
679 | // are bufferized. |
680 | DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); |
681 | // For every yielded value, does it alias with something defined outside of |
682 | // the loop? |
683 | SmallVector<Value> yieldValues; |
684 | for (const auto it : llvm::enumerate(yieldOp.getResults())) { |
685 | // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this |
686 | // type cannot be used in the signature of `resolveConflicts` because the |
687 | // op interface is in the "IR" build unit and the `OneShotAnalysisState` |
688 | // is defined in the "Transforms" build unit. |
689 | if (!indices.contains(it.index()) || |
690 | doesNotAliasExternalValue( |
691 | it.value(), &forOp.getRegion(), |
692 | /*exceptions=*/forOp.getRegionIterArg(it.index()), |
693 | static_cast<const OneShotAnalysisState &>(analysisState))) { |
694 | yieldValues.push_back(it.value()); |
695 | continue; |
696 | } |
697 | FailureOr<Value> alloc = allocateTensorForShapedValue( |
698 | rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(), |
699 | bufferizationState); |
700 | if (failed(alloc)) |
701 | return failure(); |
702 | yieldValues.push_back(*alloc); |
703 | } |
704 | |
705 | rewriter.modifyOpInPlace( |
706 | yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); |
707 | return success(); |
708 | } |
709 | |
710 | FailureOr<BaseMemRefType> |
711 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
712 | const BufferizationState &state, |
713 | SmallVector<Value> &invocationStack) const { |
714 | auto forOp = cast<scf::ForOp>(op); |
715 | assert(getOwnerOfValue(value) == op && "invalid value"); |
716 | assert(isa<TensorType>(value.getType()) && "expected tensor type"); |
717 | |
718 | if (auto opResult = dyn_cast<OpResult>(Val&: value)) { |
719 | // The type of an OpResult must match the corresponding iter_arg type. |
720 | BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); |
721 | return bufferization::getBufferType(bbArg, options, state, |
722 | invocationStack); |
723 | } |
724 | |
725 | // Compute result/argument number. |
726 | BlockArgument bbArg = cast<BlockArgument>(Val&: value); |
727 | unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber(); |
728 | |
729 | // Compute the bufferized type. |
730 | auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
731 | Value yieldedValue = yieldOp.getOperand(resultNum); |
732 | BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum]; |
733 | Value initArg = forOp.getInitArgs()[resultNum]; |
734 | return computeLoopRegionIterArgBufferType( |
735 | loopOp: op, iterArg, initArg, yieldedValue, options, state, invocationStack); |
736 | } |
737 | |
738 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
739 | const BufferizationOptions &options, |
740 | BufferizationState &state) const { |
741 | auto forOp = cast<scf::ForOp>(op); |
742 | Block *oldLoopBody = forOp.getBody(); |
743 | |
744 | // Indices of all iter_args that have tensor type. These are the ones that |
745 | // are bufferized. |
746 | DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); |
747 | |
748 | // The new memref init_args of the loop. |
749 | FailureOr<SmallVector<Value>> maybeInitArgs = |
750 | getBuffers(rewriter, forOp.getInitArgsMutable(), options, state); |
751 | if (failed(Result: maybeInitArgs)) |
752 | return failure(); |
753 | SmallVector<Value> initArgs = *maybeInitArgs; |
754 | |
755 | // Cast init_args if necessary. |
756 | SmallVector<Value> castedInitArgs; |
757 | for (const auto &it : llvm::enumerate(initArgs)) { |
758 | Value initArg = it.value(); |
759 | Value result = forOp->getResult(it.index()); |
760 | // If the type is not a tensor, bufferization doesn't need to touch it. |
761 | if (!isa<TensorType>(result.getType())) { |
762 | castedInitArgs.push_back(initArg); |
763 | continue; |
764 | } |
765 | auto targetType = bufferization::getBufferType(result, options, state); |
766 | if (failed(targetType)) |
767 | return failure(); |
768 | castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); |
769 | } |
770 | |
771 | // Construct a new scf.for op with memref instead of tensor values. |
772 | auto newForOp = rewriter.create<scf::ForOp>( |
773 | forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
774 | forOp.getStep(), castedInitArgs); |
775 | newForOp->setAttrs(forOp->getAttrs()); |
776 | Block *loopBody = newForOp.getBody(); |
777 | |
778 | // Set up new iter_args. The loop body uses tensors, so wrap the (memref) |
779 | // iter_args of the new loop in ToTensorOps. |
780 | rewriter.setInsertionPointToStart(loopBody); |
781 | SmallVector<Value> iterArgs = |
782 | getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), |
783 | forOp.getRegionIterArgs(), indices); |
784 | iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); |
785 | |
786 | // Move loop body to new loop. |
787 | rewriter.mergeBlocks(source: oldLoopBody, dest: loopBody, argValues: iterArgs); |
788 | |
789 | // Replace loop results. |
790 | replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); |
791 | |
792 | return success(); |
793 | } |
794 | |
795 | /// Assert that yielded values of an scf.for op are equivalent to their |
796 | /// corresponding bbArgs. In that case, the buffer relations of the |
797 | /// corresponding OpResults are "Equivalent". |
798 | /// |
799 | /// If this is not the case, an allocs+copies are inserted and yielded from |
800 | /// the loop. This could be a performance problem, so it must be explicitly |
801 | /// activated with `alloc-return-allocs`. |
802 | LogicalResult verifyAnalysis(Operation *op, |
803 | const AnalysisState &state) const { |
804 | const auto &options = |
805 | static_cast<const OneShotBufferizationOptions &>(state.getOptions()); |
806 | if (options.allowReturnAllocsFromLoops) |
807 | return success(); |
808 | |
809 | auto forOp = cast<scf::ForOp>(op); |
810 | auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
811 | for (OpResult opResult : op->getOpResults()) { |
812 | if (!isa<TensorType>(Val: opResult.getType())) |
813 | continue; |
814 | |
815 | // Note: This is overly strict. We should check for aliasing bufferized |
816 | // values. But we don't have a "must-alias" analysis yet. |
817 | if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) |
818 | return yieldOp->emitError() |
819 | << "Yield operand #"<< opResult.getResultNumber() |
820 | << " is not equivalent to the corresponding iter bbArg"; |
821 | } |
822 | |
823 | return success(); |
824 | } |
825 | }; |
826 | |
827 | /// Bufferization of scf.while. Replace with a new scf.while that operates on |
828 | /// memrefs. |
829 | struct WhileOpInterface |
830 | : public BufferizableOpInterface::ExternalModel<WhileOpInterface, |
831 | scf::WhileOp> { |
832 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
833 | const AnalysisState &state) const { |
834 | // Tensor iter_args of scf::WhileOps are always considered as a read. |
835 | return true; |
836 | } |
837 | |
838 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
839 | const AnalysisState &state) const { |
840 | // Tensor iter_args of scf::WhileOps are always considered as a write. |
841 | return true; |
842 | } |
843 | |
844 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
845 | const AnalysisState &state) const { |
846 | auto whileOp = cast<scf::WhileOp>(op); |
847 | unsigned int idx = opOperand.getOperandNumber(); |
848 | |
849 | // The OpResults and OpOperands may not match. They may not even have the |
850 | // same type. The number of OpResults and OpOperands can also differ. |
851 | if (idx >= op->getNumResults() || |
852 | opOperand.get().getType() != op->getResult(idx).getType()) |
853 | return {}; |
854 | |
855 | // The only aliasing OpResult may be the one at the same index. |
856 | OpResult opResult = whileOp->getResult(idx); |
857 | BufferRelation relation = bufferRelation(op, opResult, state); |
858 | return {{opResult, relation, |
859 | /*isDefinite=*/relation == BufferRelation::Equivalent}}; |
860 | } |
861 | |
862 | BufferRelation bufferRelation(Operation *op, OpResult opResult, |
863 | const AnalysisState &state) const { |
864 | // WhileOp results are equivalent to their corresponding init_args if the |
865 | // corresponding iter_args and yield values are equivalent (for both the |
866 | // "before" and the "after" block). |
867 | unsigned int resultNumber = opResult.getResultNumber(); |
868 | auto whileOp = cast<scf::WhileOp>(op); |
869 | |
870 | // The "before" region bbArgs and the OpResults may not match. |
871 | if (resultNumber >= whileOp.getBeforeArguments().size()) |
872 | return BufferRelation::Unknown; |
873 | if (opResult.getType() != |
874 | whileOp.getBeforeArguments()[resultNumber].getType()) |
875 | return BufferRelation::Unknown; |
876 | |
877 | auto conditionOp = whileOp.getConditionOp(); |
878 | BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; |
879 | Value conditionOperand = conditionOp.getArgs()[resultNumber]; |
880 | bool equivCondition = |
881 | state.areEquivalentBufferizedValues(v1: conditionBbArg, v2: conditionOperand); |
882 | |
883 | auto yieldOp = whileOp.getYieldOp(); |
884 | BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; |
885 | Value yieldOperand = yieldOp.getOperand(resultNumber); |
886 | bool equivYield = |
887 | state.areEquivalentBufferizedValues(v1: bodyBbArg, v2: yieldOperand); |
888 | |
889 | return equivCondition && equivYield ? BufferRelation::Equivalent |
890 | : BufferRelation::Unknown; |
891 | } |
892 | |
893 | bool isWritable(Operation *op, Value value, |
894 | const AnalysisState &state) const { |
895 | // Interestingly, scf::WhileOp's bbArg can **always** be viewed |
896 | // inplace from the perspective of ops nested under: |
897 | // 1. Either the matching iter operand is not bufferized inplace and an |
898 | // alloc + optional copy makes the bbArg itself inplaceable. |
899 | // 2. Or the matching iter operand is bufferized inplace and bbArg just |
900 | // bufferizes to that too. |
901 | return true; |
902 | } |
903 | |
904 | LogicalResult |
905 | resolveConflicts(Operation *op, RewriterBase &rewriter, |
906 | const AnalysisState &analysisState, |
907 | const BufferizationState &bufferizationState) const { |
908 | auto bufferizableOp = cast<BufferizableOpInterface>(op); |
909 | if (failed(bufferizableOp.resolveTensorOpOperandConflicts( |
910 | rewriter, analysisState, bufferizationState))) |
911 | return failure(); |
912 | |
913 | if (analysisState.getOptions().copyBeforeWrite) |
914 | return success(); |
915 | |
916 | // According to the `getAliasing...` implementations, a bufferized OpResult |
917 | // may alias only with the corresponding bufferized init_arg and with no |
918 | // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; |
919 | // but not with any other OpOperand. If a corresponding OpResult/init_arg |
920 | // pair bufferizes to equivalent buffers, this aliasing requirement is |
921 | // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. |
922 | // (New buffer copies do not alias with any buffer.) |
923 | OpBuilder::InsertionGuard g(rewriter); |
924 | auto whileOp = cast<scf::WhileOp>(op); |
925 | auto conditionOp = whileOp.getConditionOp(); |
926 | |
927 | // For every yielded value, is the value equivalent to its corresponding |
928 | // bbArg? |
929 | DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( |
930 | whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState); |
931 | DenseSet<int64_t> equivalentYieldsAfter = |
932 | getEquivalentBuffers(whileOp.getAfterArguments(), |
933 | whileOp.getYieldOp().getResults(), analysisState); |
934 | |
935 | // Update "before" region. |
936 | rewriter.setInsertionPoint(conditionOp); |
937 | SmallVector<Value> beforeYieldValues; |
938 | for (int64_t idx = 0; |
939 | idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) { |
940 | Value value = conditionOp.getArgs()[idx]; |
941 | if (!isa<TensorType>(Val: value.getType()) || |
942 | (equivalentYieldsAfter.contains(V: idx) && |
943 | equivalentYieldsBefore.contains(V: idx))) { |
944 | beforeYieldValues.push_back(Elt: value); |
945 | continue; |
946 | } |
947 | FailureOr<Value> alloc = allocateTensorForShapedValue( |
948 | rewriter, conditionOp.getLoc(), value, analysisState.getOptions(), |
949 | bufferizationState); |
950 | if (failed(Result: alloc)) |
951 | return failure(); |
952 | beforeYieldValues.push_back(Elt: *alloc); |
953 | } |
954 | rewriter.modifyOpInPlace(conditionOp, [&]() { |
955 | conditionOp.getArgsMutable().assign(beforeYieldValues); |
956 | }); |
957 | |
958 | return success(); |
959 | } |
960 | |
961 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
962 | const BufferizationOptions &options, |
963 | BufferizationState &state) const { |
964 | auto whileOp = cast<scf::WhileOp>(op); |
965 | |
966 | // Indices of all bbArgs that have tensor type. These are the ones that |
967 | // are bufferized. The "before" and "after" regions may have different args. |
968 | DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); |
969 | DenseSet<int64_t> indicesAfter = |
970 | getTensorIndices(whileOp.getAfterArguments()); |
971 | |
972 | // The new memref init_args of the loop. |
973 | FailureOr<SmallVector<Value>> maybeInitArgs = |
974 | getBuffers(rewriter, whileOp.getInitsMutable(), options, state); |
975 | if (failed(Result: maybeInitArgs)) |
976 | return failure(); |
977 | SmallVector<Value> initArgs = *maybeInitArgs; |
978 | |
979 | // Cast init_args if necessary. |
980 | SmallVector<Value> castedInitArgs; |
981 | for (const auto &it : llvm::enumerate(initArgs)) { |
982 | Value initArg = it.value(); |
983 | Value beforeArg = whileOp.getBeforeArguments()[it.index()]; |
984 | // If the type is not a tensor, bufferization doesn't need to touch it. |
985 | if (!isa<TensorType>(beforeArg.getType())) { |
986 | castedInitArgs.push_back(initArg); |
987 | continue; |
988 | } |
989 | auto targetType = bufferization::getBufferType(beforeArg, options, state); |
990 | if (failed(targetType)) |
991 | return failure(); |
992 | castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); |
993 | } |
994 | |
995 | // The result types of a WhileOp are the same as the "after" bbArg types. |
996 | SmallVector<Type> argsTypesAfter = llvm::to_vector( |
997 | llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { |
998 | if (!isa<TensorType>(Val: bbArg.getType())) |
999 | return bbArg.getType(); |
1000 | // TODO: error handling |
1001 | return llvm::cast<Type>( |
1002 | Val: *bufferization::getBufferType(value: bbArg, options, state)); |
1003 | })); |
1004 | |
1005 | // Construct a new scf.while op with memref instead of tensor values. |
1006 | ValueRange argsRangeBefore(castedInitArgs); |
1007 | TypeRange argsTypesBefore(argsRangeBefore); |
1008 | auto newWhileOp = rewriter.create<scf::WhileOp>( |
1009 | whileOp.getLoc(), argsTypesAfter, castedInitArgs); |
1010 | |
1011 | // Add before/after regions to the new op. |
1012 | SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(), |
1013 | whileOp.getLoc()); |
1014 | SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(), |
1015 | whileOp.getLoc()); |
1016 | Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); |
1017 | newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); |
1018 | Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); |
1019 | newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); |
1020 | |
1021 | // Set up new iter_args and move the loop condition block to the new op. |
1022 | // The old block uses tensors, so wrap the (memref) bbArgs of the new block |
1023 | // in ToTensorOps. |
1024 | rewriter.setInsertionPointToStart(newBeforeBody); |
1025 | SmallVector<Value> newBeforeArgs = |
1026 | getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(), |
1027 | whileOp.getBeforeArguments(), indicesBefore); |
1028 | rewriter.mergeBlocks(source: whileOp.getBeforeBody(), dest: newBeforeBody, argValues: newBeforeArgs); |
1029 | |
1030 | // Set up new iter_args and move the loop body block to the new op. |
1031 | // The old block uses tensors, so wrap the (memref) bbArgs of the new block |
1032 | // in ToTensorOps. |
1033 | rewriter.setInsertionPointToStart(newAfterBody); |
1034 | SmallVector<Value> newAfterArgs = |
1035 | getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(), |
1036 | whileOp.getAfterArguments(), indicesAfter); |
1037 | rewriter.mergeBlocks(source: whileOp.getAfterBody(), dest: newAfterBody, argValues: newAfterArgs); |
1038 | |
1039 | // Replace loop results. |
1040 | replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); |
1041 | |
1042 | return success(); |
1043 | } |
1044 | |
1045 | FailureOr<BaseMemRefType> |
1046 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
1047 | const BufferizationState &state, |
1048 | SmallVector<Value> &invocationStack) const { |
1049 | auto whileOp = cast<scf::WhileOp>(op); |
1050 | assert(getOwnerOfValue(value) == op && "invalid value"); |
1051 | assert(isa<TensorType>(value.getType()) && "expected tensor type"); |
1052 | |
1053 | // Case 1: Block argument of the "before" region. |
1054 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) { |
1055 | if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { |
1056 | Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; |
1057 | auto yieldOp = whileOp.getYieldOp(); |
1058 | Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber()); |
1059 | return computeLoopRegionIterArgBufferType( |
1060 | loopOp: op, iterArg: bbArg, initArg, yieldedValue, options, state, invocationStack); |
1061 | } |
1062 | } |
1063 | |
1064 | // Case 2: OpResult of the loop or block argument of the "after" region. |
1065 | // The bufferized "after" bbArg type can be directly computed from the |
1066 | // bufferized "before" bbArg type. |
1067 | unsigned resultNum; |
1068 | if (auto opResult = dyn_cast<OpResult>(Val&: value)) { |
1069 | resultNum = opResult.getResultNumber(); |
1070 | } else if (cast<BlockArgument>(Val&: value).getOwner()->getParent() == |
1071 | &whileOp.getAfter()) { |
1072 | resultNum = cast<BlockArgument>(Val&: value).getArgNumber(); |
1073 | } else { |
1074 | llvm_unreachable("invalid value"); |
1075 | } |
1076 | Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; |
1077 | if (!isa<TensorType>(Val: conditionYieldedVal.getType())) { |
1078 | // scf.condition was already bufferized. |
1079 | return cast<BaseMemRefType>(Val: conditionYieldedVal.getType()); |
1080 | } |
1081 | return bufferization::getBufferType(conditionYieldedVal, options, state, |
1082 | invocationStack); |
1083 | } |
1084 | |
1085 | /// Assert that yielded values of an scf.while op are equivalent to their |
1086 | /// corresponding bbArgs. In that case, the buffer relations of the |
1087 | /// corresponding OpResults are "Equivalent". |
1088 | /// |
1089 | /// If this is not the case, allocs+copies are inserted and yielded from |
1090 | /// the loop. This could be a performance problem, so it must be explicitly |
1091 | /// activated with `allow-return-allocs`. |
1092 | /// |
1093 | /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the |
1094 | /// equivalence condition must be checked for both. |
1095 | LogicalResult verifyAnalysis(Operation *op, |
1096 | const AnalysisState &state) const { |
1097 | auto whileOp = cast<scf::WhileOp>(op); |
1098 | const auto &options = |
1099 | static_cast<const OneShotBufferizationOptions &>(state.getOptions()); |
1100 | if (options.allowReturnAllocsFromLoops) |
1101 | return success(); |
1102 | |
1103 | auto conditionOp = whileOp.getConditionOp(); |
1104 | for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { |
1105 | Block *block = conditionOp->getBlock(); |
1106 | if (!isa<TensorType>(it.value().getType())) |
1107 | continue; |
1108 | if (it.index() >= block->getNumArguments() || |
1109 | !state.areEquivalentBufferizedValues(it.value(), |
1110 | block->getArgument(it.index()))) |
1111 | return conditionOp->emitError() |
1112 | << "Condition arg #"<< it.index() |
1113 | << " is not equivalent to the corresponding iter bbArg"; |
1114 | } |
1115 | |
1116 | auto yieldOp = whileOp.getYieldOp(); |
1117 | for (const auto &it : llvm::enumerate(yieldOp.getResults())) { |
1118 | Block *block = yieldOp->getBlock(); |
1119 | if (!isa<TensorType>(it.value().getType())) |
1120 | continue; |
1121 | if (it.index() >= block->getNumArguments() || |
1122 | !state.areEquivalentBufferizedValues(it.value(), |
1123 | block->getArgument(it.index()))) |
1124 | return yieldOp->emitError() |
1125 | << "Yield operand #"<< it.index() |
1126 | << " is not equivalent to the corresponding iter bbArg"; |
1127 | } |
1128 | |
1129 | return success(); |
1130 | } |
1131 | }; |
1132 | |
1133 | /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so |
1134 | /// this is for analysis only. |
1135 | struct YieldOpInterface |
1136 | : public BufferizableOpInterface::ExternalModel<YieldOpInterface, |
1137 | scf::YieldOp> { |
1138 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
1139 | const AnalysisState &state) const { |
1140 | return true; |
1141 | } |
1142 | |
1143 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
1144 | const AnalysisState &state) const { |
1145 | return false; |
1146 | } |
1147 | |
1148 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
1149 | const AnalysisState &state) const { |
1150 | if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) { |
1151 | return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()), |
1152 | BufferRelation::Equivalent, /*isDefinite=*/false}}; |
1153 | } |
1154 | if (isa<scf::ExecuteRegionOp>(op->getParentOp())) |
1155 | return {{op->getParentOp()->getResult(idx: opOperand.getOperandNumber()), |
1156 | BufferRelation::Equivalent}}; |
1157 | return {}; |
1158 | } |
1159 | |
1160 | bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
1161 | const AnalysisState &state) const { |
1162 | // Yield operands always bufferize inplace. Otherwise, an alloc + copy |
1163 | // may be generated inside the block. We should not return/yield allocations |
1164 | // when possible. |
1165 | return true; |
1166 | } |
1167 | |
1168 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
1169 | const BufferizationOptions &options, |
1170 | BufferizationState &state) const { |
1171 | auto yieldOp = cast<scf::YieldOp>(op); |
1172 | if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp, |
1173 | scf::WhileOp>(yieldOp->getParentOp())) |
1174 | return yieldOp->emitError("unsupported scf::YieldOp parent"); |
1175 | |
1176 | SmallVector<Value> newResults; |
1177 | for (const auto &it : llvm::enumerate(yieldOp.getResults())) { |
1178 | Value value = it.value(); |
1179 | if (isa<TensorType>(value.getType())) { |
1180 | FailureOr<Value> maybeBuffer = |
1181 | getBuffer(rewriter, value, options, state); |
1182 | if (failed(maybeBuffer)) |
1183 | return failure(); |
1184 | Value buffer = *maybeBuffer; |
1185 | // We may have to cast the value before yielding it. |
1186 | if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>( |
1187 | yieldOp->getParentOp())) { |
1188 | FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
1189 | yieldOp->getParentOp()->getResult(it.index()), options, state); |
1190 | if (failed(resultType)) |
1191 | return failure(); |
1192 | buffer = castBuffer(rewriter, buffer, *resultType); |
1193 | } else if (auto whileOp = |
1194 | dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) { |
1195 | FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
1196 | whileOp.getBeforeArguments()[it.index()], options, state); |
1197 | if (failed(resultType)) |
1198 | return failure(); |
1199 | buffer = castBuffer(rewriter, buffer, *resultType); |
1200 | } |
1201 | newResults.push_back(buffer); |
1202 | } else { |
1203 | newResults.push_back(value); |
1204 | } |
1205 | } |
1206 | |
1207 | replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults); |
1208 | return success(); |
1209 | } |
1210 | }; |
1211 | |
1212 | /// Bufferization of ForallOp. This also bufferizes the terminator of the |
1213 | /// region. There are op interfaces for the terminators (InParallelOp |
1214 | /// and ParallelInsertSliceOp), but these are only used during analysis. Not |
1215 | /// for bufferization. |
1216 | struct ForallOpInterface |
1217 | : public BufferizableOpInterface::ExternalModel<ForallOpInterface, |
1218 | ForallOp> { |
1219 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
1220 | const AnalysisState &state) const { |
1221 | // All tensor operands to `scf.forall` are `shared_outs` and all |
1222 | // shared outs are assumed to be read by the loop. This does not |
1223 | // account for the case where the entire value is over-written, |
1224 | // but being conservative here. |
1225 | return true; |
1226 | } |
1227 | |
1228 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
1229 | const AnalysisState &state) const { |
1230 | // Outputs of scf::ForallOps are always considered as a write. |
1231 | return true; |
1232 | } |
1233 | |
1234 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
1235 | const AnalysisState &state) const { |
1236 | auto forallOp = cast<ForallOp>(op); |
1237 | return { |
1238 | {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}}; |
1239 | } |
1240 | |
1241 | bool isWritable(Operation *op, Value value, |
1242 | const AnalysisState &state) const { |
1243 | return true; |
1244 | } |
1245 | |
1246 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
1247 | const BufferizationOptions &options, |
1248 | BufferizationState &state) const { |
1249 | OpBuilder::InsertionGuard guard(rewriter); |
1250 | auto forallOp = cast<ForallOp>(op); |
1251 | int64_t rank = forallOp.getRank(); |
1252 | |
1253 | // Get buffers for all output operands. |
1254 | SmallVector<Value> buffers; |
1255 | for (Value out : forallOp.getOutputs()) { |
1256 | FailureOr<Value> buffer = getBuffer(rewriter, out, options, state); |
1257 | if (failed(buffer)) |
1258 | return failure(); |
1259 | buffers.push_back(*buffer); |
1260 | } |
1261 | |
1262 | // Use buffers instead of block arguments. |
1263 | rewriter.setInsertionPointToStart(forallOp.getBody()); |
1264 | for (const auto &it : llvm::zip( |
1265 | forallOp.getBody()->getArguments().drop_front(rank), buffers)) { |
1266 | BlockArgument bbArg = std::get<0>(it); |
1267 | Value buffer = std::get<1>(it); |
1268 | Value bufferAsTensor = rewriter.create<ToTensorOp>( |
1269 | forallOp.getLoc(), bbArg.getType(), buffer); |
1270 | bbArg.replaceAllUsesWith(bufferAsTensor); |
1271 | } |
1272 | |
1273 | // Create new ForallOp without any results and drop the automatically |
1274 | // introduced terminator. |
1275 | rewriter.setInsertionPoint(forallOp); |
1276 | ForallOp newForallOp; |
1277 | newForallOp = rewriter.create<ForallOp>( |
1278 | forallOp.getLoc(), forallOp.getMixedLowerBound(), |
1279 | forallOp.getMixedUpperBound(), forallOp.getMixedStep(), |
1280 | /*outputs=*/ValueRange(), forallOp.getMapping()); |
1281 | |
1282 | // Keep discardable attributes from the original op. |
1283 | newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); |
1284 | |
1285 | rewriter.eraseOp(op: newForallOp.getBody()->getTerminator()); |
1286 | |
1287 | // Move over block contents of the old op. |
1288 | SmallVector<Value> replacementBbArgs; |
1289 | replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(), |
1290 | newForallOp.getBody()->getArguments().end()); |
1291 | replacementBbArgs.append(forallOp.getOutputs().size(), Value()); |
1292 | rewriter.mergeBlocks(source: forallOp.getBody(), dest: newForallOp.getBody(), |
1293 | argValues: replacementBbArgs); |
1294 | |
1295 | // Remove the old op and replace all of its uses. |
1296 | replaceOpWithBufferizedValues(rewriter, op, values: buffers); |
1297 | |
1298 | return success(); |
1299 | } |
1300 | |
1301 | FailureOr<BaseMemRefType> |
1302 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
1303 | const BufferizationState &state, |
1304 | SmallVector<Value> &invocationStack) const { |
1305 | auto forallOp = cast<ForallOp>(op); |
1306 | |
1307 | if (auto bbArg = dyn_cast<BlockArgument>(Val&: value)) |
1308 | // A tensor block argument has the same bufferized type as the |
1309 | // corresponding output operand. |
1310 | return bufferization::getBufferType( |
1311 | value: forallOp.getTiedOpOperand(bbArg)->get(), options, state, |
1312 | invocationStack); |
1313 | |
1314 | // The bufferized result type is the same as the bufferized type of the |
1315 | // corresponding output operand. |
1316 | return bufferization::getBufferType( |
1317 | value: forallOp.getOutputs()[cast<OpResult>(Val&: value).getResultNumber()], options, |
1318 | state, invocationStack); |
1319 | } |
1320 | |
1321 | bool isRepetitiveRegion(Operation *op, unsigned index) const { |
1322 | auto forallOp = cast<ForallOp>(op); |
1323 | |
1324 | // This op is repetitive if it has 1 or more steps. |
1325 | // If the control variables are dynamic, it is also considered so. |
1326 | for (auto [lb, ub, step] : |
1327 | llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), |
1328 | forallOp.getMixedStep())) { |
1329 | std::optional<int64_t> lbConstant = getConstantIntValue(lb); |
1330 | if (!lbConstant) |
1331 | return true; |
1332 | |
1333 | std::optional<int64_t> ubConstant = getConstantIntValue(ub); |
1334 | if (!ubConstant) |
1335 | return true; |
1336 | |
1337 | std::optional<int64_t> stepConstant = getConstantIntValue(step); |
1338 | if (!stepConstant) |
1339 | return true; |
1340 | |
1341 | if (*lbConstant + *stepConstant < *ubConstant) |
1342 | return true; |
1343 | } |
1344 | return false; |
1345 | } |
1346 | |
1347 | bool isParallelRegion(Operation *op, unsigned index) const { |
1348 | return isRepetitiveRegion(op, index); |
1349 | } |
1350 | }; |
1351 | |
1352 | /// Nothing to do for InParallelOp. |
1353 | struct InParallelOpInterface |
1354 | : public BufferizableOpInterface::ExternalModel<InParallelOpInterface, |
1355 | InParallelOp> { |
1356 | LogicalResult bufferize(Operation *op, RewriterBase &b, |
1357 | const BufferizationOptions &options, |
1358 | BufferizationState &state) const { |
1359 | llvm_unreachable("op does not have any tensor OpOperands / OpResults"); |
1360 | return failure(); |
1361 | } |
1362 | }; |
1363 | |
1364 | } // namespace |
1365 | } // namespace scf |
1366 | } // namespace mlir |
1367 | |
1368 | void mlir::scf::registerBufferizableOpInterfaceExternalModels( |
1369 | DialectRegistry ®istry) { |
1370 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) { |
1371 | ConditionOp::attachInterface<ConditionOpInterface>(*ctx); |
1372 | ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); |
1373 | ForOp::attachInterface<ForOpInterface>(*ctx); |
1374 | IfOp::attachInterface<IfOpInterface>(*ctx); |
1375 | IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx); |
1376 | ForallOp::attachInterface<ForallOpInterface>(*ctx); |
1377 | InParallelOp::attachInterface<InParallelOpInterface>(*ctx); |
1378 | WhileOp::attachInterface<WhileOpInterface>(*ctx); |
1379 | YieldOp::attachInterface<YieldOpInterface>(*ctx); |
1380 | }); |
1381 | } |
1382 |
Definitions
- castBuffer
- doesNotAliasExternalValue
- ConditionOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- mustBufferizeInPlace
- bufferize
- getUniqueYieldOp
- ExecuteRegionOpInterface
- supportsUnstructuredControlFlow
- isWritable
- verifyAnalysis
- getAliasingOpOperands
- bufferize
- IfOpInterface
- getAliasingOpOperands
- bufferize
- getBufferType
- IndexSwitchOpInterface
- getAliasingOpOperands
- bufferize
- getBufferType
- getTensorIndices
- getEquivalentBuffers
- getBuffers
- getBbArgReplacements
- computeLoopRegionIterArgBufferType
- mayHaveZeroIterations
- ForOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- bufferRelation
- isWritable
- resolveConflicts
- getBufferType
- bufferize
- verifyAnalysis
- WhileOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- bufferRelation
- isWritable
- resolveConflicts
- bufferize
- getBufferType
- verifyAnalysis
- YieldOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- mustBufferizeInPlace
- bufferize
- ForallOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- isWritable
- bufferize
- getBufferType
- isRepetitiveRegion
- isParallelRegion
- InParallelOpInterface
- bufferize
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more