1//===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns to convert non-DPS ops to DPS ops. New
10// tensor.empty ops are inserted as a destination. Such tensor.empty can be
11// eliminated with "empty tensor elimination", allowing them to bufferize
12// without an allocation (assuming there are no further conflicts).
13//
14//===----------------------------------------------------------------------===//
15//
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
19#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
20#include "mlir/Dialect/Linalg/IR/Linalg.h"
21#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22#include "mlir/Dialect/Tensor/IR/Tensor.h"
23#include "mlir/Dialect/Utils/StaticValueUtils.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/PatternMatch.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/Support/Debug.h"
28
29using namespace mlir;
30using namespace mlir::tensor;
31
32// Implements backtracking to traverse indices of the output buffer while
33// iterating over op.elements().
34static Value createInserts(RewriterBase &rewriter, Location loc, int dim,
35 Value destination, ArrayRef<int64_t> shape,
36 ArrayRef<Value> constants,
37 OperandRange::iterator &elementIt,
38 SmallVectorImpl<Value> &indices) {
39 if (dim == static_cast<int>(shape.size()) - 1) {
40 for (int i = 0; i < shape.back(); ++i) {
41 indices.back() = constants[i];
42 destination = rewriter.create<tensor::InsertOp>(loc, *elementIt,
43 destination, indices);
44 ++elementIt;
45 }
46 return destination;
47 }
48 for (int i = 0; i < shape[dim]; ++i) {
49 indices[dim] = constants[i];
50 destination = createInserts(rewriter, loc, dim: dim + 1, destination, shape,
51 constants, elementIt, indices);
52 }
53 return destination;
54}
55
56/// Create a memcpy from the given source tensor to the given destination
57/// memref. The copy op type can be specified in the `options`.
58static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
59 Value memrefDest,
60 const linalg::BufferizeToAllocationOptions &options) {
61 auto tensorType = dyn_cast<RankedTensorType>(tensorSource.getType());
62 assert(tensorType && "expected ranked tensor");
63 assert(isa<MemRefType>(memrefDest.getType()) && "expected ranked memref");
64
65 switch (options.memcpyOp) {
66 case linalg::BufferizeToAllocationOptions::MemcpyOp::
67 MaterializeInDestination: {
68 // Note: This is the preferred way of memcpy'ing because no layout map
69 // and/or memory space must be specified for the source.
70 auto materializeOp = b.create<bufferization::MaterializeInDestinationOp>(
71 loc, tensorSource, memrefDest);
72 materializeOp.setWritable(true);
73 } break;
74 case linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy: {
75 // TODO: Support custom memory space on source.
76 // We do not know the layout map of the source yet, so use a fully dynamic
77 // layout for best compatibility.
78 Value toMemref = b.create<bufferization::ToMemrefOp>(
79 loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
80 tensorSource, /*readOnly=*/true);
81 b.create<memref::CopyOp>(loc, toMemref, memrefDest);
82 } break;
83 case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: {
84 // TODO: Support custom memory space on source.
85 // We do not know the layout map of the source yet, so use a fully dynamic
86 // layout for best compatibility.
87 Value toMemref = b.create<bufferization::ToMemrefOp>(
88 loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
89 tensorSource, /*readOnly=*/true);
90 b.create<linalg::CopyOp>(loc, toMemref, memrefDest);
91 } break;
92 };
93}
94
95static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
96 Location loc, PadOp padOp,
97 Value dest) {
98 OpBuilder::InsertionGuard g(rewriter);
99 RankedTensorType resultType = padOp.getResultType();
100
101 // Examine the yielded value to decide if a linalg.generic is neede or a
102 // linalg.fill is sufficient.
103 Value yieldedValue =
104 cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
105 Attribute constYieldedValue;
106 // Is the yielded value a bbArg defined outside of the PadOp?
107 bool outsideBbArg =
108 isa<BlockArgument>(Val: yieldedValue) &&
109 cast<BlockArgument>(Val&: yieldedValue).getOwner()->getParentOp() !=
110 padOp.getOperation();
111 // Is the yielded value an OpResult defined outside of the PadOp?
112 bool outsideOpResult =
113 isa<OpResult>(Val: yieldedValue) &&
114 yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
115 bool invariantYieldedValue = outsideBbArg || outsideOpResult;
116 if (matchPattern(value: yieldedValue, pattern: m_Constant(bind_value: &constYieldedValue))) {
117 // Padding with a constant: Create linalg.fill.
118 Dialect *arithDialect =
119 rewriter.getContext()->getLoadedDialect<arith::ArithDialect>();
120 Value fillValue =
121 arithDialect
122 ->materializeConstant(builder&: rewriter, value: constYieldedValue,
123 type: yieldedValue.getType(), loc: yieldedValue.getLoc())
124 ->getResult(idx: 0);
125 auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(fillValue),
126 ValueRange(dest));
127 return fillOp;
128 }
129
130 if (invariantYieldedValue) {
131 // Padding with an invariant value.
132 auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(yieldedValue),
133 ValueRange(dest));
134 return fillOp;
135 }
136
137 // Create linalg.generic.
138 SmallVector<utils::IteratorType> iteratorTypes(resultType.getRank(),
139 utils::IteratorType::parallel);
140 SmallVector<AffineMap> indexingMaps(
141 1, rewriter.getMultiDimIdentityMap(rank: resultType.getRank()));
142 auto genericOp = rewriter.create<linalg::GenericOp>(
143 loc, resultType, /*inputs=*/ValueRange(),
144 /*outputs=*/ValueRange{dest}, /*indexingMaps=*/
145 indexingMaps, iteratorTypes);
146 Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
147 resultType.getElementType(), loc);
148 rewriter.setInsertionPointToStart(body);
149 SmallVector<Value> bbArgReplacements;
150 for (int64_t i = 0; i < resultType.getRank(); ++i)
151 bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
152 rewriter.mergeBlocks(source: padOp.getBody(), dest: body, argValues: bbArgReplacements);
153
154 // Update terminator.
155 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
156 rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
157 return genericOp;
158}
159
160static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
161 Value value) {
162 auto tensorType = cast<RankedTensorType>(value.getType());
163 if (tensorType.hasStaticShape())
164 return {};
165
166 // Try to reify dynamic sizes.
167 ReifiedRankedShapedTypeDims reifiedShape;
168 if (isa<OpResult>(Val: value) &&
169 succeeded(result: reifyResultShapes(b, op: value.getDefiningOp(), reifiedReturnShapes&: reifiedShape))) {
170 SmallVector<Value> dynSizes;
171 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
172 if (tensorType.isDynamicDim(i))
173 dynSizes.push_back(
174 Elt: reifiedShape[cast<OpResult>(Val&: value).getResultNumber()][i]
175 .get<Value>());
176 }
177 return dynSizes;
178 }
179
180 // Create tensor.dim ops.
181 SmallVector<Value> dynSizes;
182 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
183 if (tensorType.isDynamicDim(i))
184 dynSizes.push_back(
185 b.create<DimOp>(value.getLoc(), value,
186 b.create<arith::ConstantIndexOp>(value.getLoc(), i)));
187 }
188 return dynSizes;
189}
190
191static Value
192createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value,
193 const linalg::BufferizeToAllocationOptions &options,
194 Attribute memorySpace = {}) {
195 OpBuilder::InsertionGuard g(rewriter);
196 auto tensorType = cast<RankedTensorType>(value.getType());
197
198 // Create buffer allocation.
199 auto memrefType =
200 cast<MemRefType>(bufferization::getMemRefTypeWithStaticIdentityLayout(
201 tensorType: tensorType, memorySpace));
202 SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(b&: rewriter, value);
203
204 Value alloc;
205 if (options.allocOp ==
206 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc) {
207 alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes);
208 if (options.emitDealloc) {
209 // Place deallocation at the end of the block.
210 rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
211 rewriter.create<memref::DeallocOp>(loc, alloc);
212 }
213 } else if (options.allocOp ==
214 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca) {
215 alloc = rewriter.create<memref::AllocaOp>(loc, memrefType, dynamicSizes);
216 // No dealloc is needed.
217 }
218
219 return alloc;
220}
221
222Value linalg::bufferizeToAllocation(
223 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
224 PadOp padOp, Attribute memorySpace, Operation *insertionPoint) {
225 // tensor.pad does not have a destination operand.
226 assert(!options.bufferizeDestinationOnly && "invalid options");
227
228 OpBuilder::InsertionGuard g(rewriter);
229 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp);
230 Location loc = padOp.getLoc();
231
232 // Create buffer allocation.
233 Value alloc = createAllocationForTensor(rewriter, loc, padOp.getResult(),
234 options, memorySpace);
235 rewriter.setInsertionPoint(padOp);
236
237 if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {
238 // Create linalg.fill or linalg.generic. Not needed if there is no padding.
239 Operation *fillOp =
240 movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc);
241 rewriter.setInsertionPointAfter(fillOp);
242 }
243
244 // Create memcpy.
245 SmallVector<OpFoldResult> sizes =
246 getMixedSizes(rewriter, loc, padOp.getSource());
247 SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(),
248 rewriter.getIndexAttr(1));
249 Value subview = rewriter.create<memref::SubViewOp>(
250 loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides);
251 createMemcpy(rewriter, loc, padOp.getSource(), subview, options);
252
253 // Create bufferization.to_tensor with "restrict" and "writable". The returned
254 // tensor is a new buffer allocation, so it does not alias with any buffer.
255 Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
256 loc, alloc, /*restrict=*/true, /*writable=*/true);
257 rewriter.replaceOp(padOp, toTensorOp);
258 return alloc;
259}
260
261Value linalg::bufferizeToAllocation(
262 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
263 vector::MaskOp maskOp, Attribute memorySpace, Operation *insertionPoint) {
264 assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
265 "expected single masked op");
266 OpBuilder::InsertionGuard g(rewriter);
267 bufferization::BufferizationOptions bufferizationOptions;
268 Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
269 assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator");
270
271 // Bufferize maskable op. By default, place the buffer allocation right before
272 // the mask op.
273 Value alloc = bufferizeToAllocation(
274 rewriter, options, maskOp.getMaskableOp(), memorySpace,
275 /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp);
276
277 if (options.bufferizeDestinationOnly)
278 return alloc;
279
280 // Bufferize terminator.
281 rewriter.setInsertionPoint(yieldOp);
282 if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
283 rewriter, bufferizationOptions)))
284 return nullptr;
285
286 // Erase dead to_tensor ops inside of the mask op. This is necessary because
287 // there only be one op (apart from the terminator) inside the mask op.
288 // TODO: Remove dead to_tensor ops more aggressively during bufferization.
289 SmallVector<Operation *> toTensorOps;
290 maskOp.walk([&](bufferization::ToTensorOp toTensorOp) {
291 if (toTensorOp->getUses().empty())
292 toTensorOps.push_back(Elt: toTensorOp.getOperation());
293 });
294 for (Operation *op : toTensorOps)
295 rewriter.eraseOp(op);
296
297 // Bufferize mask op.
298 SmallVector<OpOperand *> resultUses;
299 for (Value result : maskOp.getResults())
300 if (isa<TensorType>(result.getType()))
301 for (OpOperand &use : result.getUses())
302 resultUses.push_back(&use);
303 rewriter.setInsertionPoint(maskOp);
304 if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
305 .bufferize(rewriter, bufferizationOptions)))
306 return nullptr;
307
308 // Set "restrict" attribute, indicating that no other tensor aliases with
309 // this tensor. That is because we just allocated a new buffer for the tensor.
310 for (OpOperand *resultUse : resultUses) {
311 auto toTensorOp =
312 resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
313 assert(toTensorOp && "expected to_tensor op");
314 rewriter.modifyOpInPlace(toTensorOp, [&]() {
315 toTensorOp.setRestrict(true);
316 toTensorOp.setWritable(true);
317 });
318 }
319
320 return alloc;
321}
322
323Value linalg::bufferizeToAllocation(
324 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
325 bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace,
326 Operation *insertionPoint) {
327 Location loc = allocTensorOp.getLoc();
328 OpBuilder::InsertionGuard g(rewriter);
329 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp);
330 bufferization::BufferizationOptions bufferizationOptions;
331
332 // Create buffer allocation.
333 Value alloc = createAllocationForTensor(
334 rewriter, loc, allocTensorOp.getResult(), options, memorySpace);
335
336 // Create bufferization.to_tensor with "restrict" and "writable". The returned
337 // tensor is a new buffer allocation, so it does not alias with any buffer.
338 Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
339 loc, alloc, /*restrict=*/true, /*writable=*/true);
340 rewriter.replaceOp(allocTensorOp, toTensorOp);
341 return alloc;
342}
343
344/// Lower tensor.from_elements to a sequence of chained tensor.insert.
345FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
346 RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
347 Location loc = fromElementsOp.getLoc();
348 RankedTensorType tensorType =
349 cast<RankedTensorType>(fromElementsOp.getType());
350 auto shape = tensorType.getShape();
351
352 // Create tensor.empty.
353 auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange());
354
355 // Case: tensor<elem_type>.
356 if (shape.empty()) {
357 Operation *res = rewriter.replaceOpWithNewOp<tensor::InsertOp>(
358 fromElementsOp, fromElementsOp.getElements().front(),
359 emptyOp.getResult(), ValueRange());
360 return res;
361 }
362
363 // Create constants for the range of possible indices [0, max{shape_i}).
364 auto maxDim = *llvm::max_element(shape);
365 SmallVector<Value, 2> constants;
366 constants.reserve(N: maxDim);
367 for (int i = 0; i < maxDim; ++i)
368 constants.push_back(rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i));
369
370 // Traverse all elements and create tensor.insert ops.
371 auto elementIt = fromElementsOp.getElements().begin();
372 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
373 Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(),
374 shape, constants, elementIt, indices);
375
376 // Replace tensor.from_elements.
377 rewriter.replaceOp(fromElementsOp, result);
378 return result.getDefiningOp();
379}
380
381/// Lower tensor.generate to linalg.generic.
382FailureOr<Operation *>
383mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
384 tensor::GenerateOp generateOp) {
385 // Only ops with exactly one block are supported.
386 if (!generateOp.getBody().hasOneBlock())
387 return failure();
388
389 Location loc = generateOp.getLoc();
390 RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
391
392 // Create tensor.empty.
393 auto emptyOp =
394 rewriter.create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents());
395
396 // Create linalg.generic.
397 SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(),
398 utils::IteratorType::parallel);
399 SmallVector<AffineMap> indexingMaps(
400 1, rewriter.getMultiDimIdentityMap(rank: tensorType.getRank()));
401 auto genericOp = rewriter.create<linalg::GenericOp>(
402 loc, tensorType, /*inputs=*/ValueRange(),
403 /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
404 indexingMaps, iteratorTypes);
405 Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
406 tensorType.getElementType(), loc);
407 rewriter.setInsertionPointToStart(body);
408 SmallVector<Value> bbArgReplacements;
409 for (int64_t i = 0; i < tensorType.getRank(); ++i)
410 bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
411 rewriter.mergeBlocks(source: &generateOp.getBody().front(), dest: body, argValues: bbArgReplacements);
412
413 // Update terminator.
414 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
415 rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
416
417 // Replace tensor.generate.
418 rewriter.replaceOp(generateOp, genericOp->getResult(0));
419 return genericOp.getOperation();
420}
421
422/// Lower tensor.pad to linalg.generic + tensor.insert_slice.
423FailureOr<Operation *>
424mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
425 tensor::PadOp padOp) {
426 // Only ops with exactly one block are supported.
427 if (!padOp.getBodyRegion().hasOneBlock())
428 return failure();
429
430 // Create tensor.empty.
431 Location loc = padOp.getLoc();
432 RankedTensorType resultType = padOp.getResultType();
433 ReifiedRankedShapedTypeDims reifiedShape;
434 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
435 return rewriter.notifyMatchFailure(
436 padOp, "failed to reify tensor.pad op result shape");
437 SmallVector<Value> dynamicSizes;
438 for (int64_t i = 0; i < resultType.getRank(); ++i)
439 if (resultType.isDynamicDim(i))
440 dynamicSizes.push_back(Elt: reifiedShape[0][i].get<Value>());
441
442 // If the `padOp` has a nofold attribute and all paddings are known to be 0,
443 // explicitly insert a `linalg.copy`.
444 if (padOp.getNofoldAttr() &&
445 llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) &&
446 llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) {
447 using bufferization::AllocTensorOp;
448 Value allocated =
449 rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes);
450 auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>(
451 padOp, padOp.getSource(), allocated);
452 return copyOp.getOperation();
453 }
454
455 Value empty = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes);
456 // Create linalg.fill or linalg.generic.
457 Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty);
458 rewriter.setInsertionPointAfter(fillOp);
459
460 // Create tensor::InsertSliceOp.
461 SmallVector<OpFoldResult> sliceSizes =
462 getMixedSizes(rewriter, loc, padOp.getSource());
463 SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
464 rewriter.getIndexAttr(1));
465 auto insertSliceOp = rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
466 padOp, padOp.getSource(), fillOp->getResult(0),
467 /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
468 return insertSliceOp.getOperation();
469}
470
471Value linalg::bufferizeToAllocation(
472 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
473 Operation *op, Attribute memorySpace, Operation *insertionPoint) {
474 using namespace bufferization;
475
476 // Call specialized overload for certain ops.
477 if (auto padOp = dyn_cast<tensor::PadOp>(op))
478 return bufferizeToAllocation(rewriter, options, padOp, memorySpace);
479 if (auto maskOp = dyn_cast<vector::MaskOp>(op))
480 return bufferizeToAllocation(rewriter, options, maskOp, memorySpace);
481 if (auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
482 return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace);
483
484 // Only bufferizable ops are supported.
485 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
486 if (!bufferizableOp)
487 return nullptr;
488 BufferizationOptions bufferizationOptions;
489 AnalysisState state(bufferizationOptions);
490
491#ifndef NDEBUG
492 if (!options.bufferizeDestinationOnly) {
493 // Ops with nested tensor ops are not supported yet. At the moment, this
494 // function just bufferizes the given op itself, but not its body.
495 op->walk(callback: [&](Operation *nestedOp) {
496 if (op == nestedOp)
497 return;
498 if (llvm::any_of(Range: nestedOp->getOperands(),
499 P: [](Value v) { return isa<TensorType>(Val: v.getType()); }))
500 llvm_unreachable("ops with nested tensor ops are not supported yet");
501 if (llvm::any_of(Range: nestedOp->getResults(),
502 P: [](Value v) { return isa<TensorType>(Val: v.getType()); }))
503 llvm_unreachable("ops with nested tensor ops are not supported yet");
504 });
505 }
506#endif // NDEBUG
507
508 // Gather tensor results.
509 SmallVector<OpResult> tensorResults;
510 for (OpResult result : op->getResults()) {
511 if (!isa<TensorType>(Val: result.getType()))
512 continue;
513 // Unranked tensors are not supported
514 if (!isa<RankedTensorType>(Val: result.getType()))
515 return nullptr;
516 // Ops that bufferize to an allocation are not supported.
517 if (bufferizableOp.bufferizesToAllocation(result))
518 return nullptr;
519 tensorResults.push_back(Elt: result);
520 }
521
522 // Gather all operands that should bufferize to a new allocation. I.e.,
523 // bufferize out-of-place.
524 SmallVector<OpOperand *> outOfPlaceOperands, resultUses;
525 auto addOutOfPlaceOperand = [&](OpOperand *operand) {
526 if (!llvm::is_contained(Range&: outOfPlaceOperands, Element: operand))
527 outOfPlaceOperands.push_back(Elt: operand);
528 };
529 for (OpResult result : tensorResults) {
530 AliasingOpOperandList aliasingOperands =
531 state.getAliasingOpOperands(result);
532 for (const AliasingOpOperand &operand : aliasingOperands) {
533 addOutOfPlaceOperand(operand.opOperand);
534 for (OpOperand &resultUse : result.getUses())
535 resultUses.push_back(&resultUse);
536 }
537 }
538 for (OpOperand &operand : op->getOpOperands()) {
539 if (!state.bufferizesToMemoryWrite(opOperand&: operand))
540 continue;
541 if (!isa<RankedTensorType>(Val: operand.get().getType()))
542 continue;
543 addOutOfPlaceOperand(&operand);
544 }
545 // TODO: Support multiple buffers.
546 if (outOfPlaceOperands.size() != 1)
547 return nullptr;
548
549 // Allocate buffers.
550 OpBuilder::InsertionGuard g(rewriter);
551 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op);
552 SmallVector<Value> allocs;
553 for (OpOperand *operand : outOfPlaceOperands) {
554 Value alloc = createAllocationForTensor(
555 rewriter, loc: op->getLoc(), value: operand->get(), options, memorySpace);
556 allocs.push_back(Elt: alloc);
557 if (!state.findDefinitions(operand->get()).empty()) {
558 // Initialize buffer with a copy of the operand data. Not needed if the
559 // tensor is uninitialized.
560 createMemcpy(b&: rewriter, loc: op->getLoc(), tensorSource: operand->get(), memrefDest: alloc, options);
561 }
562 rewriter.modifyOpInPlace(root: op, callable: [&]() {
563 auto toTensorOp = rewriter.create<ToTensorOp>(op->getLoc(), alloc);
564 operand->set(toTensorOp);
565 if (options.bufferizeDestinationOnly) {
566 rewriter.modifyOpInPlace(toTensorOp, [&]() {
567 toTensorOp.setRestrict(true);
568 toTensorOp.setWritable(true);
569 });
570 }
571 });
572 }
573
574 if (options.bufferizeDestinationOnly)
575 return allocs.front();
576
577 // Bufferize the op.
578 rewriter.setInsertionPoint(op);
579 if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions)))
580 return nullptr;
581
582 // Set "restrict" attribute, indicating that no other tensor aliases with
583 // this tensor. That is because we just allocated a new buffer for the tensor.
584 for (OpOperand *resultUse : resultUses) {
585 auto toTensorOp = resultUse->get().getDefiningOp<ToTensorOp>();
586 assert(toTensorOp && "expected to_tensor op");
587 rewriter.modifyOpInPlace(toTensorOp, [&]() {
588 toTensorOp.setRestrict(true);
589 toTensorOp.setWritable(true);
590 });
591 }
592 return allocs.front();
593}
594
595namespace {
596
597template <typename OpTy>
598LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
599 PatternRewriter &rewriter) {
600 return linalg::rewriteInDestinationPassingStyle(rewriter, op);
601}
602
603} // namespace
604
605void linalg::populateConvertToDestinationStylePatterns(
606 RewritePatternSet &patterns) {
607 patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
608 patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
609 patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
610}
611

source code of mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp