1 | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
14 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
15 | #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" |
16 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/SCF/IR/SCF.h" |
19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
20 | #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" |
21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
22 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
23 | #include "mlir/IR/Dialect.h" |
24 | #include "mlir/IR/Operation.h" |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::bufferization; |
28 | using namespace mlir::tensor; |
29 | |
30 | namespace mlir { |
31 | namespace tensor { |
32 | namespace { |
33 | |
34 | struct CastOpInterface |
35 | : public BufferizableOpInterface::ExternalModel<CastOpInterface, |
36 | tensor::CastOp> { |
37 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
38 | const AnalysisState &state) const { |
39 | return false; |
40 | } |
41 | |
42 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
43 | const AnalysisState &state) const { |
44 | return false; |
45 | } |
46 | |
47 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
48 | const AnalysisState &state) const { |
49 | return {{op->getResult(idx: 0), BufferRelation::Equivalent}}; |
50 | } |
51 | |
52 | FailureOr<BaseMemRefType> |
53 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
54 | const BufferizationState &state, |
55 | SmallVector<Value> &invocationStack) const { |
56 | auto castOp = cast<tensor::CastOp>(op); |
57 | auto maybeSrcBufferType = bufferization::getBufferType( |
58 | value: castOp.getSource(), options, state, invocationStack); |
59 | if (failed(maybeSrcBufferType)) |
60 | return failure(); |
61 | Attribute memorySpace = maybeSrcBufferType->getMemorySpace(); |
62 | |
63 | // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref |
64 | // type in case the input is an unranked tensor type. |
65 | |
66 | // Case 1: Casting an unranked tensor |
67 | if (isa<UnrankedTensorType>(castOp.getSource().getType())) { |
68 | // When casting to a ranked tensor, we cannot infer any static offset or |
69 | // strides from the source. Assume fully dynamic. |
70 | return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); |
71 | } |
72 | |
73 | // Case 2: Casting to an unranked tensor type |
74 | if (isa<UnrankedTensorType>(castOp.getType())) { |
75 | return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); |
76 | } |
77 | |
78 | // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not |
79 | // change. |
80 | auto rankedResultType = cast<RankedTensorType>(castOp.getType()); |
81 | return MemRefType::get( |
82 | rankedResultType.getShape(), rankedResultType.getElementType(), |
83 | llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace); |
84 | } |
85 | |
86 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
87 | const BufferizationOptions &options, |
88 | BufferizationState &state) const { |
89 | auto castOp = cast<tensor::CastOp>(op); |
90 | |
91 | // The result buffer still has the old (pre-cast) type. |
92 | FailureOr<Value> resultBuffer = |
93 | getBuffer(rewriter, castOp.getSource(), options, state); |
94 | if (failed(Result: resultBuffer)) |
95 | return failure(); |
96 | |
97 | // Compute the new type. |
98 | auto resultMemRefType = |
99 | bufferization::getBufferType(value: castOp.getResult(), options, state); |
100 | if (failed(resultMemRefType)) |
101 | return failure(); |
102 | if (resultBuffer->getType() == *resultMemRefType) { |
103 | // This cast is a no-op. |
104 | replaceOpWithBufferizedValues(rewriter, op, values: *resultBuffer); |
105 | return success(); |
106 | } |
107 | |
108 | // Replace the op with a memref.cast. |
109 | assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), |
110 | *resultMemRefType) && |
111 | "CallOp::bufferize: cast incompatible"); |
112 | replaceOpWithNewBufferizedOp<memref::CastOp>( |
113 | rewriter, op, *resultMemRefType, *resultBuffer); |
114 | |
115 | return success(); |
116 | } |
117 | }; |
118 | |
119 | /// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. |
120 | struct CollapseShapeOpInterface |
121 | : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface, |
122 | tensor::CollapseShapeOp> { |
123 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
124 | const AnalysisState &state) const { |
125 | // tensor.collapse_shape may reallocate, at which point the source buffer is |
126 | // copied. I.e., there will be a memory read side effect on the bufferized |
127 | // source. This function conservatively returns "true" because whether a |
128 | // copy will be created or not is not known at this point. |
129 | return true; |
130 | } |
131 | |
132 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
133 | const AnalysisState &state) const { |
134 | return false; |
135 | } |
136 | |
137 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
138 | const AnalysisState &state) const { |
139 | // TODO: CollapseShapeOp may allocate at runtime. |
140 | return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}}; |
141 | } |
142 | |
143 | FailureOr<BaseMemRefType> |
144 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
145 | const BufferizationState &state, |
146 | SmallVector<Value> &invocationStack) const { |
147 | auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); |
148 | auto maybeSrcBufferType = bufferization::getBufferType( |
149 | value: collapseShapeOp.getSrc(), options, state, invocationStack); |
150 | if (failed(maybeSrcBufferType)) |
151 | return failure(); |
152 | auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType); |
153 | bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( |
154 | srcBufferType, collapseShapeOp.getReassociationIndices()); |
155 | |
156 | if (!canBeCollapsed) { |
157 | // If dims cannot be collapsed, this op bufferizes to a new allocation. |
158 | RankedTensorType tensorResultType = collapseShapeOp.getResultType(); |
159 | return bufferization::getMemRefTypeWithStaticIdentityLayout( |
160 | tensorType: tensorResultType, memorySpace: srcBufferType.getMemorySpace()); |
161 | } |
162 | |
163 | return memref::CollapseShapeOp::computeCollapsedType( |
164 | srcBufferType, collapseShapeOp.getReassociationIndices()); |
165 | } |
166 | |
167 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
168 | const BufferizationOptions &options, |
169 | BufferizationState &state) const { |
170 | auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op); |
171 | RankedTensorType tensorResultType = collapseShapeOp.getResultType(); |
172 | FailureOr<Value> maybeBuffer = |
173 | getBuffer(rewriter, collapseShapeOp.getSrc(), options, state); |
174 | if (failed(Result: maybeBuffer)) |
175 | return failure(); |
176 | Value buffer = *maybeBuffer; |
177 | auto bufferType = cast<MemRefType>(buffer.getType()); |
178 | |
179 | if (tensorResultType.getRank() == 0) { |
180 | // 0-d collapses must go through a different op builder. |
181 | MemRefType resultType; |
182 | |
183 | if (bufferType.getLayout().isIdentity()) { |
184 | // Standard layout: result type has no offset. |
185 | MemRefLayoutAttrInterface layout; |
186 | resultType = MemRefType::get({}, tensorResultType.getElementType(), |
187 | layout, bufferType.getMemorySpace()); |
188 | } else { |
189 | // Source memref has a layout map: result type has the same offset as |
190 | // the source type. |
191 | SmallVector<int64_t> strides; |
192 | int64_t offset; |
193 | if (failed(bufferType.getStridesAndOffset(strides, offset))) |
194 | return failure(); |
195 | resultType = MemRefType::get( |
196 | {}, tensorResultType.getElementType(), |
197 | StridedLayoutAttr::get(op->getContext(), offset, {}), |
198 | bufferType.getMemorySpace()); |
199 | } |
200 | |
201 | replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( |
202 | rewriter, op, resultType, buffer, collapseShapeOp.getReassociation()); |
203 | return success(); |
204 | } |
205 | |
206 | // If the dims are not collapsible (due to an incompatible source layout |
207 | // map), force an out-of-place bufferization, i.e., a buffer copy. This |
208 | // newly allocated buffer will have no layout map and thus be collapsible. |
209 | bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( |
210 | bufferType, collapseShapeOp.getReassociationIndices()); |
211 | if (!canBeCollapsed) { |
212 | // TODO: Create alloc_tensor ops during TensorCopyInsertion. |
213 | AnalysisState analysisState(options); |
214 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
215 | rewriter, op->getLoc(), collapseShapeOp.getSrc(), options, state); |
216 | if (failed(Result: tensorAlloc)) |
217 | return failure(); |
218 | auto memrefType = |
219 | MemRefType::get(collapseShapeOp.getSrcType().getShape(), |
220 | collapseShapeOp.getSrcType().getElementType(), |
221 | AffineMap(), bufferType.getMemorySpace()); |
222 | buffer = rewriter.create<bufferization::ToBufferOp>( |
223 | op->getLoc(), memrefType, *tensorAlloc); |
224 | } |
225 | |
226 | // Result type is inferred by the builder. |
227 | replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>( |
228 | rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); |
229 | return success(); |
230 | } |
231 | }; |
232 | |
233 | /// Bufferization of tensor.dim. Replace with memref.dim. |
234 | struct DimOpInterface |
235 | : public BufferizableOpInterface::ExternalModel<DimOpInterface, |
236 | tensor::DimOp> { |
237 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
238 | const AnalysisState &state) const { |
239 | // The op reads the tensor's metadata but not its contents. |
240 | return false; |
241 | } |
242 | |
243 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
244 | const AnalysisState &state) const { |
245 | return false; |
246 | } |
247 | |
248 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
249 | const AnalysisState &state) const { |
250 | return {}; |
251 | } |
252 | |
253 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
254 | const BufferizationOptions &options, |
255 | BufferizationState &state) const { |
256 | auto dimOp = cast<tensor::DimOp>(op); |
257 | FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options, state); |
258 | if (failed(Result: v)) |
259 | return failure(); |
260 | replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v, |
261 | dimOp.getIndex()); |
262 | return success(); |
263 | } |
264 | }; |
265 | |
266 | /// Bufferization of "tensor.empty". Replace with "bufferization.alloc_tensor". |
267 | struct EmptyOpInterface |
268 | : public BufferizableOpInterface::ExternalModel<EmptyOpInterface, |
269 | tensor::EmptyOp> { |
270 | bool bufferizesToAllocation(Operation *op, Value value) const { return true; } |
271 | |
272 | bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult, |
273 | const AnalysisState &state) const { |
274 | // The returned tensor does not have specified contents. |
275 | return false; |
276 | } |
277 | |
278 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
279 | const BufferizationOptions &options, |
280 | BufferizationState &state) const { |
281 | auto emptyOp = cast<tensor::EmptyOp>(op); |
282 | |
283 | // Optimization: Fold away the op if it has no uses. |
284 | if (op->getUses().empty()) { |
285 | rewriter.eraseOp(op); |
286 | return success(); |
287 | } |
288 | |
289 | // Allocate a tensor. This emits a "bufferization.alloc_tensor" op. |
290 | FailureOr<Value> allocTensor = allocateTensorForShapedValue( |
291 | rewriter, op->getLoc(), emptyOp.getResult(), options, state, |
292 | /*copy=*/false); |
293 | if (failed(Result: allocTensor)) |
294 | return failure(); |
295 | rewriter.replaceOp(op, newValues: *allocTensor); |
296 | return success(); |
297 | } |
298 | }; |
299 | |
300 | /// Bufferization of tensor.expand_shape. Replace with memref.expand_shape. |
301 | struct ExpandShapeOpInterface |
302 | : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface, |
303 | tensor::ExpandShapeOp> { |
304 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
305 | const AnalysisState &state) const { |
306 | // In contrast to tensor.collapse_shape, this op can always be bufferized |
307 | // without a copy. |
308 | return false; |
309 | } |
310 | |
311 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
312 | const AnalysisState &state) const { |
313 | return false; |
314 | } |
315 | |
316 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
317 | const AnalysisState &state) const { |
318 | return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}}; |
319 | } |
320 | |
321 | FailureOr<BaseMemRefType> |
322 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
323 | const BufferizationState &state, |
324 | SmallVector<Value> &invocationStack) const { |
325 | auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); |
326 | auto maybeSrcBufferType = bufferization::getBufferType( |
327 | value: expandShapeOp.getSrc(), options, state, invocationStack); |
328 | if (failed(maybeSrcBufferType)) |
329 | return failure(); |
330 | auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType); |
331 | auto maybeResultType = memref::ExpandShapeOp::computeExpandedType( |
332 | srcBufferType, expandShapeOp.getResultType().getShape(), |
333 | expandShapeOp.getReassociationIndices()); |
334 | if (failed(maybeResultType)) |
335 | return failure(); |
336 | return *maybeResultType; |
337 | } |
338 | |
339 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
340 | const BufferizationOptions &options, |
341 | BufferizationState &state) const { |
342 | auto expandShapeOp = cast<tensor::ExpandShapeOp>(op); |
343 | auto tensorResultType = expandShapeOp.getResultType(); |
344 | FailureOr<Value> buffer = |
345 | getBuffer(rewriter, expandShapeOp.getSrc(), options, state); |
346 | if (failed(Result: buffer)) |
347 | return failure(); |
348 | |
349 | auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>( |
350 | op->getLoc(), tensorResultType.getShape(), *buffer, |
351 | expandShapeOp.getReassociationIndices(), |
352 | expandShapeOp.getMixedOutputShape()); |
353 | replaceOpWithBufferizedValues(rewriter, op, |
354 | memrefExpandShape->getResults()); |
355 | return success(); |
356 | } |
357 | }; |
358 | |
359 | /// Bufferization of tensor.extract_slice. Replace with memref.subview. |
360 | struct ExtractSliceOpInterface |
361 | : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface, |
362 | tensor::ExtractSliceOp> { |
363 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
364 | const AnalysisState &state) const { |
365 | return false; |
366 | } |
367 | |
368 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
369 | const AnalysisState &state) const { |
370 | return false; |
371 | } |
372 | |
373 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
374 | const AnalysisState &state) const { |
375 | return {{op->getOpResult(idx: 0), BufferRelation::Unknown}}; |
376 | } |
377 | |
378 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
379 | const BufferizationOptions &options, |
380 | BufferizationState &state) const { |
381 | auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); |
382 | SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); |
383 | SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); |
384 | SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); |
385 | Location loc = extractSliceOp.getLoc(); |
386 | |
387 | // Get source buffer. |
388 | FailureOr<Value> srcMemref = |
389 | getBuffer(rewriter, extractSliceOp.getSource(), options, state); |
390 | if (failed(Result: srcMemref)) |
391 | return failure(); |
392 | |
393 | // Take a subview of the source buffer. |
394 | auto resultMemrefType = bufferization::getBufferType( |
395 | value: extractSliceOp.getResult(), options, state); |
396 | if (failed(resultMemrefType)) |
397 | return failure(); |
398 | Value subView = rewriter.create<memref::SubViewOp>( |
399 | loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, |
400 | mixedOffsets, mixedSizes, mixedStrides); |
401 | |
402 | replaceOpWithBufferizedValues(rewriter, op, values: subView); |
403 | return success(); |
404 | } |
405 | |
406 | FailureOr<BaseMemRefType> |
407 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
408 | const BufferizationState &state, |
409 | SmallVector<Value> &invocationStack) const { |
410 | auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); |
411 | assert(value == extractSliceOp.getResult() && "invalid value"); |
412 | auto srcMemrefType = bufferization::getBufferType( |
413 | value: extractSliceOp.getSource(), options, state, invocationStack); |
414 | if (failed(srcMemrefType)) |
415 | return failure(); |
416 | SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); |
417 | SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); |
418 | SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); |
419 | return memref::SubViewOp::inferRankReducedResultType( |
420 | extractSliceOp.getType().getShape(), |
421 | llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes, |
422 | mixedStrides); |
423 | } |
424 | }; |
425 | |
426 | /// Bufferization of tensor.extract. Replace with memref.load. |
427 | struct ExtractOpInterface |
428 | : public BufferizableOpInterface::ExternalModel<ExtractOpInterface, |
429 | tensor::ExtractOp> { |
430 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
431 | const AnalysisState &state) const { |
432 | return true; |
433 | } |
434 | |
435 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
436 | const AnalysisState &state) const { |
437 | return false; |
438 | } |
439 | |
440 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
441 | const AnalysisState &state) const { |
442 | return {}; |
443 | } |
444 | |
445 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
446 | const BufferizationOptions &options, |
447 | BufferizationState &state) const { |
448 | auto extractOp = cast<tensor::ExtractOp>(op); |
449 | FailureOr<Value> srcMemref = |
450 | getBuffer(rewriter, extractOp.getTensor(), options, state); |
451 | if (failed(Result: srcMemref)) |
452 | return failure(); |
453 | replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref, |
454 | extractOp.getIndices()); |
455 | return success(); |
456 | } |
457 | }; |
458 | |
459 | // Implements backtracking to traverse indices of the output buffer while |
460 | // iterating over op.elements(). |
461 | static void createStores(RewriterBase &rewriter, Location loc, int dim, |
462 | Value buffer, ArrayRef<int64_t> shape, |
463 | ArrayRef<Value> constants, |
464 | OperandRange::iterator &elementIt, |
465 | SmallVectorImpl<Value> &indices) { |
466 | if (dim == static_cast<int>(shape.size()) - 1) { |
467 | for (int i = 0; i < shape.back(); ++i) { |
468 | indices.back() = constants[i]; |
469 | rewriter.create<memref::StoreOp>(loc, *elementIt, buffer, indices); |
470 | ++elementIt; |
471 | } |
472 | return; |
473 | } |
474 | for (int i = 0; i < shape[dim]; ++i) { |
475 | indices[dim] = constants[i]; |
476 | createStores(rewriter, loc, dim: dim + 1, buffer, shape, constants, elementIt, |
477 | indices); |
478 | } |
479 | } |
480 | |
481 | /// Bufferization of tensor.from_elements. |
482 | struct FromElementsOpInterface |
483 | : public BufferizableOpInterface::ExternalModel<FromElementsOpInterface, |
484 | tensor::FromElementsOp> { |
485 | |
486 | bool bufferizesToAllocation(Operation *op, Value value) const { return true; } |
487 | |
488 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
489 | const BufferizationOptions &options, |
490 | BufferizationState &state) const { |
491 | auto fromElementsOp = cast<tensor::FromElementsOp>(op); |
492 | auto tensorType = cast<RankedTensorType>(fromElementsOp.getType()); |
493 | |
494 | // Allocate a buffer for the result. |
495 | Location loc = op->getLoc(); |
496 | auto shape = tensorType.getShape(); |
497 | // TODO: Create alloc_tensor ops during TensorCopyInsertion. |
498 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
499 | rewriter, loc, fromElementsOp.getResult(), options, state, |
500 | /*copy=*/false); |
501 | if (failed(Result: tensorAlloc)) |
502 | return failure(); |
503 | FailureOr<BaseMemRefType> memrefType = |
504 | bufferization::getBufferType(value: *tensorAlloc, options, state); |
505 | if (failed(Result: memrefType)) |
506 | return failure(); |
507 | Value buffer = rewriter.create<bufferization::ToBufferOp>( |
508 | op->getLoc(), *memrefType, *tensorAlloc); |
509 | |
510 | // Case: tensor<0xelem_type>. |
511 | if (fromElementsOp.getElements().empty()) { |
512 | replaceOpWithBufferizedValues(rewriter, op, values: buffer); |
513 | return success(); |
514 | } |
515 | |
516 | // Case: tensor<elem_type>. |
517 | if (shape.empty()) { |
518 | rewriter.create<memref::StoreOp>( |
519 | loc, fromElementsOp.getElements().front(), buffer); |
520 | replaceOpWithBufferizedValues(rewriter, op, values: buffer); |
521 | return success(); |
522 | } |
523 | |
524 | // Create constants for the range of possible indices [0, max{shape_i}). |
525 | auto maxDim = *llvm::max_element(shape); |
526 | SmallVector<Value, 2> constants; |
527 | constants.reserve(N: maxDim); |
528 | for (int i = 0; i < maxDim; ++i) |
529 | constants.push_back(rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i)); |
530 | |
531 | // Traverse all `elements` and create `memref.store` ops. |
532 | auto elementIt = fromElementsOp.getElements().begin(); |
533 | SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); |
534 | createStores(rewriter, loc, /*dim=*/0, buffer, shape, constants, elementIt, |
535 | indices); |
536 | |
537 | replaceOpWithBufferizedValues(rewriter, op, values: buffer); |
538 | |
539 | return success(); |
540 | } |
541 | }; |
542 | |
543 | /// Lower the body of a tensor.generate like op (one index-typed bbArg per dim). |
544 | /// Such ops are lowered to linalg.map with the given tensor as a destination. |
545 | /// |
546 | /// Example: |
547 | /// ``` |
548 | /// %r = tensor.generate %x, %y { |
549 | /// ^bb0(%arg0: index, %arg1: index): |
550 | /// %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index) |
551 | /// tensor.yield %0 : index |
552 | /// } : tensor<?x?xindex> |
553 | /// ``` |
554 | /// |
555 | /// Is lowered to: |
556 | /// ``` |
557 | /// linalg.map ins() outs(%dest) { |
558 | /// %d0 = linalg.index 0 : index |
559 | /// %d1 = linalg.index 1 : index |
560 | /// %0 = "some_op"(%d0, %d1) : (index, index) -> (index) |
561 | /// linalg.yield %0 : index |
562 | /// } |
563 | /// ``` |
564 | static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc, |
565 | Value tensorDestination, |
566 | ValueRange dynamicSizes, |
567 | Region &generateBody) { |
568 | assert(generateBody.hasOneBlock() && "expected body with single block"); |
569 | auto tensorType = cast<RankedTensorType>(tensorDestination.getType()); |
570 | assert(generateBody.getNumArguments() == tensorType.getRank() && |
571 | "rank mismatch"); |
572 | |
573 | // Create linalg::MapOp. |
574 | OpBuilder::InsertionGuard g(rewriter); |
575 | auto linalgOp = |
576 | rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(), |
577 | /*init=*/tensorDestination); |
578 | Block &linalgBody = linalgOp.getMapper().emplaceBlock(); |
579 | |
580 | // Create linalg::IndexOps. |
581 | rewriter.setInsertionPointToStart(&linalgBody); |
582 | SmallVector<Value> indices; |
583 | for (int64_t dim = 0; dim < tensorType.getRank(); ++dim) |
584 | indices.push_back(rewriter.create<linalg::IndexOp>(loc, dim)); |
585 | |
586 | // Move over body. |
587 | rewriter.mergeBlocks(source: &generateBody.front(), dest: &linalgBody, argValues: indices); |
588 | auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator()); |
589 | rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue()); |
590 | |
591 | return linalgOp.getResult()[0]; |
592 | } |
593 | |
594 | /// Bufferization of tensor.generate. |
595 | struct GenerateOpInterface |
596 | : public BufferizableOpInterface::ExternalModel<GenerateOpInterface, |
597 | tensor::GenerateOp> { |
598 | |
599 | bool bufferizesToAllocation(Operation *op, Value value) const { return true; } |
600 | |
601 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
602 | const BufferizationOptions &options, |
603 | BufferizationState &state) const { |
604 | auto generateOp = cast<tensor::GenerateOp>(op); |
605 | |
606 | auto type = generateOp.getResult().getType(); |
607 | |
608 | // TODO: Implement memory space for this op. |
609 | if (options.defaultMemorySpaceFn(type) != Attribute()) |
610 | return op->emitError(message: "memory space not implemented yet"); |
611 | |
612 | // Allocate memory. |
613 | Location loc = op->getLoc(); |
614 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
615 | rewriter, loc, generateOp.getResult(), options, state, |
616 | /*copy=*/false); |
617 | if (failed(Result: tensorAlloc)) |
618 | return failure(); |
619 | |
620 | Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc, |
621 | generateOp.getDynamicExtents(), |
622 | generateOp.getBody()); |
623 | rewriter.replaceOp(generateOp, result); |
624 | |
625 | return success(); |
626 | } |
627 | }; |
628 | |
629 | /// Bufferization of tensor.insert. Replace with memref.store. |
630 | /// |
631 | /// Note: DstBufferizableOpInterfaceExternalModel provides many default method |
632 | /// implementations for DestinationStyle ops. |
633 | struct InsertOpInterface |
634 | : public DstBufferizableOpInterfaceExternalModel<InsertOpInterface, |
635 | tensor::InsertOp> { |
636 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
637 | const BufferizationOptions &options, |
638 | BufferizationState &state) const { |
639 | auto insertOp = cast<tensor::InsertOp>(op); |
640 | FailureOr<Value> destMemref = |
641 | getBuffer(rewriter, insertOp.getDest(), options, state); |
642 | if (failed(Result: destMemref)) |
643 | return failure(); |
644 | rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(), |
645 | *destMemref, insertOp.getIndices()); |
646 | replaceOpWithBufferizedValues(rewriter, op, values: *destMemref); |
647 | return success(); |
648 | } |
649 | }; |
650 | |
651 | template <typename InsertOpTy> |
652 | static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp, |
653 | OpOperand &opOperand) { |
654 | // The source is always read. |
655 | if (opOperand == insertSliceOp.getSourceMutable()) |
656 | return true; |
657 | |
658 | // For the destination, it depends... |
659 | assert(opOperand == insertSliceOp.getDestMutable() && "expected dest"); |
660 | |
661 | // Dest is not read if it is entirely overwritten. E.g.: |
662 | // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> |
663 | bool allOffsetsZero = |
664 | llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger); |
665 | RankedTensorType destType = insertSliceOp.getDestType(); |
666 | bool sizesMatchDestSizes = |
667 | areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape()); |
668 | bool allStridesOne = |
669 | areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1); |
670 | return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne); |
671 | } |
672 | |
673 | /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under |
674 | /// certain circumstances, this op can also be a no-op. |
675 | /// |
676 | /// Note: DstBufferizableOpInterfaceExternalModel provides many default method |
677 | /// implementations for DestinationStyle ops. |
678 | struct InsertSliceOpInterface |
679 | : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface, |
680 | tensor::InsertSliceOp> { |
681 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
682 | const AnalysisState &state) const { |
683 | return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op), |
684 | opOperand); |
685 | } |
686 | |
687 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
688 | const BufferizationOptions &options, |
689 | BufferizationState &state) const { |
690 | // insert_slice ops arise from tiling and bufferizing them out-of-place is |
691 | // generally a deal breaker. When used with loops, this ends up cloning the |
692 | // whole tensor on every single iteration and is a symptom of a |
693 | // catastrophically bad scheduling decision. |
694 | // TODO: be very loud about it or even consider failing the pass. |
695 | auto insertSliceOp = cast<tensor::InsertSliceOp>(op); |
696 | SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); |
697 | SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); |
698 | SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); |
699 | Location loc = insertSliceOp.getLoc(); |
700 | |
701 | // Get destination buffer. |
702 | FailureOr<Value> dstMemref = |
703 | getBuffer(rewriter, insertSliceOp.getDest(), options, state); |
704 | if (failed(Result: dstMemref)) |
705 | return failure(); |
706 | |
707 | // Take a subview of the destination buffer. |
708 | auto dstMemrefType = cast<MemRefType>(dstMemref->getType()); |
709 | MemRefType subviewMemRefType = |
710 | memref::SubViewOp::inferRankReducedResultType( |
711 | insertSliceOp.getSourceType().getShape(), dstMemrefType, |
712 | mixedOffsets, mixedSizes, mixedStrides); |
713 | Value subView = rewriter.create<memref::SubViewOp>( |
714 | loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, |
715 | mixedStrides); |
716 | |
717 | // Copy tensor. If this tensor.insert_slice has a matching |
718 | // tensor.extract_slice, the copy operation will eventually fold away. |
719 | FailureOr<Value> srcMemref = |
720 | getBuffer(rewriter, insertSliceOp.getSource(), options, state); |
721 | if (failed(Result: srcMemref)) |
722 | return failure(); |
723 | if (failed(Result: options.createMemCpy(b&: rewriter, loc, from: *srcMemref, to: subView))) |
724 | return failure(); |
725 | |
726 | replaceOpWithBufferizedValues(rewriter, op, values: *dstMemref); |
727 | return success(); |
728 | } |
729 | }; |
730 | |
731 | /// Bufferization of tensor.pad. Replace with bufferization.alloc_tensor + |
732 | /// linalg.map + insert_slice. |
733 | /// For best performance, vectorize before bufferization (better performance in |
734 | /// case of padding with a constant). |
735 | struct PadOpInterface |
736 | : public BufferizableOpInterface::ExternalModel<PadOpInterface, |
737 | tensor::PadOp> { |
738 | bool bufferizesToAllocation(Operation *op, Value value) const { return true; } |
739 | |
740 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
741 | const AnalysisState &state) const { |
742 | return true; |
743 | } |
744 | |
745 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
746 | const AnalysisState &state) const { |
747 | return false; |
748 | } |
749 | |
750 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
751 | const AnalysisState &state) const { |
752 | return {}; |
753 | } |
754 | |
755 | FailureOr<BaseMemRefType> |
756 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
757 | const BufferizationState &state, |
758 | SmallVector<Value> &invocationStack) const { |
759 | // Infer memory space from the source tensor. |
760 | auto padOp = cast<tensor::PadOp>(op); |
761 | auto maybeSrcBufferType = bufferization::getBufferType( |
762 | value: padOp.getSource(), options, state, invocationStack); |
763 | if (failed(maybeSrcBufferType)) |
764 | return failure(); |
765 | MemRefLayoutAttrInterface layout; |
766 | return MemRefType::get(padOp.getResultType().getShape(), |
767 | padOp.getResultType().getElementType(), layout, |
768 | maybeSrcBufferType->getMemorySpace()); |
769 | } |
770 | |
771 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
772 | const BufferizationOptions &options, |
773 | BufferizationState &state) const { |
774 | auto padOp = cast<tensor::PadOp>(op); |
775 | Location loc = padOp.getLoc(); |
776 | RankedTensorType resultType = padOp.getResultType(); |
777 | RankedTensorType srcType = padOp.getSourceType(); |
778 | |
779 | auto toValue = [&](OpFoldResult ofr) { |
780 | if (auto value = dyn_cast<Value>(Val&: ofr)) |
781 | return value; |
782 | return rewriter |
783 | .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr)) |
784 | .getResult(); |
785 | }; |
786 | |
787 | // Compute dynamic result dimensions. |
788 | SmallVector<OpFoldResult> mixedLowPad = padOp.getMixedLowPad(); |
789 | SmallVector<OpFoldResult> mixedHighPad = padOp.getMixedHighPad(); |
790 | SmallVector<Value> dynamicSizes; |
791 | for (int64_t i = 0; i < resultType.getRank(); ++i) { |
792 | if (!resultType.isDynamicDim(i)) |
793 | continue; |
794 | Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i); |
795 | Value lowPad = toValue(mixedLowPad[i]); |
796 | Value highPad = toValue(mixedHighPad[i]); |
797 | AffineExpr s0, s1, s2; |
798 | bindSymbols(ctx: op->getContext(), exprs&: s0, exprs&: s1, exprs&: s2); |
799 | AffineExpr sumExpr = s0 + s1 + s2; |
800 | Value sum = rewriter.create<affine::AffineApplyOp>( |
801 | loc, sumExpr, ValueRange{srcDim, lowPad, highPad}); |
802 | dynamicSizes.push_back(Elt: sum); |
803 | } |
804 | |
805 | // Allocate a buffer for the padded result. |
806 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
807 | rewriter, loc, padOp.getResult(), options, state, |
808 | /*copy=*/false); |
809 | if (failed(Result: tensorAlloc)) |
810 | return failure(); |
811 | |
812 | // tensor::PadOp is like tensor::GenerateOp: The only difference is that |
813 | // only a part of the generated tensor is needed. For simplicity, we reuse |
814 | // the same functionality here. |
815 | Value filledBuffer = lowerGenerateLikeOpBody( |
816 | rewriter, loc, *tensorAlloc, dynamicSizes, padOp.getBodyRegion()); |
817 | |
818 | // Create tensor::InsertSliceOp. |
819 | SmallVector<OpFoldResult> sliceSizes = |
820 | getMixedSizes(rewriter, loc, padOp.getSource()); |
821 | SmallVector<OpFoldResult> sliceStrides(srcType.getRank(), |
822 | rewriter.getIndexAttr(1)); |
823 | rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( |
824 | padOp, padOp.getSource(), filledBuffer, |
825 | /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides); |
826 | |
827 | return success(); |
828 | } |
829 | }; |
830 | |
831 | /// Bufferization of tensor.rank. Replace with memref.rank. |
832 | struct RankOpInterface |
833 | : public BufferizableOpInterface::ExternalModel<RankOpInterface, |
834 | tensor::RankOp> { |
835 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
836 | const AnalysisState &state) const { |
837 | // The op reads the tensor's metadata but not its contents. |
838 | return false; |
839 | } |
840 | |
841 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
842 | const AnalysisState &state) const { |
843 | return false; |
844 | } |
845 | |
846 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
847 | const AnalysisState &state) const { |
848 | return {}; |
849 | } |
850 | |
851 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
852 | const BufferizationOptions &options, |
853 | BufferizationState &state) const { |
854 | auto rankOp = cast<tensor::RankOp>(op); |
855 | FailureOr<Value> v = |
856 | getBuffer(rewriter, rankOp.getTensor(), options, state); |
857 | if (failed(Result: v)) |
858 | return failure(); |
859 | replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(), |
860 | *v); |
861 | return success(); |
862 | } |
863 | }; |
864 | |
865 | /// Bufferization of tensor.reshape. Replace with memref.reshape. |
866 | struct ReshapeOpInterface |
867 | : public BufferizableOpInterface::ExternalModel<ReshapeOpInterface, |
868 | tensor::ReshapeOp> { |
869 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
870 | const AnalysisState &state) const { |
871 | // Depending on the layout map, the source buffer may have to be copied. |
872 | auto reshapeOp = cast<tensor::ReshapeOp>(op); |
873 | return opOperand == reshapeOp.getShapeMutable(); |
874 | } |
875 | |
876 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
877 | const AnalysisState &state) const { |
878 | return false; |
879 | } |
880 | |
881 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
882 | const AnalysisState &state) const { |
883 | // Only the 'source' operand aliases the result. |
884 | auto reshapeOp = cast<tensor::ReshapeOp>(op); |
885 | if (reshapeOp.getSourceMutable() != opOperand) |
886 | return {}; |
887 | return {{op->getOpResult(idx: 0), BufferRelation::Equivalent}}; |
888 | } |
889 | |
890 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
891 | const BufferizationOptions &options, |
892 | BufferizationState &state) const { |
893 | auto reshapeOp = cast<tensor::ReshapeOp>(op); |
894 | FailureOr<Value> srcBuffer = |
895 | getBuffer(rewriter, reshapeOp.getSource(), options, state); |
896 | FailureOr<Value> shapeBuffer = |
897 | getBuffer(rewriter, reshapeOp.getShape(), options, state); |
898 | if (failed(Result: srcBuffer) || failed(Result: shapeBuffer)) |
899 | return failure(); |
900 | auto maybeResultMemRefType = |
901 | bufferization::getBufferType(value: reshapeOp.getResult(), options, state); |
902 | if (failed(maybeResultMemRefType)) |
903 | return failure(); |
904 | |
905 | // memref.reshape requires the source buffer to have an identity layout. |
906 | // If the source memref does not have an identity layout, copy the source |
907 | // into a new buffer with an identity layout. |
908 | auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType()); |
909 | if (srcType && !srcType.getLayout().isIdentity()) { |
910 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
911 | rewriter, op->getLoc(), reshapeOp.getSource(), options, state); |
912 | if (failed(Result: tensorAlloc)) |
913 | return failure(); |
914 | auto memrefType = MemRefType::get( |
915 | srcType.getShape(), srcType.getElementType(), AffineMap(), |
916 | cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace()); |
917 | srcBuffer = rewriter |
918 | .create<bufferization::ToBufferOp>( |
919 | op->getLoc(), memrefType, *tensorAlloc) |
920 | .getResult(); |
921 | } |
922 | |
923 | replaceOpWithNewBufferizedOp<memref::ReshapeOp>( |
924 | rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer); |
925 | return success(); |
926 | } |
927 | |
928 | FailureOr<BaseMemRefType> |
929 | getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
930 | const BufferizationState &state, |
931 | SmallVector<Value> &invocationStack) const { |
932 | auto reshapeOp = cast<tensor::ReshapeOp>(op); |
933 | assert(value == reshapeOp.getResult() && "unexpected value provided"); |
934 | auto maybeSourceBufferType = bufferization::getBufferType( |
935 | value: reshapeOp.getSource(), options, state, invocationStack); |
936 | if (failed(maybeSourceBufferType)) |
937 | return failure(); |
938 | return getMemRefTypeWithStaticIdentityLayout( |
939 | reshapeOp.getResult().getType(), |
940 | cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()); |
941 | } |
942 | }; |
943 | |
944 | /// Analysis of ParallelInsertSliceOp. |
945 | struct ParallelInsertSliceOpInterface |
946 | : public BufferizableOpInterface::ExternalModel< |
947 | ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { |
948 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
949 | const AnalysisState &state) const { |
950 | return {}; |
951 | } |
952 | |
953 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
954 | const AnalysisState &state) const { |
955 | return opOperand == cast<ParallelInsertSliceOp>(op).getSourceMutable(); |
956 | } |
957 | |
958 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
959 | const AnalysisState &state) const { |
960 | auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); |
961 | return opOperand == parallelInsertSliceOp.getDestMutable(); |
962 | } |
963 | |
964 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
965 | const BufferizationOptions &options, |
966 | BufferizationState &state) const { |
967 | OpBuilder::InsertionGuard g(rewriter); |
968 | auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op); |
969 | ParallelCombiningOpInterface parallelCombiningParent = |
970 | parallelInsertSliceOp.getParallelCombiningParent(); |
971 | |
972 | // Bufferize the op outside of the parallel combining terminator. |
973 | rewriter.setInsertionPoint(parallelCombiningParent); |
974 | |
975 | // Get source and destination buffers. |
976 | FailureOr<Value> destBuffer = |
977 | getBuffer(rewriter, parallelInsertSliceOp.getDest(), options, state); |
978 | if (failed(Result: destBuffer)) |
979 | return failure(); |
980 | FailureOr<Value> srcBuffer = |
981 | getBuffer(rewriter, parallelInsertSliceOp.getSource(), options, state); |
982 | if (failed(Result: srcBuffer)) |
983 | return failure(); |
984 | |
985 | // Take a subview of the destination buffer. |
986 | auto destBufferType = cast<MemRefType>(destBuffer->getType()); |
987 | MemRefType subviewMemRefType = |
988 | memref::SubViewOp::inferRankReducedResultType( |
989 | parallelInsertSliceOp.getSourceType().getShape(), destBufferType, |
990 | parallelInsertSliceOp.getMixedOffsets(), |
991 | parallelInsertSliceOp.getMixedSizes(), |
992 | parallelInsertSliceOp.getMixedStrides()); |
993 | Value subview = rewriter.create<memref::SubViewOp>( |
994 | parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, |
995 | parallelInsertSliceOp.getMixedOffsets(), |
996 | parallelInsertSliceOp.getMixedSizes(), |
997 | parallelInsertSliceOp.getMixedStrides()); |
998 | |
999 | // This memcpy will fold away if everything bufferizes in-place. |
1000 | if (failed(options.createMemCpy(b&: rewriter, loc: parallelInsertSliceOp.getLoc(), |
1001 | from: *srcBuffer, to: subview))) |
1002 | return failure(); |
1003 | |
1004 | // In case the source was allocated in the same block, make sure that the |
1005 | // deallocation op (if any) appears after the memcpy. By default, deallocs |
1006 | // are placed before the terminator, but this does not work for ForallOp |
1007 | // because the terminator does more than just yielding a value. |
1008 | // |
1009 | // Note: This is not a problem for the destination buffer because these are |
1010 | // assumed to always bufferize in-place. |
1011 | for (Operation *user : srcBuffer->getUsers()) { |
1012 | if (hasEffect<MemoryEffects::Free>(user)) { |
1013 | if (user->getBlock() == parallelCombiningParent->getBlock()) |
1014 | rewriter.moveOpBefore(user, user->getBlock()->getTerminator()); |
1015 | break; |
1016 | } |
1017 | } |
1018 | |
1019 | // Delete the op. |
1020 | rewriter.eraseOp(op); |
1021 | return success(); |
1022 | } |
1023 | |
1024 | /// tensor.parallel_insert_slice op has implicit inplace behavior. We |
1025 | /// shouldn't create copy to resolve conflict. |
1026 | LogicalResult |
1027 | resolveConflicts(Operation *op, RewriterBase &rewriter, |
1028 | const AnalysisState &analysisState, |
1029 | const BufferizationState &bufferizationState) const { |
1030 | return success(); |
1031 | } |
1032 | }; |
1033 | |
1034 | /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled |
1035 | /// with a linalg.map. Similar to tensor.generate. |
1036 | struct SplatOpInterface |
1037 | : public BufferizableOpInterface::ExternalModel<SplatOpInterface, |
1038 | tensor::SplatOp> { |
1039 | |
1040 | bool bufferizesToAllocation(Operation *op, Value value) const { return true; } |
1041 | |
1042 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
1043 | const BufferizationOptions &options, |
1044 | BufferizationState &state) const { |
1045 | OpBuilder::InsertionGuard g(rewriter); |
1046 | auto splatOp = cast<tensor::SplatOp>(op); |
1047 | |
1048 | // Allocate memory. |
1049 | Location loc = op->getLoc(); |
1050 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
1051 | rewriter, loc, splatOp.getResult(), options, state, |
1052 | /*copy=*/false); |
1053 | if (failed(Result: tensorAlloc)) |
1054 | return failure(); |
1055 | |
1056 | // Create linalg::MapOp. |
1057 | auto tensorType = cast<RankedTensorType>(tensorAlloc->getType()); |
1058 | |
1059 | // TODO: Implement memory space for this op. |
1060 | if (options.defaultMemorySpaceFn(tensorType) != Attribute()) |
1061 | return op->emitError(message: "memory space not implemented yet"); |
1062 | |
1063 | auto linalgOp = |
1064 | rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(), |
1065 | /*init=*/*tensorAlloc); |
1066 | Block &linalgBody = linalgOp.getMapper().emplaceBlock(); |
1067 | |
1068 | // Create linalg::IndexOps. |
1069 | rewriter.setInsertionPointToStart(&linalgBody); |
1070 | rewriter.create<linalg::YieldOp>(loc, splatOp.getInput()); |
1071 | rewriter.replaceOp(splatOp, linalgOp.getResult()[0]); |
1072 | |
1073 | return success(); |
1074 | } |
1075 | }; |
1076 | |
1077 | /// Bufferization of tensor.concat. Bufferizes to a new allocation that is |
1078 | /// filled with copy ops. Similar to tensor.from_elements, but using memref.copy |
1079 | /// on subviews instead of memref.store. |
1080 | struct ConcatOpInterface |
1081 | : public BufferizableOpInterface::ExternalModel<ConcatOpInterface, |
1082 | tensor::ConcatOp> { |
1083 | |
1084 | bool bufferizesToAllocation(Operation *op, Value value) const { return true; } |
1085 | |
1086 | bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
1087 | const AnalysisState &state) const { |
1088 | return false; |
1089 | } |
1090 | |
1091 | bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
1092 | const AnalysisState &state) const { |
1093 | return true; |
1094 | } |
1095 | |
1096 | AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
1097 | const AnalysisState &state) const { |
1098 | return {}; |
1099 | } |
1100 | |
1101 | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
1102 | const BufferizationOptions &options, |
1103 | BufferizationState &state) const { |
1104 | OpBuilder::InsertionGuard g(rewriter); |
1105 | auto concatOp = cast<tensor::ConcatOp>(op); |
1106 | |
1107 | // Allocate memory. |
1108 | Location loc = op->getLoc(); |
1109 | FailureOr<Value> tensorAlloc = allocateTensorForShapedValue( |
1110 | rewriter, loc, concatOp.getResult(), options, state, |
1111 | /*copy=*/false); |
1112 | if (failed(Result: tensorAlloc)) |
1113 | return failure(); |
1114 | auto tensorType = cast<RankedTensorType>(tensorAlloc->getType()); |
1115 | |
1116 | // TODO: Implement memory space for this op. |
1117 | if (options.defaultMemorySpaceFn(tensorType) != Attribute()) |
1118 | return op->emitError(message: "memory space not implemented yet"); |
1119 | |
1120 | MemRefLayoutAttrInterface layout; |
1121 | MemRefType memrefType = |
1122 | MemRefType::get(concatOp.getResultType().getShape(), |
1123 | concatOp.getResultType().getElementType(), layout); |
1124 | Value dstBuffer = rewriter.create<bufferization::ToBufferOp>( |
1125 | op->getLoc(), memrefType, *tensorAlloc); |
1126 | |
1127 | // Extract the dimension for the concat op |
1128 | uint64_t concatDim = concatOp.getDim(); |
1129 | bool dynamicConcatDim = false; |
1130 | |
1131 | SmallVector<OpFoldResult> offsets(tensorType.getRank(), |
1132 | rewriter.getIndexAttr(0)); |
1133 | SmallVector<OpFoldResult> strides(tensorType.getRank(), |
1134 | rewriter.getIndexAttr(1)); |
1135 | SmallVector<OpFoldResult> sizes; |
1136 | |
1137 | for (const auto &[dimIdx, dimSize] : |
1138 | llvm::enumerate(tensorType.getShape())) { |
1139 | if (dimSize == ShapedType::kDynamic) { |
1140 | auto dimOp = rewriter.create<memref::DimOp>(loc, dstBuffer, dimIdx); |
1141 | sizes.push_back(dimOp.getResult()); |
1142 | if (dimIdx == concatDim) |
1143 | dynamicConcatDim = true; |
1144 | } else { |
1145 | sizes.push_back(rewriter.getIndexAttr(dimSize)); |
1146 | } |
1147 | } |
1148 | |
1149 | int64_t concatDimOffset = 0; |
1150 | std::optional<Value> dynamicOffset; |
1151 | std::optional<Value> dynamicSize; |
1152 | if (dynamicConcatDim) { |
1153 | // One or more operands have dynamic size, so we must accumulate the |
1154 | // offset with arith ops. |
1155 | dynamicOffset = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1156 | } |
1157 | |
1158 | for (auto operand : concatOp.getInputs()) { |
1159 | // Get the buffer for the operand. |
1160 | FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state); |
1161 | if (failed(srcBuffer)) |
1162 | return failure(); |
1163 | |
1164 | // Each operand may have a different size along the concat dimension, |
1165 | // so the offset on that axis must accumulate through the loop, and the |
1166 | // size must change to the size of the current operand. |
1167 | auto operandTensorType = cast<RankedTensorType>(operand.getType()); |
1168 | int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim); |
1169 | |
1170 | if (dynamicConcatDim) { |
1171 | offsets[concatDim] = dynamicOffset.value(); |
1172 | dynamicSize = rewriter.create<memref::DimOp>(loc, *srcBuffer, concatDim) |
1173 | .getResult(); |
1174 | sizes[concatDim] = dynamicSize.value(); |
1175 | } else { |
1176 | sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize); |
1177 | offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset); |
1178 | } |
1179 | |
1180 | // Create a subview of the destination buffer. |
1181 | auto dstMemrefType = cast<MemRefType>(memrefType); |
1182 | MemRefType subviewMemRefType = |
1183 | memref::SubViewOp::inferRankReducedResultType( |
1184 | operandTensorType.getShape(), dstMemrefType, offsets, sizes, |
1185 | strides); |
1186 | Value subview = rewriter.create<memref::SubViewOp>( |
1187 | loc, subviewMemRefType, dstBuffer, offsets, sizes, strides); |
1188 | |
1189 | // Copy the source buffer into the destination subview. |
1190 | if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview))) |
1191 | return failure(); |
1192 | |
1193 | if (dynamicConcatDim) { |
1194 | dynamicOffset = rewriter.create<arith::AddIOp>( |
1195 | loc, dynamicOffset.value(), dynamicSize.value()); |
1196 | } else { |
1197 | concatDimOffset += operandConcatDimSize; |
1198 | } |
1199 | } |
1200 | |
1201 | replaceOpWithBufferizedValues(rewriter, op, values: dstBuffer); |
1202 | return success(); |
1203 | } |
1204 | }; |
1205 | |
1206 | } // namespace |
1207 | } // namespace tensor |
1208 | } // namespace mlir |
1209 | |
1210 | void mlir::tensor::registerBufferizableOpInterfaceExternalModels( |
1211 | DialectRegistry ®istry) { |
1212 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, tensor::TensorDialect *dialect) { |
1213 | CastOp::attachInterface<CastOpInterface>(*ctx); |
1214 | CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx); |
1215 | ConcatOp::attachInterface<ConcatOpInterface>(*ctx); |
1216 | DimOp::attachInterface<DimOpInterface>(*ctx); |
1217 | EmptyOp::attachInterface<EmptyOpInterface>(*ctx); |
1218 | ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); |
1219 | ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx); |
1220 | ExtractOp::attachInterface<ExtractOpInterface>(*ctx); |
1221 | FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx); |
1222 | GenerateOp::attachInterface<GenerateOpInterface>(*ctx); |
1223 | InsertOp::attachInterface<InsertOpInterface>(*ctx); |
1224 | InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx); |
1225 | PadOp::attachInterface<PadOpInterface>(*ctx); |
1226 | ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>( |
1227 | *ctx); |
1228 | RankOp::attachInterface<RankOpInterface>(*ctx); |
1229 | ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx); |
1230 | SplatOp::attachInterface<SplatOpInterface>(*ctx); |
1231 | |
1232 | // Load additional dialects of which ops may get created. |
1233 | ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>(); |
1234 | }); |
1235 | |
1236 | // Bufferization requires SubsetInsertionOpInterface models. Make sure that |
1237 | // they are registered. |
1238 | tensor::registerSubsetOpInterfaceExternalModels(registry); |
1239 | } |
1240 |
Definitions
- CastOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- getBufferType
- bufferize
- CollapseShapeOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- getBufferType
- bufferize
- DimOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- bufferize
- EmptyOpInterface
- bufferizesToAllocation
- resultBufferizesToMemoryWrite
- bufferize
- ExpandShapeOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- getBufferType
- bufferize
- ExtractSliceOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- bufferize
- getBufferType
- ExtractOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- bufferize
- createStores
- FromElementsOpInterface
- bufferizesToAllocation
- bufferize
- lowerGenerateLikeOpBody
- GenerateOpInterface
- bufferizesToAllocation
- bufferize
- InsertOpInterface
- bufferize
- insertSliceOpRequiresRead
- InsertSliceOpInterface
- bufferizesToMemoryRead
- bufferize
- PadOpInterface
- bufferizesToAllocation
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- getBufferType
- bufferize
- RankOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- bufferize
- ReshapeOpInterface
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- getAliasingValues
- bufferize
- getBufferType
- ParallelInsertSliceOpInterface
- getAliasingValues
- bufferizesToMemoryRead
- bufferizesToMemoryWrite
- bufferize
- resolveConflicts
- SplatOpInterface
- bufferizesToAllocation
- bufferize
- ConcatOpInterface
- bufferizesToAllocation
- bufferizesToMemoryWrite
- bufferizesToMemoryRead
- getAliasingValues
- bufferize
Improve your Profiling and Debugging skills
Find out more