1//===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns to convert non-DPS ops to DPS ops. New
10// tensor.empty ops are inserted as a destination. Such tensor.empty can be
11// eliminated with "empty tensor elimination", allowing them to bufferize
12// without an allocation (assuming there are no further conflicts).
13//
14//===----------------------------------------------------------------------===//
15//
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
18#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
19#include "mlir/Dialect/Linalg/IR/Linalg.h"
20#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Utils/StaticValueUtils.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/IR/PatternMatch.h"
25#include "llvm/ADT/STLExtras.h"
26
27using namespace mlir;
28using namespace mlir::tensor;
29
30// Implements backtracking to traverse indices of the output buffer while
31// iterating over op.elements().
32static Value createInserts(RewriterBase &rewriter, Location loc, int dim,
33 Value destination, ArrayRef<int64_t> shape,
34 ArrayRef<Value> constants,
35 OperandRange::iterator &elementIt,
36 SmallVectorImpl<Value> &indices) {
37 if (dim == static_cast<int>(shape.size()) - 1) {
38 for (int i = 0; i < shape.back(); ++i) {
39 indices.back() = constants[i];
40 destination = rewriter.create<tensor::InsertOp>(location: loc, args: *elementIt,
41 args&: destination, args&: indices);
42 ++elementIt;
43 }
44 return destination;
45 }
46 for (int i = 0; i < shape[dim]; ++i) {
47 indices[dim] = constants[i];
48 destination = createInserts(rewriter, loc, dim: dim + 1, destination, shape,
49 constants, elementIt, indices);
50 }
51 return destination;
52}
53
54/// Create a memcpy from the given source tensor to the given destination
55/// memref. The copy op type can be specified in the `options`.
56static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
57 Value memrefDest,
58 const linalg::BufferizeToAllocationOptions &options) {
59 auto tensorType = dyn_cast<RankedTensorType>(Val: tensorSource.getType());
60 assert(tensorType && "expected ranked tensor");
61 assert(isa<MemRefType>(memrefDest.getType()) && "expected ranked memref");
62
63 switch (options.memcpyOp) {
64 case linalg::BufferizeToAllocationOptions::MemcpyOp::
65 MaterializeInDestination: {
66 // Note: This is the preferred way of memcpy'ing because no layout map
67 // and/or memory space must be specified for the source.
68 auto materializeOp = b.create<bufferization::MaterializeInDestinationOp>(
69 location: loc, args&: tensorSource, args&: memrefDest);
70 materializeOp.setWritable(true);
71 } break;
72 case linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy: {
73 // TODO: Support custom memory space on source.
74 // We do not know the layout map of the source yet, so use a fully dynamic
75 // layout for best compatibility.
76 Value toBuffer = b.create<bufferization::ToBufferOp>(
77 location: loc, args: bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
78 args&: tensorSource, /*readOnly=*/args: true);
79 b.create<memref::CopyOp>(location: loc, args&: toBuffer, args&: memrefDest);
80 } break;
81 case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: {
82 // TODO: Support custom memory space on source.
83 // We do not know the layout map of the source yet, so use a fully dynamic
84 // layout for best compatibility.
85 Value toBuffer = b.create<bufferization::ToBufferOp>(
86 location: loc, args: bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
87 args&: tensorSource, /*readOnly=*/args: true);
88 b.create<linalg::CopyOp>(location: loc, args&: toBuffer, args&: memrefDest);
89 } break;
90 };
91}
92
93static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
94 Location loc, PadOp padOp,
95 Value dest) {
96 OpBuilder::InsertionGuard g(rewriter);
97 RankedTensorType resultType = padOp.getResultType();
98
99 // Examine the yielded value to decide if a linalg.generic is neede or a
100 // linalg.fill is sufficient.
101 Value yieldedValue =
102 cast<tensor::YieldOp>(Val: padOp.getBody()->getTerminator()).getValue();
103 Attribute constYieldedValue;
104 // Is the yielded value a bbArg defined outside of the PadOp?
105 bool outsideBbArg =
106 isa<BlockArgument>(Val: yieldedValue) &&
107 cast<BlockArgument>(Val&: yieldedValue).getOwner()->getParentOp() !=
108 padOp.getOperation();
109 // Is the yielded value an OpResult defined outside of the PadOp?
110 bool outsideOpResult =
111 isa<OpResult>(Val: yieldedValue) &&
112 yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
113 bool invariantYieldedValue = outsideBbArg || outsideOpResult;
114 if (matchPattern(value: yieldedValue, pattern: m_Constant(bind_value: &constYieldedValue))) {
115 // Padding with a constant: Create linalg.fill.
116 Dialect *arithDialect =
117 rewriter.getContext()->getLoadedDialect<arith::ArithDialect>();
118 Value fillValue =
119 arithDialect
120 ->materializeConstant(builder&: rewriter, value: constYieldedValue,
121 type: yieldedValue.getType(), loc: yieldedValue.getLoc())
122 ->getResult(idx: 0);
123 auto fillOp = rewriter.create<linalg::FillOp>(location: loc, args: ValueRange(fillValue),
124 args: ValueRange(dest));
125 return fillOp;
126 }
127
128 if (invariantYieldedValue) {
129 // Padding with an invariant value.
130 auto fillOp = rewriter.create<linalg::FillOp>(location: loc, args: ValueRange(yieldedValue),
131 args: ValueRange(dest));
132 return fillOp;
133 }
134
135 // Create linalg.generic.
136 SmallVector<utils::IteratorType> iteratorTypes(resultType.getRank(),
137 utils::IteratorType::parallel);
138 SmallVector<AffineMap> indexingMaps(
139 1, rewriter.getMultiDimIdentityMap(rank: resultType.getRank()));
140 auto genericOp = rewriter.create<linalg::GenericOp>(
141 location: loc, args&: resultType, /*inputs=*/args: ValueRange(),
142 /*outputs=*/args: ValueRange{dest}, /*indexingMaps=*/
143 args&: indexingMaps, args&: iteratorTypes);
144 Block *body = rewriter.createBlock(parent: &genericOp->getRegion(index: 0), insertPt: {},
145 argTypes: resultType.getElementType(), locs: loc);
146 rewriter.setInsertionPointToStart(body);
147 SmallVector<Value> bbArgReplacements;
148 for (int64_t i = 0; i < resultType.getRank(); ++i)
149 bbArgReplacements.push_back(Elt: rewriter.create<linalg::IndexOp>(location: loc, args&: i));
150 rewriter.mergeBlocks(source: padOp.getBody(), dest: body, argValues: bbArgReplacements);
151
152 // Update terminator.
153 auto yieldOp = cast<tensor::YieldOp>(Val: body->getTerminator());
154 rewriter.replaceOpWithNewOp<linalg::YieldOp>(op: yieldOp, args: yieldOp.getValue());
155 return genericOp;
156}
157
158static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
159 Value value) {
160 auto tensorType = cast<RankedTensorType>(Val: value.getType());
161 if (tensorType.hasStaticShape())
162 return {};
163
164 // Try to reify dynamic sizes.
165 ReifiedRankedShapedTypeDims reifiedShape;
166 if (isa<OpResult>(Val: value) &&
167 succeeded(Result: reifyResultShapes(b, op: value.getDefiningOp(), reifiedReturnShapes&: reifiedShape))) {
168 SmallVector<Value> dynSizes;
169 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
170 if (tensorType.isDynamicDim(idx: i))
171 dynSizes.push_back(Elt: cast<Value>(
172 Val&: reifiedShape[cast<OpResult>(Val&: value).getResultNumber()][i]));
173 }
174 return dynSizes;
175 }
176
177 // Create tensor.dim ops.
178 SmallVector<Value> dynSizes;
179 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
180 if (tensorType.isDynamicDim(idx: i))
181 dynSizes.push_back(
182 Elt: b.create<DimOp>(location: value.getLoc(), args&: value,
183 args: b.create<arith::ConstantIndexOp>(location: value.getLoc(), args&: i)));
184 }
185 return dynSizes;
186}
187
188static Value
189createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value,
190 const linalg::BufferizeToAllocationOptions &options,
191 Attribute memorySpace = {}) {
192 OpBuilder::InsertionGuard g(rewriter);
193 auto tensorType = cast<RankedTensorType>(Val: value.getType());
194
195 // Create buffer allocation.
196 auto memrefType =
197 cast<MemRefType>(Val: bufferization::getMemRefTypeWithStaticIdentityLayout(
198 tensorType, memorySpace));
199 SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(b&: rewriter, value);
200
201 Value alloc;
202 if (options.allocOp ==
203 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc) {
204 alloc = rewriter.create<memref::AllocOp>(location: loc, args&: memrefType, args&: dynamicSizes);
205 if (options.emitDealloc) {
206 // Place deallocation at the end of the block.
207 rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
208 rewriter.create<memref::DeallocOp>(location: loc, args&: alloc);
209 }
210 } else if (options.allocOp ==
211 linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca) {
212 alloc = rewriter.create<memref::AllocaOp>(location: loc, args&: memrefType, args&: dynamicSizes);
213 // No dealloc is needed.
214 }
215
216 return alloc;
217}
218
219Value linalg::bufferizeToAllocation(
220 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
221 PadOp padOp, Attribute memorySpace, Operation *insertionPoint) {
222 // tensor.pad does not have a destination operand.
223 assert(!options.bufferizeDestinationOnly && "invalid options");
224
225 OpBuilder::InsertionGuard g(rewriter);
226 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp);
227 Location loc = padOp.getLoc();
228
229 // Create buffer allocation.
230 Value alloc = createAllocationForTensor(rewriter, loc, value: padOp.getResult(),
231 options, memorySpace);
232 rewriter.setInsertionPoint(padOp);
233
234 if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {
235 // Create linalg.fill or linalg.generic. Not needed if there is no padding.
236 Operation *fillOp =
237 movePaddingToFillOrGenericOp(rewriter, loc, padOp, dest: alloc);
238 rewriter.setInsertionPointAfter(fillOp);
239 }
240
241 // Create memcpy.
242 SmallVector<OpFoldResult> sizes =
243 getMixedSizes(builder&: rewriter, loc, value: padOp.getSource());
244 SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(),
245 rewriter.getIndexAttr(value: 1));
246 Value subview = rewriter.create<memref::SubViewOp>(
247 location: loc, args&: alloc, /*offsets=*/args: padOp.getMixedLowPad(), args&: sizes, args&: strides);
248 createMemcpy(b&: rewriter, loc, tensorSource: padOp.getSource(), memrefDest: subview, options);
249
250 // Create bufferization.to_tensor with "restrict" and "writable". The returned
251 // tensor is a new buffer allocation, so it does not alias with any buffer.
252 Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
253 location: loc, args: padOp.getResult().getType(), args&: alloc, /*restrict=*/args: true,
254 /*writable=*/args: true);
255 rewriter.replaceOp(op: padOp, newValues: toTensorOp);
256 return alloc;
257}
258
259Value linalg::bufferizeToAllocation(
260 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
261 vector::MaskOp maskOp, Attribute memorySpace, Operation *insertionPoint) {
262 assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
263 "expected single masked op");
264 OpBuilder::InsertionGuard g(rewriter);
265
266 // Should the bufferization options and state be function arguments?
267 bufferization::BufferizationOptions bufferizationOptions;
268 bufferization::BufferizationState bufferizationState;
269
270 Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
271 assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator");
272
273 // Bufferize maskable op. By default, place the buffer allocation right before
274 // the mask op.
275 Value alloc = bufferizeToAllocation(
276 rewriter, options, op: maskOp.getMaskableOp(), memorySpace,
277 /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp);
278
279 if (options.bufferizeDestinationOnly)
280 return alloc;
281
282 // Bufferize terminator.
283 rewriter.setInsertionPoint(yieldOp);
284 if (failed(Result: cast<bufferization::BufferizableOpInterface>(Val: yieldOp).bufferize(
285 rewriter, options: bufferizationOptions, state&: bufferizationState)))
286 return nullptr;
287
288 // Erase dead to_tensor ops inside of the mask op. This is necessary because
289 // there only be one op (apart from the terminator) inside the mask op.
290 // TODO: Remove dead to_tensor ops more aggressively during bufferization.
291 SmallVector<Operation *> toTensorOps;
292 maskOp.walk(callback: [&](bufferization::ToTensorOp toTensorOp) {
293 if (toTensorOp->getUses().empty())
294 toTensorOps.push_back(Elt: toTensorOp.getOperation());
295 });
296 for (Operation *op : toTensorOps)
297 rewriter.eraseOp(op);
298
299 // Bufferize mask op.
300 SmallVector<OpOperand *> resultUses;
301 for (Value result : maskOp.getResults())
302 if (isa<TensorType>(Val: result.getType()))
303 for (OpOperand &use : result.getUses())
304 resultUses.push_back(Elt: &use);
305 rewriter.setInsertionPoint(maskOp);
306 if (failed(
307 Result: cast<bufferization::BufferizableOpInterface>(Val: maskOp.getOperation())
308 .bufferize(rewriter, options: bufferizationOptions, state&: bufferizationState)))
309 return nullptr;
310
311 // Set "restrict" attribute, indicating that no other tensor aliases with
312 // this tensor. That is because we just allocated a new buffer for the tensor.
313 for (OpOperand *resultUse : resultUses) {
314 auto toTensorOp =
315 resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
316 assert(toTensorOp && "expected to_tensor op");
317 rewriter.modifyOpInPlace(root: toTensorOp, callable: [&]() {
318 toTensorOp.setRestrict(true);
319 toTensorOp.setWritable(true);
320 });
321 }
322
323 return alloc;
324}
325
326Value linalg::bufferizeToAllocation(
327 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
328 bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace,
329 Operation *insertionPoint) {
330 Location loc = allocTensorOp.getLoc();
331 OpBuilder::InsertionGuard g(rewriter);
332 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp);
333 bufferization::BufferizationOptions bufferizationOptions;
334
335 // Create buffer allocation.
336 Value alloc = createAllocationForTensor(
337 rewriter, loc, value: allocTensorOp.getResult(), options, memorySpace);
338
339 // Create bufferization.to_tensor with "restrict" and "writable". The returned
340 // tensor is a new buffer allocation, so it does not alias with any buffer.
341 Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
342 location: loc, args: allocTensorOp.getResult().getType(), args&: alloc, /*restrict=*/args: true,
343 /*writable=*/args: true);
344 rewriter.replaceOp(op: allocTensorOp, newValues: toTensorOp);
345 return alloc;
346}
347
348/// Lower tensor.from_elements to a sequence of chained tensor.insert.
349FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
350 RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
351 Location loc = fromElementsOp.getLoc();
352 RankedTensorType tensorType =
353 cast<RankedTensorType>(Val: fromElementsOp.getType());
354 auto shape = tensorType.getShape();
355
356 // Create tensor.empty.
357 auto emptyOp = rewriter.create<EmptyOp>(location: loc, args&: tensorType, args: ValueRange());
358
359 // Case: tensor<elem_type>.
360 if (shape.empty()) {
361 Operation *res = rewriter.replaceOpWithNewOp<tensor::InsertOp>(
362 op: fromElementsOp, args: fromElementsOp.getElements().front(),
363 args: emptyOp.getResult(), args: ValueRange());
364 return res;
365 }
366
367 // Create constants for the range of possible indices [0, max{shape_i}).
368 auto maxDim = *llvm::max_element(Range&: shape);
369 SmallVector<Value, 2> constants;
370 constants.reserve(N: maxDim);
371 for (int i = 0; i < maxDim; ++i)
372 constants.push_back(Elt: rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i));
373
374 // Traverse all elements and create tensor.insert ops.
375 auto elementIt = fromElementsOp.getElements().begin();
376 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
377 Value result = createInserts(rewriter, loc, /*dim=*/0, destination: emptyOp.getResult(),
378 shape, constants, elementIt, indices);
379
380 // Replace tensor.from_elements.
381 rewriter.replaceOp(op: fromElementsOp, newValues: result);
382 return result.getDefiningOp();
383}
384
385/// Lower tensor.generate to linalg.generic.
386FailureOr<Operation *>
387mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
388 tensor::GenerateOp generateOp) {
389 // Only ops with exactly one block are supported.
390 if (!generateOp.getBody().hasOneBlock())
391 return failure();
392
393 Location loc = generateOp.getLoc();
394 RankedTensorType tensorType = cast<RankedTensorType>(Val: generateOp.getType());
395
396 // Create tensor.empty.
397 auto emptyOp =
398 rewriter.create<EmptyOp>(location: loc, args&: tensorType, args: generateOp.getDynamicExtents());
399
400 // Create linalg.generic.
401 SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(),
402 utils::IteratorType::parallel);
403 SmallVector<AffineMap> indexingMaps(
404 1, rewriter.getMultiDimIdentityMap(rank: tensorType.getRank()));
405 auto genericOp = rewriter.create<linalg::GenericOp>(
406 location: loc, args&: tensorType, /*inputs=*/args: ValueRange(),
407 /*outputs=*/args: ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
408 args&: indexingMaps, args&: iteratorTypes);
409 Block *body = rewriter.createBlock(parent: &genericOp->getRegion(index: 0), insertPt: {},
410 argTypes: tensorType.getElementType(), locs: loc);
411 rewriter.setInsertionPointToStart(body);
412 SmallVector<Value> bbArgReplacements;
413 for (int64_t i = 0; i < tensorType.getRank(); ++i)
414 bbArgReplacements.push_back(Elt: rewriter.create<linalg::IndexOp>(location: loc, args&: i));
415 rewriter.mergeBlocks(source: &generateOp.getBody().front(), dest: body, argValues: bbArgReplacements);
416
417 // Update terminator.
418 auto yieldOp = cast<tensor::YieldOp>(Val: body->getTerminator());
419 rewriter.replaceOpWithNewOp<linalg::YieldOp>(op: yieldOp, args: yieldOp.getValue());
420
421 // Replace tensor.generate.
422 rewriter.replaceOp(op: generateOp, newValues: genericOp->getResult(idx: 0));
423 return genericOp.getOperation();
424}
425
426/// Lower tensor.pad to linalg.generic + tensor.insert_slice.
427FailureOr<Operation *>
428mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
429 tensor::PadOp padOp) {
430 // Only ops with exactly one block are supported.
431 if (!padOp.getBodyRegion().hasOneBlock())
432 return failure();
433
434 // Create tensor.empty.
435 Location loc = padOp.getLoc();
436 RankedTensorType resultType = padOp.getResultType();
437 ReifiedRankedShapedTypeDims reifiedShape;
438 if (failed(Result: reifyResultShapes(b&: rewriter, op: padOp, reifiedReturnShapes&: reifiedShape)))
439 return rewriter.notifyMatchFailure(
440 arg&: padOp, msg: "failed to reify tensor.pad op result shape");
441 SmallVector<Value> dynamicSizes;
442 for (int64_t i = 0; i < resultType.getRank(); ++i)
443 if (resultType.isDynamicDim(idx: i))
444 dynamicSizes.push_back(Elt: cast<Value>(Val&: reifiedShape[0][i]));
445
446 // If the `padOp` has a nofold attribute and all paddings are known to be 0,
447 // explicitly insert a `linalg.copy`.
448 if (padOp.getNofoldAttr() &&
449 llvm::all_of(Range: padOp.getMixedLowPad(), P: isZeroInteger) &&
450 llvm::all_of(Range: padOp.getMixedHighPad(), P: isZeroInteger)) {
451 using bufferization::AllocTensorOp;
452 Value allocated =
453 rewriter.create<AllocTensorOp>(location: loc, args&: resultType, args&: dynamicSizes);
454 auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>(
455 op: padOp, args: padOp.getSource(), args&: allocated);
456 return copyOp.getOperation();
457 }
458
459 Value empty = rewriter.create<EmptyOp>(location: loc, args&: resultType, args&: dynamicSizes);
460 // Create linalg.fill or linalg.generic.
461 Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, dest: empty);
462 rewriter.setInsertionPointAfter(fillOp);
463
464 // Create tensor::InsertSliceOp.
465 SmallVector<OpFoldResult> sliceSizes =
466 getMixedSizes(builder&: rewriter, loc, value: padOp.getSource());
467 SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
468 rewriter.getIndexAttr(value: 1));
469 auto insertSliceOp = rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
470 op: padOp, args: padOp.getSource(), args: fillOp->getResult(idx: 0),
471 /*offsets=*/args: padOp.getMixedLowPad(), args&: sliceSizes, args&: sliceStrides);
472 return insertSliceOp.getOperation();
473}
474
475Value linalg::bufferizeToAllocation(
476 RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
477 Operation *op, Attribute memorySpace, Operation *insertionPoint) {
478 using namespace bufferization;
479
480 // Call specialized overload for certain ops.
481 if (auto padOp = dyn_cast<tensor::PadOp>(Val: op))
482 return bufferizeToAllocation(rewriter, options, padOp, memorySpace);
483 if (auto maskOp = dyn_cast<vector::MaskOp>(Val: op))
484 return bufferizeToAllocation(rewriter, options, maskOp, memorySpace);
485 if (auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(Val: op))
486 return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace);
487
488 // Only bufferizable ops are supported.
489 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(Val: op);
490 if (!bufferizableOp)
491 return nullptr;
492
493 // Should the bufferization options and states be function arguments?
494 BufferizationOptions bufferizationOptions;
495 AnalysisState analysisState(bufferizationOptions);
496 BufferizationState bufferizationState;
497
498#ifndef NDEBUG
499 if (!options.bufferizeDestinationOnly) {
500 // Ops with nested tensor ops are not supported yet. At the moment, this
501 // function just bufferizes the given op itself, but not its body.
502 op->walk([&](Operation *nestedOp) {
503 if (op == nestedOp)
504 return;
505 if (llvm::any_of(nestedOp->getOperands(),
506 [](Value v) { return isa<TensorType>(v.getType()); }))
507 llvm_unreachable("ops with nested tensor ops are not supported yet");
508 if (llvm::any_of(nestedOp->getResults(),
509 [](Value v) { return isa<TensorType>(v.getType()); }))
510 llvm_unreachable("ops with nested tensor ops are not supported yet");
511 });
512 }
513#endif // NDEBUG
514
515 // Gather tensor results.
516 SmallVector<OpResult> tensorResults;
517 for (OpResult result : op->getResults()) {
518 if (!isa<TensorType>(Val: result.getType()))
519 continue;
520 // Unranked tensors are not supported
521 if (!isa<RankedTensorType>(Val: result.getType()))
522 return nullptr;
523 // Ops that bufferize to an allocation are not supported.
524 if (bufferizableOp.bufferizesToAllocation(value: result))
525 return nullptr;
526 tensorResults.push_back(Elt: result);
527 }
528
529 // Gather all operands that should bufferize to a new allocation. I.e.,
530 // bufferize out-of-place.
531 SmallVector<OpOperand *> outOfPlaceOperands, resultUses;
532 auto addOutOfPlaceOperand = [&](OpOperand *operand) {
533 if (!llvm::is_contained(Range&: outOfPlaceOperands, Element: operand))
534 outOfPlaceOperands.push_back(Elt: operand);
535 };
536 for (OpResult result : tensorResults) {
537 AliasingOpOperandList aliasingOperands =
538 analysisState.getAliasingOpOperands(value: result);
539 for (const AliasingOpOperand &operand : aliasingOperands) {
540 addOutOfPlaceOperand(operand.opOperand);
541 for (OpOperand &resultUse : result.getUses())
542 resultUses.push_back(Elt: &resultUse);
543 }
544 }
545 for (OpOperand &operand : op->getOpOperands()) {
546 if (!analysisState.bufferizesToMemoryWrite(opOperand&: operand))
547 continue;
548 if (!isa<RankedTensorType>(Val: operand.get().getType()))
549 continue;
550 addOutOfPlaceOperand(&operand);
551 }
552 // TODO: Support multiple buffers.
553 if (outOfPlaceOperands.size() != 1)
554 return nullptr;
555
556 // Allocate buffers.
557 OpBuilder::InsertionGuard g(rewriter);
558 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op);
559 SmallVector<Value> allocs;
560 for (OpOperand *operand : outOfPlaceOperands) {
561 Value alloc = createAllocationForTensor(
562 rewriter, loc: op->getLoc(), value: operand->get(), options, memorySpace);
563 allocs.push_back(Elt: alloc);
564 if (!analysisState.findDefinitions(opOperand: operand).empty()) {
565 // Initialize buffer with a copy of the operand data. Not needed if the
566 // tensor is uninitialized.
567 createMemcpy(b&: rewriter, loc: op->getLoc(), tensorSource: operand->get(), memrefDest: alloc, options);
568 }
569 rewriter.modifyOpInPlace(root: op, callable: [&]() {
570 auto toTensorOp = rewriter.create<ToTensorOp>(
571 location: op->getLoc(), args: operand->get().getType(), args&: alloc);
572 operand->set(toTensorOp);
573 if (options.bufferizeDestinationOnly) {
574 rewriter.modifyOpInPlace(root: toTensorOp, callable: [&]() {
575 toTensorOp.setRestrict(true);
576 toTensorOp.setWritable(true);
577 });
578 }
579 });
580 }
581
582 if (options.bufferizeDestinationOnly)
583 return allocs.front();
584
585 // Bufferize the op.
586 rewriter.setInsertionPoint(op);
587 if (failed(Result: bufferizableOp.bufferize(rewriter, options: bufferizationOptions,
588 state&: bufferizationState)))
589 return nullptr;
590
591 // Set "restrict" attribute, indicating that no other tensor aliases with
592 // this tensor. That is because we just allocated a new buffer for the tensor.
593 for (OpOperand *resultUse : resultUses) {
594 auto toTensorOp = resultUse->get().getDefiningOp<ToTensorOp>();
595 assert(toTensorOp && "expected to_tensor op");
596 rewriter.modifyOpInPlace(root: toTensorOp, callable: [&]() {
597 toTensorOp.setRestrict(true);
598 toTensorOp.setWritable(true);
599 });
600 }
601 return allocs.front();
602}
603
604namespace {
605
606template <typename OpTy>
607LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
608 PatternRewriter &rewriter) {
609 return linalg::rewriteInDestinationPassingStyle(rewriter, op);
610}
611
612} // namespace
613
614void linalg::populateConvertToDestinationStylePatterns(
615 RewritePatternSet &patterns) {
616 patterns.add(implFn: rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
617 patterns.add(implFn: rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
618 patterns.add(implFn: rewriteOpInDestinationPassingStyle<tensor::PadOp>);
619}
620

source code of mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp