1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Arith/Utils/Utils.h"
12#include "mlir/Dialect/Complex/IR/Complex.h"
13#include "mlir/Dialect/Linalg/IR/RelayoutOpInterface.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/Dialect/Utils/IndexingUtils.h"
16#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
17#include "mlir/Dialect/Utils/StaticValueUtils.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinAttributeInterfaces.h"
20#include "mlir/IR/BuiltinTypeInterfaces.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/IRMapping.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/IR/OpDefinition.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/TypeUtilities.h"
27#include "mlir/Interfaces/DestinationStyleOpInterface.h"
28#include "mlir/Interfaces/InferIntRangeInterface.h"
29#include "mlir/Interfaces/LoopLikeInterface.h"
30#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
31#include "mlir/Interfaces/ViewLikeInterface.h"
32#include "mlir/Support/LLVM.h"
33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/SmallBitVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/LogicalResult.h"
39#include "llvm/Support/MathExtras.h"
40#include <algorithm>
41#include <optional>
42#include <vector>
43
44using namespace mlir;
45using namespace mlir::tensor;
46
47using llvm::divideCeilSigned;
48using llvm::divideFloorSigned;
49using llvm::mod;
50
51/// Materialize a single constant operation from a given attribute value with
52/// the desired resultant type.
53Operation *TensorDialect::materializeConstant(OpBuilder &builder,
54 Attribute value, Type type,
55 Location loc) {
56 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
57 return op;
58 if (complex::ConstantOp::isBuildableWith(value, type))
59 return builder.create<complex::ConstantOp>(loc, type,
60 llvm::cast<ArrayAttr>(value));
61 return nullptr;
62}
63
64OpFoldResult tensor::getMixedSize(OpBuilder &builder, Location loc, Value value,
65 int64_t dim) {
66 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
67 if (tensorType.isDynamicDim(dim))
68 return builder.createOrFold<tensor::DimOp>(loc, value, dim);
69
70 return builder.getIndexAttr(value: tensorType.getDimSize(dim));
71}
72
73SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder,
74 Location loc, Value value) {
75 auto tensorType = llvm::cast<RankedTensorType>(value.getType());
76 SmallVector<OpFoldResult> result;
77 for (int64_t i = 0; i < tensorType.getRank(); ++i)
78 result.push_back(Elt: getMixedSize(builder, loc, value, dim: i));
79 return result;
80}
81
82FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc,
83 OpResult opResult) {
84 auto tensorType = llvm::dyn_cast<TensorType>(Val: opResult.getType());
85 assert(tensorType && "expected tensor type");
86
87 // If the op has a destination, it implements DestinationStyleOpInterface and
88 // we can query the destination operand from that interface.
89 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
90 if (destOp)
91 return destOp.getTiedOpOperand(opResult)->get();
92
93 // Otherwise, create a new destination tensor with the same shape.
94 OpBuilder::InsertionGuard g(b);
95 b.setInsertionPoint(opResult.getDefiningOp());
96
97 // Compute sizes.
98 SmallVector<OpFoldResult> mixedSizes;
99 if (!tensorType.hasStaticShape()) {
100 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
101 ReifiedRankedShapedTypeDims reifiedShapes;
102 if (failed(Result: reifyResultShapes(b, op: opResult.getDefiningOp(), reifiedReturnShapes&: reifiedShapes)))
103 return failure();
104 mixedSizes = reifiedShapes[opResult.getResultNumber()];
105 } else {
106 // Static shape: Take static sizes directly.
107 for (int64_t sz : tensorType.getShape())
108 mixedSizes.push_back(b.getIndexAttr(sz));
109 }
110
111 // Create empty tensor.
112 Value emptyTensor =
113 b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
114 return emptyTensor;
115}
116
117LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
118 Operation *op,
119 SmallVector<Value> &result) {
120 for (OpResult opResult : op->getResults()) {
121 if (llvm::isa<TensorType>(Val: opResult.getType())) {
122 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
123 if (failed(Result: destination))
124 return failure();
125 result.push_back(Elt: *destination);
126 }
127 }
128 return success();
129}
130
131bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
132 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
133 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
134 return rtp1.getShape() == rtp2.getShape() &&
135 rtp1.getElementType() == rtp2.getElementType();
136 return false;
137 }
138 return tp1 == tp2; // default implementation
139}
140
141/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
142/// rank-extending tensor.insert_slice op.
143static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
144 ArrayRef<OpFoldResult> mixedSizes) {
145 llvm::SmallBitVector droppedDims(mixedSizes.size());
146 int64_t shapePos = reducedShape.size() - 1;
147
148 for (const auto &size : enumerate(First: llvm::reverse(C&: mixedSizes))) {
149 size_t idx = mixedSizes.size() - size.index() - 1;
150 // Rank-reduced dims must have a static unit dimension.
151 bool isStaticUnitSize =
152 isa<Attribute>(Val: size.value()) &&
153 llvm::cast<IntegerAttr>(cast<Attribute>(Val: size.value())).getInt() == 1;
154
155 if (shapePos < 0) {
156 // There are no more dims in the reduced shape. All remaining sizes must
157 // be rank-reduced dims.
158 assert(isStaticUnitSize && "expected unit dim");
159 droppedDims.set(idx);
160 continue;
161 }
162
163 // Dim is preserved if the size is not a static 1.
164 if (!isStaticUnitSize) {
165 --shapePos;
166 continue;
167 }
168
169 // Dim is preserved if the reduced shape dim is also 1.
170 if (reducedShape[shapePos] == 1) {
171 --shapePos;
172 continue;
173 }
174
175 // Otherwise: Dim is dropped.
176 droppedDims.set(idx);
177 }
178
179 assert(shapePos < 0 && "dimension mismatch");
180 return droppedDims;
181}
182
183/// Given a ranked tensor type and a range of values that defines its dynamic
184/// dimension sizes, turn all dynamic sizes that have a constant value into
185/// static dimension sizes.
186static RankedTensorType
187foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
188 SmallVector<Value> &foldedDynamicSizes) {
189 SmallVector<int64_t> staticShape(type.getShape());
190 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
191 "incorrect number of dynamic sizes");
192
193 // Compute new static and dynamic sizes.
194 unsigned ctr = 0;
195 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
196 if (type.isDynamicDim(i)) {
197 Value dynamicSize = dynamicSizes[ctr++];
198 std::optional<int64_t> cst = getConstantIntValue(ofr: dynamicSize);
199 if (cst.has_value()) {
200 // Dynamic size must be non-negative.
201 if (cst.value() < 0) {
202 foldedDynamicSizes.push_back(Elt: dynamicSize);
203 continue;
204 }
205 staticShape[i] = *cst;
206 } else {
207 foldedDynamicSizes.push_back(Elt: dynamicSize);
208 }
209 }
210 }
211
212 return RankedTensorType::get(staticShape, type.getElementType(),
213 type.getEncoding());
214}
215
216//===----------------------------------------------------------------------===//
217// BitcastOp
218//===----------------------------------------------------------------------===//
219
220bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
221 if (inputs.size() != 1 || outputs.size() != 1)
222 return false;
223 Type a = inputs.front(), b = outputs.front();
224 auto aT = dyn_cast<TensorType>(a);
225 auto bT = dyn_cast<TensorType>(b);
226 if (!aT || !bT)
227 return false;
228
229 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
230 return false;
231
232 return succeeded(verifyCompatibleShape(aT, bT));
233}
234
235namespace {
236
237/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
238/// operation.
239struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
240 using OpRewritePattern<BitcastOp>::OpRewritePattern;
241
242 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
243 PatternRewriter &rewriter) const final {
244 auto tensorBitcastOperand =
245 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
246 if (!tensorBitcastOperand)
247 return failure();
248
249 auto resultType = cast<TensorType>(tensorBitcast.getType());
250 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
251 tensorBitcastOperand.getOperand());
252 return success();
253 }
254};
255
256} // namespace
257
258void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
259 MLIRContext *context) {
260 results.add<ChainedTensorBitcast>(context);
261}
262
263//===----------------------------------------------------------------------===//
264// CastOp
265//===----------------------------------------------------------------------===//
266
267void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
268 setNameFn(getResult(), "cast");
269}
270
271/// Returns true if `target` is a ranked tensor type that preserves static
272/// information available in the `source` ranked tensor type.
273bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
274 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
275 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
276
277 // Requires RankedTensorType.
278 if (!sourceType || !targetType)
279 return false;
280
281 // Requires same elemental type.
282 if (sourceType.getElementType() != targetType.getElementType())
283 return false;
284
285 // Requires same rank.
286 if (sourceType.getRank() != targetType.getRank())
287 return false;
288
289 // Requires same encoding.
290 if (sourceType.getEncoding() != targetType.getEncoding())
291 return false;
292
293 // If cast is towards more static sizes along any dimension, don't fold.
294 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
295 if (!ShapedType::isDynamic(std::get<0>(t)) &&
296 ShapedType::isDynamic(std::get<1>(t)))
297 return false;
298 }
299
300 return true;
301}
302
303/// Determines whether tensor::CastOp casts to a more dynamic version of the
304/// source tensor. This is useful to fold a tensor.cast into a consuming op and
305/// implement canonicalization patterns for ops in different dialects that may
306/// consume the results of tensor.cast operations. Such foldable tensor.cast
307/// operations are typically inserted as `slice` ops and are canonicalized,
308/// to preserve the type compatibility of their uses.
309///
310/// Returns true when all conditions are met:
311/// 1. source and result are ranked tensors with same element type and rank.
312/// 2. the tensor type has more static information than the result
313///
314/// Example:
315/// ```mlir
316/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
317/// %2 = consumer %1 ... : tensor<?x?xf32> ...
318/// ```
319///
320/// folds into:
321///
322/// ```mlir
323/// %2 = consumer %0 ... : tensor<8x16xf32> ...
324/// ```
325bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
326 if (!castOp)
327 return false;
328
329 // Can fold if the source of cast has at least as much static information as
330 // its results.
331 return preservesStaticInformation(castOp.getType(),
332 castOp.getSource().getType());
333}
334
335/// Determines whether the tensor::CastOp casts to a more static version of the
336/// source tensor. This is useful to fold into a producing op and implement
337/// canonicalization patterns with the `tensor.cast` op as the root, but
338/// producer being from different dialects. Returns true when all conditions are
339/// met:
340/// 1. source and result and ranked tensors with same element type and rank.
341/// 2. the result type has more static information than the source.
342///
343/// Example:
344/// ```mlir
345/// %1 = producer ... : tensor<?x?xf32>
346/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
347/// ```
348///
349/// can be canonicalized to :
350///
351/// ```mlir
352/// %2 = producer ... : tensor<8x16xf32>
353/// ```
354/// Not all ops might be canonicalizable this way, but for those that can be,
355/// this method provides a check that it is worth doing the canonicalization.
356bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
357 if (!castOp)
358 return false;
359 return preservesStaticInformation(castOp.getSource().getType(),
360 castOp.getType());
361}
362
363bool mlir::tensor::hasFoldableTensorCastOperand(Operation *op) {
364 return llvm::any_of(Range: op->getOpOperands(), P: [&](OpOperand &opOperand) {
365 if (llvm::isa<BlockArgument>(Val: opOperand.get()))
366 return false;
367 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
368 return castOp && canFoldIntoConsumerOp(castOp);
369 });
370}
371
372SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding(
373 DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
374 SmallVector<Value> newOperands;
375 newOperands.reserve(N: op->getNumOperands());
376
377 assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
378
379 // Assumes that the result has dpsInits followed by nonDpsInits.
380 int64_t dpsInitIdx = 0;
381 for (OpOperand &opOperand : op->getOpOperands()) {
382 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
383 bool fold = canFoldIntoConsumerOp(tensorCastOp);
384 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
385 if (op.isDpsInit(&opOperand) &&
386 !llvm::isa<MemRefType>(newOperands.back().getType()))
387 newResTy[dpsInitIdx++] = newOperands.back().getType();
388 }
389 return newOperands;
390}
391
392/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
393/// that can be folded.
394LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
395 bool folded = false;
396 for (OpOperand &operand : op->getOpOperands()) {
397 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
398 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
399 operand.set(castOp.getOperand());
400 folded = true;
401 }
402 }
403 return success(IsSuccess: folded);
404}
405
406bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
407 if (inputs.size() != 1 || outputs.size() != 1)
408 return false;
409 Type a = inputs.front(), b = outputs.front();
410 auto aT = llvm::dyn_cast<TensorType>(a);
411 auto bT = llvm::dyn_cast<TensorType>(b);
412 if (!aT || !bT)
413 return false;
414
415 if (aT.getElementType() != bT.getElementType())
416 return false;
417
418 return succeeded(verifyCompatibleShape(aT, bT));
419}
420
421/// Compute a TensorType that has the joined shape knowledge of the two
422/// given TensorTypes. The element types need to match.
423static TensorType joinShapes(TensorType one, TensorType two) {
424 assert(one.getElementType() == two.getElementType());
425
426 if (!one.hasRank())
427 return two;
428 if (!two.hasRank())
429 return one;
430
431 int64_t rank = one.getRank();
432 if (rank != two.getRank())
433 return {};
434
435 SmallVector<int64_t, 4> join;
436 join.reserve(N: rank);
437 for (int64_t i = 0; i < rank; ++i) {
438 if (one.isDynamicDim(i)) {
439 join.push_back(Elt: two.getDimSize(i));
440 continue;
441 }
442 if (two.isDynamicDim(i)) {
443 join.push_back(Elt: one.getDimSize(i));
444 continue;
445 }
446 if (one.getDimSize(i) != two.getDimSize(i))
447 return {};
448 join.push_back(Elt: one.getDimSize(i));
449 }
450 return RankedTensorType::get(join, one.getElementType());
451}
452
453namespace {
454
455/// Replaces chains of two tensor.cast operations by a single tensor.cast
456/// operation if doing so does not remove runtime constraints.
457struct ChainedTensorCast : public OpRewritePattern<CastOp> {
458 using OpRewritePattern<CastOp>::OpRewritePattern;
459
460 LogicalResult matchAndRewrite(CastOp tensorCast,
461 PatternRewriter &rewriter) const final {
462 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
463
464 if (!tensorCastOperand)
465 return failure();
466
467 auto sourceType =
468 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
469 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
470 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
471
472 // We can remove the intermediate cast if joining all three produces the
473 // same result as just joining the source and result shapes.
474 auto firstJoin =
475 joinShapes(joinShapes(sourceType, intermediateType), resultType);
476
477 // The join might not exist if the cast sequence would fail at runtime.
478 if (!firstJoin)
479 return failure();
480
481 // The newJoin always exists if the above join exists, it might just contain
482 // less information. If so, we cannot drop the intermediate cast, as doing
483 // so would remove runtime checks.
484 auto newJoin = joinShapes(sourceType, resultType);
485 if (firstJoin != newJoin)
486 return failure();
487
488 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
489 tensorCastOperand.getOperand());
490 return success();
491 }
492};
493
494/// Fold tensor.cast into tesor.extract_slice producer.
495/// Example:
496/// ```
497/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
498/// tensor<128x512xf32> to tensor<?x512xf32>
499/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
500/// ```
501/// ->
502/// ```
503/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
504/// tensor<128x512xf32> to tensor<16x512xf32>
505/// ```
506struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
507 using OpRewritePattern<CastOp>::OpRewritePattern;
508
509 LogicalResult matchAndRewrite(CastOp tensorCast,
510 PatternRewriter &rewriter) const final {
511 auto extractOperand =
512 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
513
514 // Cannot fold cast to unranked tensor.
515 auto rankedResultType =
516 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
517 if (!rankedResultType)
518 return failure();
519
520 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
521 rankedResultType.getShape() ==
522 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
523 .getShape())
524 return failure();
525
526 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
527 auto dimMask = computeRankReductionMask(
528 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
529 size_t dimIndex = 0;
530 for (size_t i = 0, e = sizes.size(); i < e; i++) {
531 if (dimMask && dimMask->count(i))
532 continue;
533 int64_t dim = rankedResultType.getShape()[dimIndex++];
534 if (ShapedType::isDynamic(dim))
535 continue;
536 sizes[i] = rewriter.getIndexAttr(dim);
537 }
538
539 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
540 tensorCast, rankedResultType, extractOperand.getSource(),
541 extractOperand.getMixedOffsets(), sizes,
542 extractOperand.getMixedStrides());
543 return success();
544 }
545};
546
547} // namespace
548
549void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
550 MLIRContext *context) {
551 results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
552}
553
554//===----------------------------------------------------------------------===//
555// ConcatOp
556//===----------------------------------------------------------------------===//
557
558RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
559 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
560 auto tensorTypes =
561 llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
562 return llvm::cast<RankedTensorType>(type);
563 }));
564 int64_t concatRank = tensorTypes[0].getRank();
565
566 // The concatenation dim must be in the range [0, rank).
567 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
568
569 SmallVector<int64_t> sizes(concatRank);
570 for (int64_t i = 0, e = concatRank; i < e; ++i) {
571 if (i == dim)
572 continue;
573 SaturatedInteger size;
574 for (auto tensorType : tensorTypes)
575 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
576 sizes[i] = size.asInteger();
577 }
578 auto concatSize = SaturatedInteger::wrap(0);
579 for (auto tensorType : tensorTypes)
580 concatSize =
581 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
582 sizes[dim] = concatSize.asInteger();
583 return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
584}
585
586void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
587 ValueRange inputs) {
588 FailureOr<RankedTensorType> resultType =
589 inferResultType(dim, inputs.getTypes());
590 assert(succeeded(resultType) && "failed to infer concatenation result type");
591 build(builder, result, *resultType, dim, inputs);
592}
593
594LogicalResult ConcatOp::verify() {
595 if (getInputs().size() < 1)
596 return emitOpError("requires at least one input");
597
598 SmallVector<RankedTensorType> inputTypes;
599 for (auto input : getInputs())
600 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
601
602 RankedTensorType resultType = getResultType();
603 int64_t resultRank = getRank();
604 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
605 return type.getRank() != resultRank;
606 }))
607 return emitOpError("rank of concatenated inputs must match result rank");
608
609 Type resultElementType = resultType.getElementType();
610 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
611 return type.getElementType() != resultElementType;
612 }))
613 return emitOpError("inputs and result element type must match");
614
615 int64_t dim = getDim();
616 if (dim >= resultRank)
617 return emitOpError("concatenation dim must be less than the tensor rank");
618
619 SmallVector<int64_t> sizes(resultRank);
620 for (int64_t i = 0, e = resultRank; i < e; ++i) {
621 if (i == dim)
622 continue;
623 SaturatedInteger size;
624 for (auto tensorType : inputTypes) {
625 FailureOr<SaturatedInteger> maybeSize =
626 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
627 if (failed(maybeSize))
628 return emitOpError("static concatenation size mismatch along ")
629 << "non-concatenated dimension " << i;
630 size = *maybeSize;
631 }
632 sizes[i] = size.asInteger();
633 }
634 auto concatSize = SaturatedInteger::wrap(0);
635 for (auto tensorType : inputTypes)
636 concatSize =
637 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
638 sizes[dim] = concatSize.asInteger();
639 auto inferredResultType =
640 RankedTensorType::get(sizes, inputTypes[0].getElementType());
641
642 for (auto [inferredSize, actualSize] :
643 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
644 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
645 ShapedType::isDynamic(actualSize);
646 if (!hasDynamic && inferredSize != actualSize)
647 return emitOpError("result type ")
648 << resultType << "does not match inferred shape "
649 << inferredResultType << " static sizes";
650 }
651
652 return success();
653}
654
655FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
656 size_t numInputs = getInputs().size();
657 uint64_t concatDim = getDim();
658
659 SmallVector<SmallVector<OpFoldResult>> inputShapes;
660 inputShapes.reserve(numInputs);
661 SmallVector<OpFoldResult> concatOffsets;
662 concatOffsets.reserve(numInputs);
663 SmallVector<OpFoldResult> outputShape;
664
665 AffineExpr addExpr =
666 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
667 OpFoldResult zero = builder.getIndexAttr(0);
668 Location loc = getLoc();
669 for (auto [index, input] : llvm::enumerate(getInputs())) {
670 SmallVector<OpFoldResult> inputShape =
671 tensor::getMixedSizes(builder, input.getLoc(), input);
672 if (index == 0) {
673 outputShape = inputShape;
674 concatOffsets.push_back(zero);
675 } else {
676 concatOffsets.push_back(outputShape[concatDim]);
677 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
678 builder, loc, addExpr,
679 {outputShape[concatDim], inputShape[concatDim]});
680 }
681 inputShapes.emplace_back(std::move(inputShape));
682 }
683
684 Value replacement = builder.create<tensor::EmptyOp>(
685 loc, outputShape, getType().getElementType());
686
687 int64_t rank = getType().getRank();
688 OpFoldResult one = builder.getIndexAttr(1);
689 SmallVector<OpFoldResult> strides(rank, one);
690 SmallVector<OpFoldResult> offsets(rank, zero);
691 for (auto [index, input] : llvm::enumerate(getInputs())) {
692 offsets[concatDim] = concatOffsets[index];
693 auto insertSlice = builder.create<tensor::InsertSliceOp>(
694 loc, input, replacement, offsets, inputShapes[index], strides);
695 replacement = insertSlice.getResult();
696 }
697 if (replacement.getType() != getType()) {
698 replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
699 }
700 return SmallVector<Value>{replacement};
701}
702
703LogicalResult
704ConcatOp::reifyResultShapes(OpBuilder &builder,
705 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
706 ValueRange inputs = getInputs();
707 int64_t dim = getDim();
708 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
709
710 Value init = inputs[0];
711 int64_t rank = getType().getRank();
712
713 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
714
715 // Pre-populate the result sizes with as much static information as possible
716 // from the given result type, as well as the inferred result type, otherwise
717 // use the dim sizes from the first input.
718 for (int64_t i = 0; i < rank; ++i) {
719 if (i == dim)
720 continue;
721 if (!getType().isDynamicDim(i)) {
722 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
723 } else if (!inferredResultType.isDynamicDim(i)) {
724 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
725 builder, getLoc(),
726 builder.getIndexAttr(inferredResultType.getDimSize(i)));
727 } else {
728 reifiedReturnShapes[0][i] =
729 builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
730 }
731 }
732
733 if (getType().isDynamicDim(dim)) {
734 // Take the sum of the input sizes along the concatenated dim.
735 AffineExpr sum = builder.getAffineDimExpr(0);
736 SmallVector<OpFoldResult> sizes = {
737 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
738 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
739 sum = sum + builder.getAffineDimExpr(idx + 1);
740 sizes.push_back(
741 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
742 }
743 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
744 builder, getLoc(),
745 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
746 } else {
747 // If the result shape is static along the concatenated dim, use the static
748 // shape.
749 reifiedReturnShapes[0][dim] =
750 builder.getIndexAttr(getType().getDimSize(dim));
751 }
752 return success();
753}
754
755void ConcatOp::getAsmResultNames(
756 function_ref<void(Value, StringRef)> setNameFn) {
757 setNameFn(getResult(), "concat");
758}
759
760OpFoldResult ConcatOp::fold(FoldAdaptor) {
761 ValueRange inputs = getInputs();
762 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
763 return inputs[0];
764 return {};
765}
766
767namespace {
768/// Fold a concat op with a single input to a cast.
769struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
770 using OpRewritePattern<ConcatOp>::OpRewritePattern;
771
772 LogicalResult matchAndRewrite(ConcatOp concatOp,
773 PatternRewriter &rewriter) const override {
774 if (concatOp.getInputs().size() != 1)
775 return failure();
776 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
777 concatOp.getInputs()[0]);
778 return success();
779 }
780};
781
782/// Propagate static shapes into the operands of a `tensor.concat`.
783///
784/// `tensor.concat` requires every operand to match on all dimensions except the
785/// concatenation dimension. If one operand is already static in those
786/// dimensions, the other operands may safely be refined to that same static
787/// shape.
788///
789/// Example:
790///
791/// ```mlir
792/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
793/// tensor<?x12xi32>
794/// ```
795/// ->
796/// ```mlir
797/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
798/// %2 = tensor.concat dim(0) %0, %cast :
799/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
800/// ```
801struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
802 using OpRewritePattern<ConcatOp>::OpRewritePattern;
803
804 LogicalResult matchAndRewrite(ConcatOp concatOp,
805 PatternRewriter &rewriter) const override {
806 int64_t dim = concatOp.getDim();
807 RankedTensorType inferredResultType =
808 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
809
810 // Find operands for which a more static shape can be inferred.
811 LogicalResult matched = failure();
812 // Inferred operand shapes are identical in every dimension except the
813 // concatenation dimension.
814 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
815 for (auto [operandIdx, operandType] :
816 llvm::enumerate(concatOp->getOperandTypes())) {
817 // Compute inferred type for operand.
818 inferredOperandShape[dim] =
819 cast<RankedTensorType>(operandType).getDimSize(dim);
820 auto inferredOperandType = RankedTensorType::get(
821 inferredOperandShape, inferredResultType.getElementType());
822
823 // Check if inferred type is more static.
824 if (!preservesStaticInformation(inferredOperandType, operandType)) {
825 matched = success();
826
827 // Use refined operand type and create cast from original operand.
828 auto castOp =
829 rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
830 concatOp.getOperand(operandIdx));
831 rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
832 concatOp->setOperand(operandIdx, castOp->getResult(0));
833 });
834 }
835 }
836
837 return matched;
838 }
839};
840
841// Ensure `tensor.concat`'s result type is at least as static as can be inferred
842// from its operand types.
843///
844/// Example:
845/// ```mlir
846/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
847/// tensor<?x?xi32>
848/// ```
849/// ->
850/// ```mlir
851/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
852/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
853/// tensor<?x?xi32>
854/// ```
855struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
856 using OpRewritePattern<ConcatOp>::OpRewritePattern;
857
858 LogicalResult matchAndRewrite(ConcatOp concatOp,
859 PatternRewriter &rewriter) const override {
860 int64_t dim = concatOp.getDim();
861 RankedTensorType inferredResultType =
862 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
863
864 // The result type should be at least as static as inferred result type.
865 if (preservesStaticInformation(inferredResultType,
866 concatOp.getResultType())) {
867 return failure();
868 }
869
870 auto newConcatOp = rewriter.create<ConcatOp>(
871 concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
872 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
873 newConcatOp);
874
875 return success();
876 }
877};
878} // namespace
879
880void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
881 MLIRContext *context) {
882 results
883 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
884 context);
885}
886
887//===----------------------------------------------------------------------===//
888// DimOp
889//===----------------------------------------------------------------------===//
890
891void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
892 setNameFn(getResult(), "dim");
893}
894
895void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
896 int64_t index) {
897 auto loc = result.location;
898 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
899 build(builder, result, source, indexValue);
900}
901
902std::optional<int64_t> DimOp::getConstantIndex() {
903 return getConstantIntValue(getIndex());
904}
905
906Speculation::Speculatability DimOp::getSpeculatability() {
907 auto constantIndex = getConstantIndex();
908 if (!constantIndex)
909 return Speculation::NotSpeculatable;
910
911 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
912 if (!rankedSourceType)
913 return Speculation::NotSpeculatable;
914
915 if (rankedSourceType.getRank() <= constantIndex)
916 return Speculation::NotSpeculatable;
917
918 return Speculation::Speculatable;
919}
920
921void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
922 SetIntLatticeFn setResultRange) {
923 setResultRange(getResult(),
924 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
925}
926
927OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
928 // All forms of folding require a known index.
929 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
930 if (!index)
931 return {};
932
933 // Folding for unranked types (UnrankedTensorType) is not supported.
934 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
935 if (!tensorType)
936 return {};
937
938 // Out of bound indices produce undefined behavior but are still valid IR.
939 // Don't choke on them.
940 int64_t indexVal = index.getInt();
941 if (indexVal < 0 || indexVal >= tensorType.getRank())
942 return {};
943
944 // Fold if the shape extent along the given index is known.
945 if (!tensorType.isDynamicDim(index.getInt())) {
946 Builder builder(getContext());
947 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
948 }
949
950 Operation *definingOp = getSource().getDefiningOp();
951
952 // Fold dim to the operand of tensor.generate.
953 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
954 auto resultType =
955 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
956 // The case where the type encodes the size of the dimension is handled
957 // above.
958 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
959
960 // Find the operand of the fromElements that corresponds to this index.
961 auto dynExtents = fromElements.getDynamicExtents().begin();
962 for (auto dim : resultType.getShape().take_front(index.getInt()))
963 if (ShapedType::isDynamic(dim))
964 dynExtents++;
965
966 return Value{*dynExtents};
967 }
968
969 // The size at the given index is now known to be a dynamic size.
970 unsigned unsignedIndex = index.getValue().getZExtValue();
971
972 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
973 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
974 // `resolve-shaped-type-result-dims` pass.
975 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
976 sliceOp.isDynamicSize(unsignedIndex)) {
977 return {sliceOp.getDynamicSize(unsignedIndex)};
978 }
979 }
980
981 // dim(cast) -> dim
982 if (succeeded(foldTensorCast(*this)))
983 return getResult();
984
985 return {};
986}
987
988namespace {
989/// Fold dim of a cast into the dim of the source of the tensor cast.
990struct DimOfCastOp : public OpRewritePattern<DimOp> {
991 using OpRewritePattern<DimOp>::OpRewritePattern;
992
993 LogicalResult matchAndRewrite(DimOp dimOp,
994 PatternRewriter &rewriter) const override {
995 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
996 if (!castOp)
997 return failure();
998 Value newSource = castOp.getOperand();
999 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
1000 return success();
1001 }
1002};
1003
1004/// Fold dim of a destination passing style op into the dim of the corresponding
1005/// init.
1006struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1007 using OpRewritePattern<DimOp>::OpRewritePattern;
1008
1009 LogicalResult matchAndRewrite(DimOp dimOp,
1010 PatternRewriter &rewriter) const override {
1011 auto source = dimOp.getSource();
1012 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1013 if (!destOp)
1014 return failure();
1015
1016 auto resultIndex = cast<OpResult>(source).getResultNumber();
1017 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1018
1019 rewriter.modifyOpInPlace(
1020 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1021 return success();
1022 }
1023};
1024
1025/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1026/// operand.
1027struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1028 using OpRewritePattern<DimOp>::OpRewritePattern;
1029
1030 LogicalResult matchAndRewrite(DimOp dim,
1031 PatternRewriter &rewriter) const override {
1032 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1033
1034 if (!reshape)
1035 return failure();
1036
1037 // Since tensors are immutable we don't need to worry about where to place
1038 // the extract call
1039 rewriter.setInsertionPointAfter(dim);
1040 Location loc = dim.getLoc();
1041 Value extract =
1042 rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
1043 if (extract.getType() != dim.getType())
1044 extract =
1045 rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
1046 rewriter.replaceOp(dim, extract);
1047 return success();
1048 }
1049};
1050} // namespace
1051
1052void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1053 MLIRContext *context) {
1054 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1055}
1056
1057//===----------------------------------------------------------------------===//
1058// EmptyOp
1059//===----------------------------------------------------------------------===//
1060
1061void EmptyOp::build(OpBuilder &builder, OperationState &result,
1062 ArrayRef<int64_t> staticShape, Type elementType,
1063 Attribute encoding) {
1064 assert(none_of(staticShape, ShapedType::isDynamic) &&
1065 "expected only static sizes");
1066 build(builder, result, staticShape, elementType, ValueRange{}, encoding);
1067}
1068
1069void EmptyOp::build(OpBuilder &builder, OperationState &result,
1070 ArrayRef<int64_t> staticShape, Type elementType,
1071 ValueRange dynamicSizes, Attribute encoding) {
1072 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1073 build(builder, result, tensorType, dynamicSizes);
1074}
1075
1076void EmptyOp::build(OpBuilder &builder, OperationState &result,
1077 ArrayRef<OpFoldResult> sizes, Type elementType,
1078 Attribute encoding) {
1079 SmallVector<int64_t> staticShape;
1080 SmallVector<Value> dynamicSizes;
1081 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
1082 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1083}
1084
1085LogicalResult EmptyOp::verify() {
1086 if (getType().getNumDynamicDims() != getDynamicSizes().size())
1087 return emitOpError("incorrect number of dynamic sizes, has ")
1088 << getDynamicSizes().size() << ", expected "
1089 << getType().getNumDynamicDims();
1090 return success();
1091}
1092
1093LogicalResult
1094EmptyOp::reifyResultShapes(OpBuilder &builder,
1095 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1096 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1097 unsigned ctr = 0;
1098 for (int64_t i = 0; i < getType().getRank(); ++i) {
1099 if (getType().isDynamicDim(i)) {
1100 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1101 } else {
1102 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
1103 }
1104 }
1105 return success();
1106}
1107
1108Value EmptyOp::getDynamicSize(unsigned idx) {
1109 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1110 unsigned ctr = 0;
1111 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1112 if (getType().isDynamicDim(i))
1113 ++ctr;
1114 return getDynamicSizes()[ctr];
1115}
1116
1117SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1118 SmallVector<OpFoldResult> result;
1119 unsigned ctr = 0;
1120 OpBuilder b(getContext());
1121 for (int64_t i = 0; i < getType().getRank(); ++i) {
1122 if (getType().isDynamicDim(i)) {
1123 result.push_back(getDynamicSizes()[ctr++]);
1124 } else {
1125 result.push_back(b.getIndexAttr(getType().getShape()[i]));
1126 }
1127 }
1128 return result;
1129}
1130
1131namespace {
1132/// Change the type of the result of a `tensor.empty` by making the result
1133/// type statically sized along dimensions that in the original operation were
1134/// defined as dynamic, but the size was defined using a `constant` op. For
1135/// example
1136///
1137/// %c5 = arith.constant 5: index
1138/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1139///
1140/// to
1141///
1142/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1143struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1144 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1145
1146 LogicalResult matchAndRewrite(EmptyOp op,
1147 PatternRewriter &rewriter) const override {
1148 SmallVector<Value> foldedDynamicSizes;
1149 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1150 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1151
1152 // Stop here if no dynamic size was promoted to static.
1153 if (foldedTensorType == op.getType())
1154 return failure();
1155
1156 auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
1157 foldedDynamicSizes);
1158 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1159 return success();
1160 }
1161};
1162
1163struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1164 using OpRewritePattern<DimOp>::OpRewritePattern;
1165
1166 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1167 PatternRewriter &rewriter) const override {
1168 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1169 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1170 if (!emptyTensorOp || !maybeConstantIndex)
1171 return failure();
1172 auto emptyTensorType = emptyTensorOp.getType();
1173 if (*maybeConstantIndex < 0 ||
1174 *maybeConstantIndex >= emptyTensorType.getRank() ||
1175 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1176 return failure();
1177 rewriter.replaceOp(dimOp,
1178 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1179 return success();
1180 }
1181};
1182
1183/// Canonicalize
1184///
1185/// ```mlir
1186/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1187/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1188/// ```
1189///
1190/// into
1191///
1192/// ```mlir
1193/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1194/// ```
1195///
1196/// This assumes the input program is correct in terms of its shape. So it is
1197/// safe to assume that `%d0` is in fact 4.
1198struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1199 using OpRewritePattern<CastOp>::OpRewritePattern;
1200
1201 LogicalResult matchAndRewrite(CastOp castOp,
1202 PatternRewriter &rewriter) const override {
1203 if (!canFoldIntoProducerOp(castOp))
1204 return failure();
1205 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1206 if (!producer)
1207 return failure();
1208
1209 auto resultType =
1210 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1211 ArrayRef<int64_t> resultShape = resultType.getShape();
1212 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1213 SmallVector<OpFoldResult> newMixedSizes;
1214 newMixedSizes.reserve(N: currMixedSizes.size());
1215 assert(resultShape.size() == currMixedSizes.size() &&
1216 "mismatch in result shape and sizes of empty op");
1217 for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1218 int64_t newDim = std::get<0>(it);
1219 OpFoldResult currDim = std::get<1>(it);
1220 // Case 1: The empty tensor dim is static. Check that the tensor cast
1221 // result dim matches.
1222 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1223 if (ShapedType::isDynamic(newDim) ||
1224 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1225 // Something is off, the cast result shape cannot be more dynamic
1226 // than the empty tensor result shape (enforced by
1227 // `canFoldIntoProducer`). Abort for now.
1228 return rewriter.notifyMatchFailure(
1229 producer, "mismatch in static value of shape of empty tensor "
1230 "result and cast result");
1231 }
1232 newMixedSizes.push_back(attr);
1233 continue;
1234 }
1235
1236 // Case 2 : The tensor cast shape is static, but empty tensor result
1237 // shape is dynamic.
1238 if (!ShapedType::isDynamic(newDim)) {
1239 newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1240 continue;
1241 }
1242
1243 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1244 // shape is dynamic. Use the dynamic value from the empty tensor op.
1245 newMixedSizes.push_back(currDim);
1246 }
1247
1248 // TODO: Do not drop tensor encoding.
1249 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
1250 resultType.getElementType());
1251 return success();
1252 }
1253};
1254
1255} // namespace
1256
1257void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1258 MLIRContext *context) {
1259 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1260 ReplaceEmptyTensorStaticShapeDims>(context);
1261}
1262
1263//===----------------------------------------------------------------------===//
1264// ExtractOp
1265//===----------------------------------------------------------------------===//
1266
1267namespace {
1268
1269/// Canonicalizes the pattern of the form
1270///
1271/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1272/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1273///
1274/// to
1275///
1276/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1277struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1278 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1279
1280 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1281 PatternRewriter &rewriter) const final {
1282 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1283 if (!tensorCast)
1284 return failure();
1285 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1286 return failure();
1287 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1288 extract, tensorCast.getSource(), extract.getIndices());
1289 return success();
1290 }
1291};
1292
1293/// Canonicalizes the pattern of the form
1294///
1295/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1296/// tensor<12xf64>
1297/// %extracted_element = tensor.extract %val[%c10] :
1298/// tensor<12xf64>
1299///
1300/// to
1301///
1302/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1303struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1304 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1305
1306 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1307 PatternRewriter &rewriter) const final {
1308 auto collapseOp =
1309 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1310 if (!collapseOp)
1311 return failure();
1312 if (!collapseOp.getSrcType().hasStaticShape())
1313 return failure();
1314
1315 auto sourceSizes = collapseOp.getSrcType().getShape();
1316
1317 SmallVector<Value> indices(extractOp.getIndices().begin(),
1318 extractOp.getIndices().end());
1319 SmallVector<Value> sourceIndices;
1320 for (auto [index, group] :
1321 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1322 assert(!group.empty() && "association indices groups cannot be empty");
1323 auto groupSize = group.size();
1324
1325 if (groupSize == 1) {
1326 sourceIndices.push_back(index);
1327 continue;
1328 }
1329
1330 SmallVector<int64_t> basis =
1331 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1332 auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
1333 extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1334 llvm::append_range(sourceIndices, delinearize.getResults());
1335 }
1336 if (collapseOp.getReassociationIndices().empty()) {
1337 auto zeroAffineMap = rewriter.getConstantAffineMap(val: 0);
1338 int64_t srcRank =
1339 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1340 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1341 rewriter, extractOp.getLoc(), zeroAffineMap,
1342 ArrayRef<OpFoldResult>{});
1343 for (int64_t i = 0; i < srcRank; i++) {
1344 sourceIndices.push_back(
1345 Elt: getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1346 }
1347 }
1348
1349 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1350 extractOp, collapseOp.getSrc(), sourceIndices);
1351 return success();
1352 }
1353};
1354
1355} // namespace
1356
1357void ExtractOp::getAsmResultNames(
1358 function_ref<void(Value, StringRef)> setNameFn) {
1359 setNameFn(getResult(), "extracted");
1360}
1361
1362LogicalResult ExtractOp::verify() {
1363 // Verify the # indices match if we have a ranked type.
1364 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1365 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1366 return emitOpError("incorrect number of indices for extract_element");
1367 return success();
1368}
1369
1370/// If we have an ExtractOp consuming an InsertOp with the same
1371/// indices, we can return the InsertOp's scalar directly.
1372// TODO: This only checks the immediate producer; extend to go up the
1373// insert/extract chain if the slices are disjoint.
1374static Value foldExtractAfterInsert(ExtractOp extractOp) {
1375 auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1376
1377 auto isSame = [](Value a, Value b) {
1378 return getAsOpFoldResult(val: a) == getAsOpFoldResult(val: b);
1379 };
1380 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1381 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1382 return insertOp.getScalar();
1383
1384 return {};
1385}
1386
1387OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1388 if (Attribute tensor = adaptor.getTensor()) {
1389 // If this is a splat elements attribute, simply return the value.
1390 // All of the elements of a splat attribute are the same.
1391 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1392 return splatTensor.getSplatValue<Attribute>();
1393
1394 // If this is a dense resource elements attribute, return.
1395 if (isa<DenseResourceElementsAttr>(tensor))
1396 return {};
1397 }
1398
1399 // Collect the constant indices into the tensor.
1400 SmallVector<uint64_t, 8> indices;
1401 for (Attribute indice : adaptor.getIndices()) {
1402 if (!indice || !llvm::isa<IntegerAttr>(indice))
1403 return {};
1404 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1405 }
1406
1407 // Fold extract(from_elements(...)).
1408 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1409 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1410 auto rank = tensorType.getRank();
1411 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1412 "rank mismatch");
1413 int flatIndex = 0;
1414 int stride = 1;
1415 for (int i = rank - 1; i >= 0; --i) {
1416 flatIndex += indices[i] * stride;
1417 stride *= tensorType.getDimSize(i);
1418 }
1419 // Prevent out of bounds accesses. This can happen in invalid code that
1420 // will never execute.
1421 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1422 flatIndex < 0)
1423 return {};
1424 return fromElementsOp.getElements()[flatIndex];
1425 }
1426
1427 // If this is an elements attribute, query the value at the given indices.
1428 if (Attribute tensor = adaptor.getTensor()) {
1429 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1430 if (elementsAttr && elementsAttr.isValidIndex(indices))
1431 return elementsAttr.getValues<Attribute>()[indices];
1432 }
1433
1434 if (Value result = foldExtractAfterInsert(*this))
1435 return result;
1436
1437 return {};
1438}
1439
1440void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1441 MLIRContext *context) {
1442 results.add<ExtractFromTensorCast>(context);
1443}
1444
1445void mlir::tensor::populateFoldCollapseExtractPatterns(
1446 RewritePatternSet &patterns) {
1447 patterns.add<ExtractFromCollapseShape>(arg: patterns.getContext());
1448}
1449
1450//===----------------------------------------------------------------------===//
1451// FromElementsOp
1452//===----------------------------------------------------------------------===//
1453
1454void FromElementsOp::getAsmResultNames(
1455 function_ref<void(Value, StringRef)> setNameFn) {
1456 setNameFn(getResult(), "from_elements");
1457}
1458
1459void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1460 ValueRange elements) {
1461 assert(!elements.empty() && "expected at least one element");
1462 Type resultType = RankedTensorType::get(
1463 {static_cast<int64_t>(elements.size())}, elements.front().getType());
1464 build(builder, result, resultType, elements);
1465}
1466
1467OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1468 if (!llvm::is_contained(adaptor.getElements(), nullptr))
1469 return DenseElementsAttr::get(getType(), adaptor.getElements());
1470 return {};
1471}
1472
1473namespace {
1474
1475// Pushes the index_casts that occur before extractions to after the extract.
1476// This minimizes type conversion in some cases and enables the extract
1477// canonicalizer. This changes:
1478//
1479// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1480// %extract = tensor.extract %cast[%index] : tensor<1xindex>
1481//
1482// to the following:
1483//
1484// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1485// %cast = arith.index_cast %extract : i32 to index
1486//
1487// to just %element.
1488//
1489// Consider expanding this to a template and handle all tensor cast
1490// operations.
1491struct ExtractElementFromIndexCast
1492 : public OpRewritePattern<tensor::ExtractOp> {
1493 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1494
1495 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1496 PatternRewriter &rewriter) const final {
1497 Location loc = extract.getLoc();
1498 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1499 if (!indexCast)
1500 return failure();
1501
1502 Type elementTy = getElementTypeOrSelf(indexCast.getIn());
1503
1504 auto newExtract = rewriter.create<tensor::ExtractOp>(
1505 loc, elementTy, indexCast.getIn(), extract.getIndices());
1506
1507 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
1508 newExtract);
1509
1510 return success();
1511 }
1512};
1513
1514} // namespace
1515
1516void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1517 MLIRContext *context) {
1518 results.add<ExtractElementFromIndexCast>(context);
1519}
1520
1521//===----------------------------------------------------------------------===//
1522// GatherOp
1523//===----------------------------------------------------------------------===//
1524
1525void GatherOp::getAsmResultNames(
1526 function_ref<void(Value, StringRef)> setNameFn) {
1527 setNameFn(getResult(), "gather");
1528}
1529
1530/// Return the inferred result type for a gatherOp where:
1531/// - sourceType is the type of the source tensor gathered from
1532/// - indicesType is the type of the indices used to gather
1533/// - gatherDims are the dims along which the gather occurs.
1534/// Return a full rank or ranked-reduced variant of the type depending on
1535/// the value of rankReduced.
1536///
1537/// The leading dimensions of the index tensor give the result tensor its
1538/// leading dimensions.
1539/// The trailing dimensions of the result tensor are obtained from the source
1540/// tensor by setting the dimensions specified in gather_dims to `1` (if
1541/// rankedReduced is false), or skipping them (otherwise).
1542RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1543 RankedTensorType indicesType,
1544 ArrayRef<int64_t> gatherDims,
1545 bool rankReduced) {
1546 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1547 resultShape.reserve(resultShape.size() + sourceType.getRank());
1548 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1549 if (llvm::binary_search(gatherDims, idx)) {
1550 if (!rankReduced)
1551 resultShape.push_back(1);
1552 continue;
1553 }
1554 resultShape.push_back(sourceType.getDimSize(idx));
1555 }
1556 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1557}
1558
1559static LogicalResult
1560verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims,
1561 ArrayRef<int64_t> indices, int64_t rank,
1562 StringRef gatherOrScatter, StringRef sourceOrDest) {
1563 if (dims.empty())
1564 return op->emitOpError(message: gatherOrScatter) << "_dims must be non-empty";
1565
1566 int64_t numGatherDims = dims.size();
1567 if (numGatherDims > rank)
1568 return op->emitOpError(message: gatherOrScatter)
1569 << "_dims overflow " << sourceOrDest << " rank";
1570 if (indices.empty() || indices.back() != numGatherDims)
1571 return op->emitOpError(message: gatherOrScatter)
1572 << "_dims length must match the size of last dimension of indices";
1573 for (int64_t val : dims) {
1574 if (val < 0)
1575 return op->emitOpError(message: gatherOrScatter)
1576 << "_dims value must be non-negative";
1577 if (val >= rank)
1578 return op->emitOpError(message: gatherOrScatter)
1579 << "_dims value must be smaller than " << sourceOrDest << " rank";
1580 }
1581 for (int64_t i = 1; i < numGatherDims; ++i) {
1582 if (dims[i - 1] >= dims[i])
1583 return op->emitOpError(message: gatherOrScatter)
1584 << "_dims values must be strictly increasing";
1585 }
1586 return success();
1587}
1588
1589LogicalResult GatherOp::verify() {
1590 int64_t sourceRank = getSourceType().getRank();
1591 ArrayRef<int64_t> gatherDims = getGatherDims();
1592 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1593 getIndicesType().getShape(), sourceRank,
1594 "gather", "source")))
1595 return failure();
1596
1597 RankedTensorType expectedResultType = GatherOp::inferResultType(
1598 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
1599 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1600 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
1601 if (getResultType() != expectedResultType &&
1602 getResultType() != expectedRankReducedResultType) {
1603 return emitOpError("result type "
1604 "mismatch: "
1605 "expected ")
1606 << expectedResultType << " or its rank-reduced variant "
1607 << expectedRankReducedResultType << " (got: " << getResultType()
1608 << ")";
1609 }
1610
1611 return success();
1612}
1613
1614OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1615 if (OpFoldResult reshapedSource = reshapeConstantSource(
1616 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1617 getResult().getType()))
1618 return reshapedSource;
1619 return {};
1620}
1621
1622//===----------------------------------------------------------------------===//
1623// InsertOp
1624//===----------------------------------------------------------------------===//
1625
1626void InsertOp::getAsmResultNames(
1627 function_ref<void(Value, StringRef)> setNameFn) {
1628 setNameFn(getResult(), "inserted");
1629}
1630
1631LogicalResult InsertOp::verify() {
1632 // Verify the # indices match if we have a ranked type.
1633 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1634 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1635 return emitOpError("incorrect number of indices");
1636 return success();
1637}
1638
1639OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1640 Attribute scalar = adaptor.getScalar();
1641 Attribute dest = adaptor.getDest();
1642 if (scalar && dest)
1643 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1644 if (scalar == splatDest.getSplatValue<Attribute>())
1645 return dest;
1646 return {};
1647}
1648
1649//===----------------------------------------------------------------------===//
1650// GenerateOp
1651//===----------------------------------------------------------------------===//
1652
1653void GenerateOp::getAsmResultNames(
1654 function_ref<void(Value, StringRef)> setNameFn) {
1655 setNameFn(getResult(), "generated");
1656}
1657
1658LogicalResult GenerateOp::reifyResultShapes(
1659 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1660 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
1661 int idx = 0;
1662 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1663 if (getType().isDynamicDim(dim)) {
1664 reifiedReturnShapes[0][dim] = getOperand(idx++);
1665 } else {
1666 reifiedReturnShapes[0][dim] =
1667 builder.getIndexAttr(getType().getDimSize(dim));
1668 }
1669 }
1670 return success();
1671}
1672
1673LogicalResult GenerateOp::verify() {
1674 // Ensure that the tensor type has as many dynamic dimensions as are
1675 // specified by the operands.
1676 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1677 if (getNumOperands() != resultType.getNumDynamicDims())
1678 return emitError("must have as many index operands as dynamic extents "
1679 "in the result type");
1680 return success();
1681}
1682
1683LogicalResult GenerateOp::verifyRegions() {
1684 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1685 // Ensure that region arguments span the index space.
1686 if (!llvm::all_of(getBody().getArgumentTypes(),
1687 [](Type ty) { return ty.isIndex(); }))
1688 return emitError("all body arguments must be index");
1689 if (getBody().getNumArguments() != resultTy.getRank())
1690 return emitError("must have one body argument per input dimension");
1691
1692 // Ensure that the region yields an element of the right type.
1693 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1694
1695 if (yieldOp.getValue().getType() != resultTy.getElementType())
1696 return emitOpError(
1697 "body must be terminated with a `yield` operation of the tensor "
1698 "element type");
1699
1700 return success();
1701}
1702
1703void GenerateOp::build(
1704 OpBuilder &b, OperationState &result, Type resultTy,
1705 ValueRange dynamicExtents,
1706 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1707 build(b, result, resultTy, dynamicExtents);
1708
1709 // Build and populate body.
1710 OpBuilder::InsertionGuard guard(b);
1711 Region *bodyRegion = result.regions.front().get();
1712 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1713 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1714 SmallVector<Location, 2> argumentLocs(rank, result.location);
1715 Block *bodyBlock =
1716 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
1717 bodyBuilder(b, result.location, bodyBlock->getArguments());
1718}
1719
1720namespace {
1721
1722/// Canonicalizes tensor.generate operations with a constant
1723/// operand into the equivalent operation with the operand expressed in the
1724/// result type, instead. We also insert a type cast to make sure that the
1725/// resulting IR is still well-typed.
1726struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1727 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1728
1729 LogicalResult matchAndRewrite(GenerateOp generateOp,
1730 PatternRewriter &rewriter) const final {
1731 SmallVector<Value> foldedDynamicSizes;
1732 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1733 generateOp.getType(), generateOp.getDynamicExtents(),
1734 foldedDynamicSizes);
1735
1736 // Stop here if no dynamic size was promoted to static.
1737 if (foldedTensorType == generateOp.getType())
1738 return failure();
1739
1740 auto loc = generateOp.getLoc();
1741 auto newOp =
1742 rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1743 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
1744 newOp.getBody().begin());
1745 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
1746 generateOp.getType(), newOp);
1747 return success();
1748 }
1749};
1750
1751/// Canonicalizes the pattern of the form
1752///
1753/// %tensor = tensor.generate %x {
1754/// ^bb0(%arg0: index):
1755/// <computation>
1756/// yield %1 : index
1757/// } : tensor<?xindex>
1758/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1759///
1760/// to just <computation> with %arg0 replaced by %c0. We only do this if the
1761/// tensor.generate operation has no side-effects.
1762struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1763 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1764
1765 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1766 PatternRewriter &rewriter) const final {
1767 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1768 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
1769 return failure();
1770
1771 IRMapping mapping;
1772 Block *body = &tensorFromElements.getBody().front();
1773 mapping.map(body->getArguments(), extract.getIndices());
1774 for (auto &op : body->without_terminator())
1775 rewriter.clone(op, mapping);
1776
1777 auto yield = cast<YieldOp>(body->getTerminator());
1778
1779 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
1780 return success();
1781 }
1782};
1783
1784} // namespace
1785
1786void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1787 MLIRContext *context) {
1788 // TODO: Move extract pattern to tensor::ExtractOp.
1789 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1790}
1791
1792//===----------------------------------------------------------------------===//
1793// RankOp
1794//===----------------------------------------------------------------------===//
1795
1796void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1797 setNameFn(getResult(), "rank");
1798}
1799
1800OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1801 // Constant fold rank when the rank of the operand is known.
1802 auto type = getOperand().getType();
1803 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1804 if (shapedType && shapedType.hasRank())
1805 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1806 return IntegerAttr();
1807}
1808
1809//===----------------------------------------------------------------------===//
1810// ReshapeOp
1811//===----------------------------------------------------------------------===//
1812
1813void ReshapeOp::getAsmResultNames(
1814 function_ref<void(Value, StringRef)> setNameFn) {
1815 setNameFn(getResult(), "reshape");
1816}
1817
1818static int64_t getNumElements(ShapedType type) {
1819 int64_t numElements = 1;
1820 for (auto dim : type.getShape())
1821 numElements *= dim;
1822 return numElements;
1823}
1824
1825LogicalResult ReshapeOp::verify() {
1826 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1827 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1828
1829 if (operandType.getElementType() != resultType.getElementType())
1830 return emitOpError("element types of source and destination tensor "
1831 "types should be the same");
1832
1833 int64_t shapeSize =
1834 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
1835 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1836 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1837
1838 if (resultRankedType) {
1839 if (operandRankedType && resultRankedType.hasStaticShape() &&
1840 operandRankedType.hasStaticShape()) {
1841 if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
1842 return emitOpError("source and destination tensor should have the "
1843 "same number of elements");
1844 }
1845 if (ShapedType::isDynamic(shapeSize))
1846 return emitOpError("cannot use shape operand with dynamic length to "
1847 "reshape to statically-ranked tensor type");
1848 if (shapeSize != resultRankedType.getRank())
1849 return emitOpError(
1850 "length of shape operand differs from the result's tensor rank");
1851 }
1852 return success();
1853}
1854
1855OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1856 if (OpFoldResult reshapedSource = reshapeConstantSource(
1857 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1858 getResult().getType()))
1859 return reshapedSource;
1860
1861 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1862 // producer's input instead as the original tensor to reshape. This could
1863 // render such producer dead code.
1864 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1865 getSourceMutable().assign(reshapeOpProducer.getSource());
1866 return getResult();
1867 }
1868
1869 auto source = getSource();
1870 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1871 auto resultTy = dyn_cast<RankedTensorType>(getType());
1872 if (!sourceTy || !resultTy || sourceTy != resultTy)
1873 return {};
1874
1875 // If the source and result are both 1D tensors and have the same type, the
1876 // reshape has no effect, even if the tensor is dynamically shaped.
1877 if (sourceTy.getRank() == 1)
1878 return source;
1879
1880 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1881 auto elements = fromElements.getElements();
1882 bool dynamicNoop =
1883 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1884 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1885 auto element = elements[id];
1886
1887 if (auto cst = getConstantIntValue(element)) {
1888 dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1889 continue;
1890 }
1891
1892 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1893 dynamicNoop &= dimOp.getSource() == source;
1894
1895 auto cst = getConstantIntValue(dimOp.getIndex());
1896 dynamicNoop &=
1897 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1898 continue;
1899 }
1900
1901 dynamicNoop = false;
1902 break;
1903 }
1904
1905 if (dynamicNoop)
1906 return source;
1907 }
1908
1909 return {};
1910}
1911
1912//===----------------------------------------------------------------------===//
1913// Reassociative reshape ops
1914//===----------------------------------------------------------------------===//
1915
1916void CollapseShapeOp::getAsmResultNames(
1917 function_ref<void(Value, StringRef)> setNameFn) {
1918 setNameFn(getResult(), "collapsed");
1919}
1920
1921void ExpandShapeOp::getAsmResultNames(
1922 function_ref<void(Value, StringRef)> setNameFn) {
1923 setNameFn(getResult(), "expanded");
1924}
1925
1926int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1927 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1928 "invalid resultDim");
1929 for (const auto &it : llvm::enumerate(getReassociationIndices()))
1930 if (llvm::is_contained(it.value(), resultDim))
1931 return it.index();
1932 llvm_unreachable("could not find reassociation group");
1933}
1934
1935FailureOr<SmallVector<OpFoldResult>>
1936ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1937 RankedTensorType expandedType,
1938 ArrayRef<ReassociationIndices> reassociation,
1939 ArrayRef<OpFoldResult> inputShape) {
1940 std::optional<SmallVector<OpFoldResult>> outputShape =
1941 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1942 inputShape);
1943 if (!outputShape)
1944 return failure();
1945 return *outputShape;
1946}
1947
1948SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1949 return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
1950}
1951
1952void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1953 Type resultType, Value src,
1954 ArrayRef<ReassociationIndices> reassociation,
1955 ArrayRef<OpFoldResult> outputShape) {
1956 auto [staticOutputShape, dynamicOutputShape] =
1957 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
1958 build(builder, result, cast<RankedTensorType>(resultType), src,
1959 getReassociationIndicesAttribute(builder, reassociation),
1960 dynamicOutputShape, staticOutputShape);
1961}
1962
1963void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1964 Type resultType, Value src,
1965 ArrayRef<ReassociationIndices> reassociation) {
1966 SmallVector<OpFoldResult> inputShape =
1967 getMixedSizes(builder, result.location, src);
1968 auto tensorResultTy = cast<RankedTensorType>(resultType);
1969 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1970 builder, result.location, tensorResultTy, reassociation, inputShape);
1971 SmallVector<OpFoldResult> outputShapeOrEmpty;
1972 if (succeeded(outputShape)) {
1973 outputShapeOrEmpty = *outputShape;
1974 }
1975 build(builder, result, tensorResultTy, src, reassociation,
1976 outputShapeOrEmpty);
1977}
1978
1979SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1980 return getSymbolLessAffineMaps(getReassociationExprs());
1981}
1982SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1983 return convertReassociationIndicesToExprs(getContext(),
1984 getReassociationIndices());
1985}
1986
1987SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1988 return getSymbolLessAffineMaps(getReassociationExprs());
1989}
1990SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1991 return convertReassociationIndicesToExprs(getContext(),
1992 getReassociationIndices());
1993}
1994
1995RankedTensorType CollapseShapeOp::inferCollapsedType(
1996 RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1997 return inferCollapsedType(
1998 type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1999 type.getContext(), reassociation)));
2000}
2001
2002/// Compute the RankedTensorType obtained by applying `reassociation` to
2003/// `type`.
2004RankedTensorType
2005CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2006 ArrayRef<AffineMap> reassociation) {
2007 auto shape = type.getShape();
2008 SmallVector<int64_t, 4> newShape;
2009 newShape.reserve(reassociation.size());
2010
2011 // Use the fact that reassociation is valid to simplify the logic: only use
2012 // each map's rank.
2013 assert(isReassociationValid(reassociation) && "invalid reassociation");
2014 unsigned currentDim = 0;
2015 for (AffineMap m : reassociation) {
2016 unsigned dim = m.getNumResults();
2017 auto band = shape.slice(currentDim, dim);
2018 int64_t size = 1;
2019 if (llvm::is_contained(band, ShapedType::kDynamic))
2020 size = ShapedType::kDynamic;
2021 else
2022 for (unsigned d = 0; d < dim; ++d)
2023 size *= shape[currentDim + d];
2024 newShape.push_back(size);
2025 currentDim += dim;
2026 }
2027
2028 return RankedTensorType::get(newShape, type.getElementType());
2029}
2030
2031void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2032 ArrayRef<ReassociationIndices> reassociation,
2033 ArrayRef<NamedAttribute> attrs) {
2034 auto resultType = inferCollapsedType(
2035 llvm::cast<RankedTensorType>(src.getType()),
2036 getSymbolLessAffineMaps(
2037 convertReassociationIndicesToExprs(b.getContext(), reassociation)));
2038 result.addAttribute(getReassociationAttrStrName(),
2039 getReassociationIndicesAttribute(b, reassociation));
2040 build(b, result, resultType, src, attrs);
2041}
2042
2043template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2044 TensorReshapeOp, ExpandShapeOp>::value>
2045static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2046 RankedTensorType expandedType,
2047 RankedTensorType collapsedType) {
2048 if (failed(
2049 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2050 return failure();
2051
2052 auto maps = op.getReassociationMaps();
2053 RankedTensorType expectedType =
2054 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2055 if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
2056 return op.emitOpError("expected collapsed type to be ")
2057 << expectedType << ", but got " << collapsedType;
2058 return success();
2059}
2060
2061LogicalResult ExpandShapeOp::verify() {
2062 auto srcType = getSrcType();
2063 auto resultType = getResultType();
2064
2065 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2066 return emitOpError("expected number of static shape dims to be equal to "
2067 "the output rank (")
2068 << resultType.getRank() << ") but found "
2069 << getStaticOutputShape().size() << " inputs instead";
2070
2071 if ((int64_t)getOutputShape().size() !=
2072 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2073 return emitOpError("mismatch in dynamic dims in output_shape and "
2074 "static_output_shape: static_output_shape has ")
2075 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2076 << " dynamic dims while output_shape has " << getOutputShape().size()
2077 << " values";
2078
2079 return verifyTensorReshapeOp(*this, resultType, srcType);
2080}
2081
2082LogicalResult CollapseShapeOp::verify() {
2083 return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
2084}
2085
2086namespace {
2087/// Reshape of a splat constant can be replaced with a constant of the result
2088/// type.
2089template <typename TensorReshapeOp>
2090struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2091 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2092 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2093 PatternRewriter &rewriter) const override {
2094 DenseElementsAttr attr;
2095 if (!matchPattern(reshapeOp.getSrc(), m_Constant(bind_value: &attr)))
2096 return failure();
2097 if (!attr || !attr.isSplat())
2098 return failure();
2099 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
2100 reshapeOp.getResultType(), attr.getRawData());
2101 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2102 return success();
2103 }
2104};
2105
2106// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2107template <typename TensorReshapeOp>
2108class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2109public:
2110 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2111
2112 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2113 PatternRewriter &rewriter) const override {
2114 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2115 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2116 return failure();
2117
2118 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2119 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2120 return success();
2121 }
2122};
2123
2124/// Reshape of a FromElements can be replaced with a FromElements of the
2125/// result type
2126template <typename TensorReshapeOp>
2127struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2128 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2129 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2130 PatternRewriter &rewriter) const override {
2131 auto fromElements =
2132 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2133 if (!fromElements)
2134 return failure();
2135
2136 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2137
2138 if (!shapedTy.hasStaticShape())
2139 return failure();
2140
2141 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2142 fromElements.getElements());
2143 return success();
2144 }
2145};
2146
2147// Fold CastOp into CollapseShapeOp when adding static information.
2148struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2149 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2150
2151 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2152 PatternRewriter &rewriter) const override {
2153 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2154 if (!tensor::canFoldIntoConsumerOp(castOp))
2155 return failure();
2156
2157 RankedTensorType srcType =
2158 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2159 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2160 srcType, collapseShapeOp.getReassociationMaps());
2161
2162 if (newResultType == collapseShapeOp.getResultType()) {
2163 rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
2164 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2165 });
2166 } else {
2167 auto newOp = rewriter.create<CollapseShapeOp>(
2168 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
2169 collapseShapeOp.getReassociation());
2170 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2171 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2172 }
2173 return success();
2174 }
2175};
2176
2177/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2178/// matching constant output_shape operands of the expand. This makes the
2179/// `tensor.expand_shape` more static and creates a consumer cast that can be
2180/// propagated further.
2181struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2182 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2183
2184 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2185 PatternRewriter &rewriter) const override {
2186 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2187 if (!canFoldIntoConsumerOp(castOp))
2188 return failure();
2189
2190 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2191 SmallVector<ReassociationIndices, 4> reassoc =
2192 expandOp.getReassociationIndices();
2193
2194 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2195 SmallVector<Value> dynamicOutputShape;
2196 auto outputIt = expandOp.getOutputShape().begin();
2197
2198 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2199 for (uint64_t outDim : innerReassoc) {
2200 if (!ShapedType::isDynamic(newOutputShape[outDim]))
2201 continue;
2202
2203 // If the cast's src type is dynamic, don't infer any of the
2204 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2205 // least one of the expanded dimensions to be dynamic if the input is
2206 // dynamic.
2207 Value val = *outputIt;
2208 ++outputIt;
2209 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2210 dynamicOutputShape.push_back(val);
2211 continue;
2212 }
2213
2214 APInt cst;
2215 if (matchPattern(val, m_ConstantInt(&cst))) {
2216 newOutputShape[outDim] = cst.getSExtValue();
2217 } else {
2218 dynamicOutputShape.push_back(val);
2219 }
2220 }
2221 }
2222
2223 // Couldn't match any values, nothing to change
2224 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2225 return failure();
2226
2227 // Calculate the input shape from the output
2228 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2229 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2230 for (auto outDim : reassoc[inDim]) {
2231 auto ofr = newOutputShape[outDim];
2232 if (ShapedType::isDynamic(ofr)) {
2233 newInputShape[inDim] = ShapedType::kDynamic;
2234 break;
2235 }
2236 newInputShape[inDim] *= ofr;
2237 }
2238 }
2239
2240 SmallVector<OpFoldResult> outputOfr =
2241 getMixedValues(staticValues: newOutputShape, dynamicValues: dynamicOutputShape, b&: rewriter);
2242 auto inputType = RankedTensorType::get(
2243 newInputShape, expandOp.getSrcType().getElementType());
2244 auto outputType = RankedTensorType::get(
2245 newOutputShape, expandOp.getSrcType().getElementType());
2246 auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
2247 expandOp.getSrc());
2248 auto newExpand = rewriter.create<ExpandShapeOp>(
2249 expandOp.getLoc(), outputType, inputCast.getResult(),
2250 expandOp.getReassociationIndices(), outputOfr);
2251 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2252 newExpand.getResult());
2253 return success();
2254 }
2255};
2256} // namespace
2257
2258void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2259 MLIRContext *context) {
2260 results.add<
2261 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2262 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
2263 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2264 FoldReshapeWithSplat<ExpandShapeOp>,
2265 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2266}
2267
2268void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2269 MLIRContext *context) {
2270 results.add<
2271 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2272 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2273 tensor::DimOp, RankedTensorType>,
2274 FoldReshapeWithConstant<CollapseShapeOp>,
2275 FoldReshapeWithSplat<CollapseShapeOp>,
2276 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2277 context);
2278}
2279
2280OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2281 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2282 adaptor.getOperands());
2283}
2284
2285OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2286 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2287 adaptor.getOperands());
2288}
2289
2290//===----------------------------------------------------------------------===//
2291// ExtractSliceOp
2292//===----------------------------------------------------------------------===//
2293
2294void ExtractSliceOp::getAsmResultNames(
2295 function_ref<void(Value, StringRef)> setNameFn) {
2296 setNameFn(getResult(), "extracted_slice");
2297}
2298
2299/// An extract_slice result type can be inferred, when it is not
2300/// rank-reduced, from the source type and the static representation of
2301/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2302RankedTensorType ExtractSliceOp::inferResultType(
2303 RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2304 ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2305 // An extract_slice op may specify only a leading subset of offset/sizes/
2306 // strides in which case we complete with offset=0, sizes from memref type
2307 // and strides=1.
2308 assert(static_cast<int64_t>(staticSizes.size()) ==
2309 sourceTensorType.getRank() &&
2310 "unexpected staticSizes not equal to rank of source");
2311 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2312 sourceTensorType.getEncoding());
2313}
2314
2315RankedTensorType ExtractSliceOp::inferResultType(
2316 RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2317 ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
2318 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2319 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2320 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2321 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2322 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2323 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2324 staticSizes, staticStrides);
2325}
2326
2327/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2328/// number of sizes), drop as many size 1 as needed to produce an inferred
2329/// type with the desired rank.
2330///
2331/// Note that there may be multiple ways to compute this rank-reduced type:
2332/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2333///
2334/// To disambiguate, this function always drops the first 1 sizes occurrences.
2335RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2336 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2337 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2338 ArrayRef<int64_t> strides) {
2339 // Type inferred in the absence of rank-reducing behavior.
2340 auto inferredType = llvm::cast<RankedTensorType>(
2341 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2342 int rankDiff = inferredType.getRank() - desiredResultRank;
2343 if (rankDiff > 0) {
2344 auto shape = inferredType.getShape();
2345 llvm::SmallBitVector dimsToProject =
2346 getPositionsOfShapeOne(rankDiff, shape);
2347 SmallVector<int64_t> projectedShape;
2348 // Best effort rank-reducing: drop 1s in order.
2349 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2350 if (!dimsToProject.test(pos))
2351 projectedShape.push_back(shape[pos]);
2352 inferredType =
2353 RankedTensorType::get(projectedShape, inferredType.getElementType());
2354 }
2355 return inferredType;
2356}
2357
2358RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2359 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2360 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2361 ArrayRef<OpFoldResult> strides) {
2362 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2363 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2364 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2365 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2366 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2367 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2368 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2369 staticStrides);
2370}
2371
2372/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2373/// result type. If the type passed is nullptr, it is inferred.
2374void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2375 RankedTensorType resultType, Value source,
2376 ArrayRef<OpFoldResult> offsets,
2377 ArrayRef<OpFoldResult> sizes,
2378 ArrayRef<OpFoldResult> strides,
2379 ArrayRef<NamedAttribute> attrs) {
2380 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2381 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2382 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2383 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2384 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2385 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
2386 // Structuring implementation this way avoids duplication between builders.
2387 if (!resultType) {
2388 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2389 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2390 }
2391 result.addAttributes(attrs);
2392 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2393 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2394 b.getDenseI64ArrayAttr(staticSizes),
2395 b.getDenseI64ArrayAttr(staticStrides));
2396}
2397
2398/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2399/// result type.
2400void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2401 ArrayRef<OpFoldResult> offsets,
2402 ArrayRef<OpFoldResult> sizes,
2403 ArrayRef<OpFoldResult> strides,
2404 ArrayRef<NamedAttribute> attrs) {
2405 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2406}
2407
2408/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2409/// a Range vector.
2410void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2411 ArrayRef<Range> ranges,
2412 ArrayRef<NamedAttribute> attrs) {
2413 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2414 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2415}
2416
2417/// Build an ExtractSliceOp with dynamic entries and custom result type. If
2418/// the type passed is nullptr, it is inferred.
2419void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2420 RankedTensorType resultType, Value source,
2421 ValueRange offsets, ValueRange sizes,
2422 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2423 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2424 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2425 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2426 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2427 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2428 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2429 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2430}
2431
2432/// Build an ExtractSliceOp with dynamic entries and inferred result type.
2433void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2434 ValueRange offsets, ValueRange sizes,
2435 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2436 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2437}
2438
2439static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2440 Operation *op,
2441 RankedTensorType expectedType) {
2442 switch (result) {
2443 case SliceVerificationResult::Success:
2444 return success();
2445 case SliceVerificationResult::RankTooLarge:
2446 return op->emitError(message: "expected rank to be smaller or equal to ")
2447 << "the other rank. ";
2448 case SliceVerificationResult::SizeMismatch:
2449 return op->emitError(message: "expected type to be ")
2450 << expectedType << " or a rank-reduced version. (size mismatch) ";
2451 case SliceVerificationResult::ElemTypeMismatch:
2452 return op->emitError(message: "expected element type to be ")
2453 << expectedType.getElementType();
2454 default:
2455 llvm_unreachable("unexpected extract_slice op verification result");
2456 }
2457}
2458
2459/// Verifier for ExtractSliceOp.
2460LogicalResult ExtractSliceOp::verify() {
2461 RankedTensorType sourceType = getSourceType();
2462
2463 // Verify result type against inferred type.
2464 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2465 sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
2466 SliceVerificationResult result = isRankReducedType(expectedType, getType());
2467 if (result != SliceVerificationResult::Success)
2468 return produceSliceErrorMsg(result, *this, expectedType);
2469
2470 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2471 // to the source tensor.
2472 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2473 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2474 getStaticStrides(), /*generateErrorMessage=*/true);
2475 if (!boundsResult.isValid)
2476 return getOperation()->emitError(boundsResult.errorMessage);
2477
2478 return success();
2479}
2480
2481llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2482 return ::getDroppedDims(getType().getShape(), getMixedSizes());
2483}
2484
2485FailureOr<Value>
2486ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2487 ArrayRef<int64_t> desiredShape) {
2488 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
2489 assert(sourceTensorType && "not a ranked tensor type");
2490 auto sourceShape = sourceTensorType.getShape();
2491 if (sourceShape.equals(desiredShape))
2492 return value;
2493 auto maybeRankReductionMask =
2494 mlir::computeRankReductionMask(sourceShape, desiredShape);
2495 if (!maybeRankReductionMask)
2496 return failure();
2497 return createCanonicalRankReducingExtractSliceOp(
2498 b, loc, value,
2499 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2500}
2501
2502LogicalResult ExtractSliceOp::reifyResultShapes(
2503 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2504 reifiedReturnShapes.resize(1);
2505 reifiedReturnShapes[0].reserve(getType().getRank());
2506 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2507 llvm::SmallBitVector droppedDims = getDroppedDims();
2508 for (const auto &size : enumerate(mixedSizes)) {
2509 if (droppedDims.test(size.index()))
2510 continue;
2511 reifiedReturnShapes[0].push_back(size.value());
2512 }
2513 return success();
2514}
2515
2516namespace {
2517/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2518/// This essentially pushes memref_cast past its consuming slice when
2519/// `canFoldIntoConsumerOp` is true.
2520///
2521/// Example:
2522/// ```
2523/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2524/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2525/// tensor<3x4xf32>
2526/// ```
2527/// is rewritten into:
2528/// ```
2529/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2530/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2531/// ```
2532class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2533public:
2534 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2535
2536 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2537 PatternRewriter &rewriter) const override {
2538 // Any constant operand, just return to let the constant folder kick in.
2539 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2540 return matchPattern(value: operand, pattern: matchConstantIndex());
2541 }))
2542 return failure();
2543
2544 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2545 if (!castOp)
2546 return failure();
2547
2548 if (!canFoldIntoConsumerOp(castOp))
2549 return failure();
2550
2551 // Pattern does not apply if the produced op would not verify.
2552 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2553 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2554 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2555 sliceOp.getStaticStrides());
2556 if (!sliceResult.isValid)
2557 return failure();
2558
2559 // Create folded extract.
2560 Location loc = sliceOp.getLoc();
2561 Value newResult = rewriter.create<ExtractSliceOp>(
2562 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2563 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2564 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2565 rewriter.replaceOp(sliceOp, newResult);
2566 return success();
2567 }
2568};
2569
2570/// Slice elements from `values` into `outValues`. `counts` represents the
2571/// numbers of elements to stride in the original values for each dimension.
2572/// The output values can be used to construct a DenseElementsAttr.
2573template <typename IterTy, typename ElemTy>
2574static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2575 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2576 ArrayRef<int64_t> strides,
2577 llvm::SmallVectorImpl<ElemTy> *outValues) {
2578 assert(offsets.size() == sizes.size());
2579 assert(offsets.size() == strides.size());
2580 if (offsets.empty())
2581 return;
2582
2583 int64_t offset = offsets.front();
2584 int64_t size = sizes.front();
2585 int64_t stride = strides.front();
2586 if (offsets.size() == 1) {
2587 for (int64_t i = 0; i < size; ++i, offset += stride)
2588 outValues->push_back(*(values + offset));
2589
2590 return;
2591 }
2592
2593 for (int64_t i = 0; i < size; ++i, offset += stride) {
2594 auto begin = values + offset * counts.front();
2595 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2596 offsets.drop_front(), sizes.drop_front(),
2597 strides.drop_front(), outValues);
2598 }
2599}
2600
2601/// Fold arith.constant and tensor.extract_slice into arith.constant. The
2602/// folded operation might introduce more constant data; Users can control
2603/// their heuristics by the control function.
2604class ConstantOpExtractSliceFolder final
2605 : public OpRewritePattern<ExtractSliceOp> {
2606public:
2607 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2608
2609 ConstantOpExtractSliceFolder(MLIRContext *context,
2610 ControlConstantExtractSliceFusionFn controlFn)
2611 : OpRewritePattern<ExtractSliceOp>(context),
2612 controlFn(std::move(controlFn)) {}
2613
2614 LogicalResult matchAndRewrite(ExtractSliceOp op,
2615 PatternRewriter &rewriter) const override {
2616 DenseElementsAttr attr;
2617 if (!matchPattern(op.getSource(), m_Constant(bind_value: &attr)))
2618 return failure();
2619
2620 // A constant splat is handled by fold().
2621 if (attr.isSplat())
2622 return failure();
2623
2624 // Dynamic result shape is not supported.
2625 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2626 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2627 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2628 return failure();
2629
2630 // Customized control over the folding.
2631 if (!controlFn(op))
2632 return failure();
2633
2634 int64_t count = sourceType.getNumElements();
2635 if (count == 0)
2636 return failure();
2637
2638 // Check if there are any dynamic parts, which are not supported.
2639 auto offsets = op.getStaticOffsets();
2640 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2641 return failure();
2642 auto sizes = op.getStaticSizes();
2643 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2644 return failure();
2645 auto strides = op.getStaticStrides();
2646 if (llvm::is_contained(strides, ShapedType::kDynamic))
2647 return failure();
2648
2649 // Compute the stride for each dimension.
2650 SmallVector<int64_t> counts;
2651 ArrayRef<int64_t> shape = sourceType.getShape();
2652 counts.reserve(N: shape.size());
2653 for (int64_t v : shape) {
2654 count = count / v;
2655 counts.push_back(count);
2656 }
2657
2658 // New attribute constructed by the sliced values.
2659 DenseElementsAttr newAttr;
2660
2661 if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2662 SmallVector<APInt> outValues;
2663 outValues.reserve(N: sourceType.getNumElements());
2664 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2665 elems.begin(), counts, offsets, sizes, strides, &outValues);
2666 newAttr = DenseElementsAttr::get(resultType, outValues);
2667 } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2668 SmallVector<APFloat> outValues;
2669 outValues.reserve(N: sourceType.getNumElements());
2670 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2671 elems.begin(), counts, offsets, sizes, strides, &outValues);
2672 newAttr = DenseElementsAttr::get(resultType, outValues);
2673 }
2674
2675 if (newAttr) {
2676 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
2677 return success();
2678 }
2679
2680 return failure();
2681 }
2682
2683private:
2684 /// This additionally controls whether the fold happens or not. Users can
2685 /// impose their heuristics in the function.
2686 ControlConstantExtractSliceFusionFn controlFn;
2687};
2688
2689} // namespace
2690
2691void mlir::tensor::populateFoldConstantExtractSlicePatterns(
2692 RewritePatternSet &patterns,
2693 const ControlConstantExtractSliceFusionFn &controlFn) {
2694 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
2695}
2696
2697/// Return the canonical type of the result of an extract_slice op.
2698struct SliceReturnTypeCanonicalizer {
2699 RankedTensorType operator()(ExtractSliceOp op,
2700 ArrayRef<OpFoldResult> mixedOffsets,
2701 ArrayRef<OpFoldResult> mixedSizes,
2702 ArrayRef<OpFoldResult> mixedStrides) {
2703 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2704 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2705 mixedStrides);
2706 }
2707};
2708
2709/// A canonicalizer wrapper to replace ExtractSliceOps.
2710struct SliceCanonicalizer {
2711 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2712 ExtractSliceOp newOp) {
2713 Value replacement = newOp.getResult();
2714 if (replacement.getType() != op.getType())
2715 replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
2716 replacement);
2717 rewriter.replaceOp(op, replacement);
2718 }
2719};
2720
2721void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2722 MLIRContext *context) {
2723 results.add<
2724 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2725 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2726 ExtractSliceOpCastFolder>(context);
2727}
2728
2729//
2730static LogicalResult
2731foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2732 ShapedType shapedType) {
2733 OpBuilder b(op.getContext());
2734 for (OpFoldResult ofr : op.getMixedOffsets())
2735 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2736 return failure();
2737 // Rank-reducing noops only need to inspect the leading dimensions:
2738 // llvm::zip is appropriate.
2739 auto shape = shapedType.getShape();
2740 for (auto it : llvm::zip(op.getMixedSizes(), shape))
2741 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
2742 return failure();
2743 for (OpFoldResult ofr : op.getMixedStrides())
2744 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2745 return failure();
2746 return success();
2747}
2748
2749/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2750/// slice, we can return the InsertSliceOp's source directly.
2751// TODO: This only checks the immediate producer; extend to go up the
2752// insert/extract chain if the slices are disjoint.
2753static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2754 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2755
2756 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2757 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2758 insertOp.isSameAs(extractOp, isSame))
2759 return insertOp.getSource();
2760
2761 return {};
2762}
2763
2764OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2765 if (OpFoldResult reshapedSource = reshapeConstantSource(
2766 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2767 getResult().getType()))
2768 return reshapedSource;
2769 if (getSourceType() == getType() &&
2770 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
2771 return this->getSource();
2772 if (Value slice = foldExtractAfterInsertSlice(*this))
2773 return slice;
2774
2775 return OpFoldResult();
2776}
2777
2778Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
2779 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2780 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
2781 unsigned rank = rankedTensorType.getRank();
2782 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
2783 SmallVector<OpFoldResult> sizes = getMixedSizes(builder&: b, loc, value: tensor);
2784 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
2785 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2786 offsets, sizes, strides);
2787}
2788
2789//===----------------------------------------------------------------------===//
2790// InsertSliceOp
2791//===----------------------------------------------------------------------===//
2792
2793void InsertSliceOp::getAsmResultNames(
2794 function_ref<void(Value, StringRef)> setNameFn) {
2795 setNameFn(getResult(), "inserted_slice");
2796}
2797
2798// Build a InsertSliceOp with mixed static and dynamic entries.
2799void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2800 Value dest, ArrayRef<OpFoldResult> offsets,
2801 ArrayRef<OpFoldResult> sizes,
2802 ArrayRef<OpFoldResult> strides,
2803 ArrayRef<NamedAttribute> attrs) {
2804 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2805 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2806 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2807 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2808 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2809 result.addAttributes(attrs);
2810 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
2811 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2812 b.getDenseI64ArrayAttr(staticSizes),
2813 b.getDenseI64ArrayAttr(staticStrides));
2814}
2815
2816/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2817/// Range vector.
2818void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2819 Value dest, ArrayRef<Range> ranges,
2820 ArrayRef<NamedAttribute> attrs) {
2821 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2822 build(b, result, source, dest, offsets, sizes, strides, attrs);
2823}
2824
2825// Build a InsertSliceOp with dynamic entries.
2826void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2827 Value dest, ValueRange offsets, ValueRange sizes,
2828 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2829 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2830 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2831 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2832 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2833 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2834 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2835 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2836}
2837
2838/// Rank-reducing type verification for both InsertSliceOp and
2839/// ParallelInsertSliceOp.
2840static SliceVerificationResult verifyInsertSliceOp(
2841 RankedTensorType srcType, RankedTensorType dstType,
2842 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2843 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2844 // insert_slice is the inverse of extract_slice, use the same type
2845 // inference.
2846 RankedTensorType expected = ExtractSliceOp::inferResultType(
2847 dstType, staticOffsets, staticSizes, staticStrides);
2848 if (expectedType)
2849 *expectedType = expected;
2850 return isRankReducedType(expected, srcType);
2851}
2852
2853/// Verifier for InsertSliceOp.
2854LogicalResult InsertSliceOp::verify() {
2855 // Verify result type against inferred type.
2856 RankedTensorType expectedType;
2857 SliceVerificationResult result =
2858 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
2859 getStaticSizes(), getStaticStrides(), &expectedType);
2860 if (result != SliceVerificationResult::Success)
2861 return produceSliceErrorMsg(result, *this, expectedType);
2862
2863 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2864 // to the destination tensor.
2865 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2866 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2867 getStaticStrides(), /*generateErrorMessage=*/true);
2868 if (!boundsResult.isValid)
2869 return getOperation()->emitError(boundsResult.errorMessage);
2870
2871 return success();
2872}
2873
2874/// If we have two consecutive InsertSliceOp writing to the same slice, we
2875/// can mutate the second InsertSliceOp's destination to the first one's.
2876///
2877/// Example:
2878///
2879/// ```mlir
2880/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2881/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2882/// ```
2883///
2884/// folds into:
2885///
2886/// ```mlir
2887/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2888/// ```
2889///
2890/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2891static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2892 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2893
2894 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2895 if (!prevInsertOp ||
2896 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2897 !prevInsertOp.isSameAs(insertOp, isSame))
2898 return failure();
2899
2900 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2901 return success();
2902}
2903
2904/// Folds round-trip extract/insert slice op pairs.
2905/// Example:
2906/// ```mlir
2907/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2908/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2909/// ```
2910/// can be folded into %val.
2911static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2912 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2913
2914 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2915 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2916 !extractOp.isSameAs(insertOp, isSame))
2917 return nullptr;
2918
2919 return extractOp.getSource();
2920}
2921
2922OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2923 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2924 getSourceType() == getType() &&
2925 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
2926 return this->getSource();
2927 if (succeeded(foldInsertAfterInsertSlice(*this)))
2928 return getResult();
2929 if (auto result = foldInsertAfterExtractSlice(*this))
2930 return result;
2931 if (llvm::any_of(getMixedSizes(), isZeroInteger))
2932 return getDest();
2933 return OpFoldResult();
2934}
2935
2936LogicalResult InsertSliceOp::reifyResultShapes(
2937 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2938 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
2939 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
2940 return success();
2941}
2942
2943namespace {
2944/// Pattern to rewrite a insert_slice op with constant arguments.
2945///
2946/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2947template <typename InsertOpTy>
2948class InsertSliceOpConstantArgumentFolder final
2949 : public OpRewritePattern<InsertOpTy> {
2950public:
2951 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2952
2953 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2954 PatternRewriter &rewriter) const override {
2955 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2956 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2957 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2958
2959 // No constant operands were folded, just return;
2960 if (failed(Result: foldDynamicOffsetSizeList(offsetsOrSizes&: mixedOffsets)) &&
2961 failed(Result: foldDynamicOffsetSizeList(offsetsOrSizes&: mixedSizes)) &&
2962 failed(Result: foldDynamicStrideList(strides&: mixedStrides)))
2963 return failure();
2964
2965 // Pattern does not apply if the produced op would not verify.
2966 SliceBoundsVerificationResult sliceResult =
2967 verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2968 mixedOffsets, mixedSizes, mixedStrides);
2969 if (!sliceResult.isValid)
2970 return failure();
2971
2972 // Create the new op in canonical form.
2973 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2974 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2975 mixedOffsets, mixedSizes, mixedStrides);
2976 Value toInsert = insertSliceOp.getSource();
2977 if (sourceType != insertSliceOp.getSourceType()) {
2978 OpBuilder::InsertionGuard g(rewriter);
2979 // The only difference between InsertSliceOp and ParallelInsertSliceOp
2980 // is that the insertion point is just before the ParallelCombiningOp in
2981 // the parallel case.
2982 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2983 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2984 toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2985 sourceType, toInsert);
2986 }
2987 rewriter.replaceOpWithNewOp<InsertOpTy>(
2988 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2989 mixedSizes, mixedStrides);
2990 return success();
2991 }
2992};
2993
2994/// Fold tensor_casts with insert_slice operations. If the source or
2995/// destination tensor is a tensor_cast that removes static type information,
2996/// the cast is folded into the insert_slice operation. E.g.:
2997///
2998/// ```mlir
2999/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3000/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
3001/// ```
3002///
3003/// folds into:
3004///
3005/// ```mlir
3006/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3007/// ```
3008///
3009/// Note: When folding a cast on the destination tensor, the result of the
3010/// insert_slice operation is casted to ensure that the type of the result did
3011/// not change.
3012///
3013/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3014template <typename InsertOpTy>
3015struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3016 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3017
3018 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3019 PatternRewriter &rewriter) const override {
3020 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3021 return matchPattern(value: operand, pattern: matchConstantIndex());
3022 }))
3023 return failure();
3024
3025 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3026 auto castOp = v.getDefiningOp<tensor::CastOp>();
3027 if (!castOp || !canFoldIntoConsumerOp(castOp))
3028 return std::nullopt;
3029 return castOp.getSource();
3030 };
3031 std::optional<Value> sourceCastSource =
3032 getSourceOfCastOp(insertSliceOp.getSource());
3033 std::optional<Value> destCastSource =
3034 getSourceOfCastOp(insertSliceOp.getDest());
3035 if (!sourceCastSource && !destCastSource)
3036 return failure();
3037
3038 auto src =
3039 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3040 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3041 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3042 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3043 if (!srcType || !dstType)
3044 return failure();
3045
3046 // The tensor.cast source could have additional static information not seen
3047 // in the insert slice op static sizes, so we ignore dynamic dims when
3048 // computing the rank reduction mask.
3049 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3050 auto rankReductionMask = computeRankReductionMask(
3051 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3052 if (!rankReductionMask.has_value())
3053 return failure();
3054 // Replace dimensions in the insert slice op with corresponding static dims
3055 // from the cast source type. If the insert slice sizes have static dims
3056 // that are not static in the tensor.cast source (i.e., when the cast op
3057 // casts a dynamic dim to static), the dim should not be replaced, and the
3058 // pattern will fail later in `verifyInsertSliceOp`.
3059 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3060 int64_t rankReducedIdx = 0;
3061 for (auto [idx, size] : enumerate(First&: staticSizes)) {
3062 if (!rankReductionMask.value().contains(idx) &&
3063 !srcType.isDynamicDim(rankReducedIdx)) {
3064 mixedSizes[idx] = getAsIndexOpFoldResult(
3065 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3066 size = srcType.getDimSize(rankReducedIdx++);
3067 }
3068 }
3069
3070 // Pattern does not apply if the produced op would not verify.
3071 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3072 staticSizes, insertSliceOp.getStaticStrides()) !=
3073 SliceVerificationResult::Success)
3074 return failure();
3075 SliceBoundsVerificationResult sliceResult =
3076 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3077 mixedSizes, insertSliceOp.getMixedStrides());
3078 if (!sliceResult.isValid)
3079 return failure();
3080
3081 Operation *replacement = rewriter.create<InsertOpTy>(
3082 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
3083 mixedSizes, insertSliceOp.getMixedStrides());
3084
3085 // In the parallel case there is no result and so nothing to cast.
3086 bool isParallelInsert =
3087 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3088 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3089 replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
3090 insertSliceOp.getDestType(),
3091 replacement->getResult(0));
3092 }
3093 rewriter.replaceOp(insertSliceOp, replacement->getResults());
3094 return success();
3095 }
3096};
3097
3098/// If additional static type information can be deduced from a insert_slice's
3099/// size operands, insert an explicit cast of the op's source operand. This
3100/// enables other canonicalization patterns that are matching for tensor_cast
3101/// ops such as `ForOpTensorCastFolder` in SCF.
3102///
3103/// Example:
3104///
3105/// ```mlir
3106/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3107/// : tensor<?x?xf32> into ...
3108/// ```
3109///
3110/// folds into:
3111///
3112/// ```mlir
3113/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3114/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3115/// : tensor<64x64xf32> into ...
3116/// ```
3117///
3118/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3119template <typename InsertOpTy>
3120struct InsertSliceOpSourceCastInserter final
3121 : public OpRewritePattern<InsertOpTy> {
3122 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3123
3124 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3125 PatternRewriter &rewriter) const override {
3126 RankedTensorType srcType = insertSliceOp.getSourceType();
3127 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3128 return failure();
3129 SmallVector<int64_t> newSrcShape(srcType.getShape());
3130 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3131 if (std::optional<int64_t> constInt =
3132 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3133 // Bail on invalid IR.
3134 if (*constInt < 0)
3135 return failure();
3136 newSrcShape[i] = *constInt;
3137 }
3138 }
3139 if (!hasValidSizesOffsets(sizesOrOffsets: newSrcShape))
3140 return failure();
3141
3142 RankedTensorType newSrcType = RankedTensorType::get(
3143 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3144 if (srcType == newSrcType ||
3145 !preservesStaticInformation(srcType, newSrcType) ||
3146 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3147 return failure();
3148
3149 // newSrcType is:
3150 // 1) Different from srcType.
3151 // 2) "More static" than srcType.
3152 // 3) Cast-compatible with srcType.
3153 // Insert the cast.
3154 OpBuilder::InsertionGuard g(rewriter);
3155 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3156 // that the insertion point is just before the ParallelCombiningOp in the
3157 // parallel case.
3158 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
3159 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3160 Value cast = rewriter.create<tensor::CastOp>(
3161 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
3162 rewriter.replaceOpWithNewOp<InsertOpTy>(
3163 insertSliceOp, cast, insertSliceOp.getDest(),
3164 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3165 insertSliceOp.getMixedStrides());
3166 return success();
3167 }
3168};
3169} // namespace
3170
3171llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3172 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3173}
3174
3175void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3176 MLIRContext *context) {
3177 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3178 InsertSliceOpCastFolder<InsertSliceOp>,
3179 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3180}
3181
3182Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
3183 Location loc,
3184 Value tensor,
3185 Value dest) {
3186 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
3187 unsigned rank = rankedTensorType.getRank();
3188 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3189 SmallVector<OpFoldResult> sizes = getMixedSizes(builder&: b, loc, value: dest);
3190 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3191 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3192 sizes, strides);
3193}
3194
3195//===----------------------------------------------------------------------===//
3196// PadOp
3197//===----------------------------------------------------------------------===//
3198
3199void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3200 setNameFn(getResult(), "padded");
3201}
3202
3203// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
3204// supports optional types.
3205void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
3206 Type typeToInfer, Type typeToInferFrom) {}
3207
3208ParseResult
3209parseInferType(OpAsmParser &parser,
3210 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3211 Type &typeToInfer, Type typeToInferFrom) {
3212 if (optOperand)
3213 typeToInfer = typeToInferFrom;
3214 return success();
3215}
3216
3217LogicalResult PadOp::verify() {
3218 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
3219 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
3220 auto expectedType =
3221 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3222 if (!expectedType) {
3223 return emitError("failed to infer expectedType from sourceType ")
3224 << sourceType << ", specified resultType is " << resultType;
3225 }
3226 if (resultType.getRank() != expectedType.getRank()) {
3227 return emitError("specified type ")
3228 << resultType << " does not match the inferred type "
3229 << expectedType;
3230 }
3231 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3232 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3233 continue;
3234 if (expectedType.isDynamicDim(i))
3235 continue;
3236 return emitError("specified type ")
3237 << resultType << " does not match the inferred type "
3238 << expectedType;
3239 }
3240
3241 return success();
3242}
3243
3244LogicalResult PadOp::verifyRegions() {
3245 auto &region = getRegion();
3246 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
3247 Block &block = region.front();
3248 if (block.getNumArguments() != rank)
3249 return emitError("expected the block to have ") << rank << " arguments";
3250
3251 // Note: the number and type of yield values are checked in the YieldOp.
3252 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
3253 if (!en.value().isIndex())
3254 return emitOpError("expected block argument ")
3255 << (en.index() + 1) << " to be an index";
3256 }
3257
3258 // Ensure that the region yields an element of the right type.
3259 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
3260 if (yieldOp.getValue().getType() !=
3261 llvm::cast<ShapedType>(getType()).getElementType())
3262 return emitOpError("expected yield type to match shape element type");
3263
3264 return success();
3265}
3266
3267RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3268 ArrayRef<int64_t> staticLow,
3269 ArrayRef<int64_t> staticHigh,
3270 ArrayRef<int64_t> resultShape) {
3271 unsigned rank = sourceType.getRank();
3272 if (staticLow.size() != rank)
3273 return RankedTensorType();
3274 if (staticHigh.size() != rank)
3275 return RankedTensorType();
3276 if (!resultShape.empty() && resultShape.size() != rank)
3277 return RankedTensorType();
3278
3279 SmallVector<int64_t, 4> inferredShape;
3280 for (auto i : llvm::seq<unsigned>(0, rank)) {
3281 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3282 staticHigh[i] == ShapedType::kDynamic) {
3283 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3284 : resultShape[i]);
3285 } else {
3286 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3287 assert((resultShape.empty() || size == resultShape[i] ||
3288 resultShape[i] == ShapedType::kDynamic) &&
3289 "mismatch between inferred shape and result shape");
3290 inferredShape.push_back(size);
3291 }
3292 }
3293
3294 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3295}
3296
3297void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3298 Value source, ArrayRef<int64_t> staticLow,
3299 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3300 bool nofold, ArrayRef<NamedAttribute> attrs) {
3301 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3302 if (!resultType)
3303 resultType = inferResultType(sourceType, staticLow, staticHigh);
3304 result.addAttributes(attrs);
3305 build(b, result, resultType, source, low, high,
3306 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3307 nofold ? b.getUnitAttr() : UnitAttr());
3308}
3309
3310void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3311 Value source, ValueRange low, ValueRange high, bool nofold,
3312 ArrayRef<NamedAttribute> attrs) {
3313 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3314 unsigned rank = sourceType.getRank();
3315 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3316 build(b, result, resultType, source, staticVector, staticVector, low, high,
3317 nofold, attrs);
3318}
3319
3320void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3321 Value source, ArrayRef<OpFoldResult> low,
3322 ArrayRef<OpFoldResult> high, bool nofold,
3323 ArrayRef<NamedAttribute> attrs) {
3324 auto sourceType = llvm::cast<RankedTensorType>(source.getType());
3325 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3326 SmallVector<int64_t, 4> staticLow, staticHigh;
3327 // staticLow and staticHigh have full information of the padding config.
3328 // This will grow staticLow and staticHigh with 1 value. If the config is
3329 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3330 // value as well.
3331 dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
3332 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
3333 if (!resultType) {
3334 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3335 }
3336 assert(llvm::isa<RankedTensorType>(resultType));
3337 result.addAttributes(attrs);
3338 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3339 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
3340 nofold ? b.getUnitAttr() : UnitAttr());
3341}
3342
3343void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3344 Value source, ArrayRef<OpFoldResult> low,
3345 ArrayRef<OpFoldResult> high, Value constantPadValue,
3346 bool nofold, ArrayRef<NamedAttribute> attrs) {
3347 build(b, result, resultType, source, low, high, nofold, attrs);
3348
3349 // Add a region and a block to yield the pad value.
3350 Region *region = result.regions[0].get();
3351 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
3352 SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3353 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3354
3355 // `builder.createBlock` changes the insertion point within the block. Create
3356 // a guard to reset the insertion point of the builder after it is destroyed.
3357 OpBuilder::InsertionGuard guard(b);
3358 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
3359 b.create<tensor::YieldOp>(result.location, constantPadValue);
3360}
3361
3362llvm::SmallBitVector PadOp::getPaddedDims() {
3363 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3364 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3365 for (const auto &en : enumerate(paddingWidths))
3366 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
3367 paddedDims.set(en.index());
3368 };
3369 extractPaddedDims(getMixedLowPad());
3370 extractPaddedDims(getMixedHighPad());
3371 return paddedDims;
3372}
3373
3374namespace {
3375// Folds tensor.pad when padding is static zeros and the attribute
3376// doesn't request otherwise.
3377struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3378 using OpRewritePattern<PadOp>::OpRewritePattern;
3379
3380 LogicalResult matchAndRewrite(PadOp padTensorOp,
3381 PatternRewriter &rewriter) const override {
3382 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3383 return failure();
3384 if (padTensorOp.getNofold())
3385 return failure();
3386 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3387 padTensorOp, padTensorOp.getResult().getType(),
3388 padTensorOp.getSource());
3389 return success();
3390 }
3391};
3392
3393// Fold CastOp into PadOp when adding static information.
3394struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3395 using OpRewritePattern<PadOp>::OpRewritePattern;
3396
3397 LogicalResult matchAndRewrite(PadOp padTensorOp,
3398 PatternRewriter &rewriter) const override {
3399 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3400 if (!tensor::canFoldIntoConsumerOp(castOp))
3401 return failure();
3402
3403 auto newResultType = PadOp::inferResultType(
3404 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3405 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3406 padTensorOp.getResultType().getShape());
3407
3408 if (newResultType == padTensorOp.getResultType()) {
3409 rewriter.modifyOpInPlace(padTensorOp, [&]() {
3410 padTensorOp.getSourceMutable().assign(castOp.getSource());
3411 });
3412 } else {
3413 auto newOp = rewriter.create<PadOp>(
3414 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3415 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3416 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3417 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3418 IRMapping mapper;
3419 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3420
3421 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3422 padTensorOp, padTensorOp.getResultType(), newOp);
3423 }
3424 return success();
3425 }
3426};
3427
3428// Fold CastOp using the result of PadOp back into the latter if it adds
3429// static information.
3430struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3431 using OpRewritePattern<PadOp>::OpRewritePattern;
3432
3433 LogicalResult matchAndRewrite(PadOp padTensorOp,
3434 PatternRewriter &rewriter) const override {
3435 if (!padTensorOp.getResult().hasOneUse())
3436 return failure();
3437 auto tensorCastOp =
3438 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3439 if (!tensorCastOp)
3440 return failure();
3441 if (!tensor::preservesStaticInformation(source: padTensorOp.getResult().getType(),
3442 target: tensorCastOp.getDest().getType()))
3443 return failure();
3444
3445 auto replacementOp = rewriter.create<PadOp>(
3446 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3447 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3448 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3449 padTensorOp.getHigh(), padTensorOp.getNofold(),
3450 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3451 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3452
3453 rewriter.replaceOp(padTensorOp, replacementOp.getResult());
3454 rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
3455 return success();
3456 }
3457};
3458
3459/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3460/// different dimensions. The pattern applies if the following preconditions
3461/// hold:
3462/// 1) the tensor::ExtractSliceOps are not rank-reducing,
3463/// 2) the tensor::ExtractSliceOps have only unit-strides,
3464/// 3) the tensor::PadOps perform only high-padding,
3465/// 4) the tensor::PadOps have the same constant padding value,
3466/// 5) the tensor::PadOps do not have common padding dimensions,
3467/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3468/// zero-offset for every dimension.
3469/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3470/// the
3471/// padded source dimensions.
3472///
3473/// Example:
3474///
3475/// ```mlir
3476/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3477/// : tensor<64x64xf32> to tensor<?x64xf32>
3478/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3479/// } : tensor<?x64xf32> to tensor<8x64xf32>
3480/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3481/// : tensor<8x64xf32> to tensor<8x?xf32>
3482/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3483/// } : tensor<8x?xf32> to tensor<8x4xf32>
3484/// ```
3485///
3486/// folds into:
3487///
3488/// ```mlir
3489/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3490/// : tensor<64x64xf32> to tensor<?x?xf32>
3491/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3492/// } : tensor<?x?xf32> to tensor<8x4xf32>
3493/// ```
3494struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3495 using OpRewritePattern<PadOp>::OpRewritePattern;
3496
3497 LogicalResult matchAndRewrite(PadOp padOp,
3498 PatternRewriter &rewriter) const override {
3499 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3500 if (!innerSliceOp)
3501 return failure();
3502 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3503 if (!outerPadOp || outerPadOp.getNofold())
3504 return failure();
3505 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3506 if (!outerSliceOp)
3507 return failure();
3508
3509 // 1) Fail if the chain is rank-reducing.
3510 int64_t rank = padOp.getSourceType().getRank();
3511 if (outerSliceOp.getSourceType().getRank() != rank) {
3512 return rewriter.notifyMatchFailure(padOp,
3513 "cannot fold rank-reducing chain");
3514 }
3515
3516 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3517 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3518 return rewriter.notifyMatchFailure(
3519 padOp, "cannot fold non-unit stride ExtractSliceOps");
3520 }
3521
3522 // 3) Fail if the tensor::PadOps have non-zero low padding.
3523 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3524 return rewriter.notifyMatchFailure(padOp,
3525 "cannot fold PadOps with low padding");
3526 }
3527
3528 // 4) Fail if the tensor::PadOps padding values do not match.
3529 Attribute innerAttr, outerAttr;
3530 Value innerValue = padOp.getConstantPaddingValue();
3531 Value outerValue = outerPadOp.getConstantPaddingValue();
3532 if (!innerValue || !outerValue ||
3533 !matchPattern(value: innerValue, pattern: m_Constant(bind_value: &innerAttr)) ||
3534 !matchPattern(value: outerValue, pattern: m_Constant(bind_value: &outerAttr)) ||
3535 innerAttr != outerAttr) {
3536 return rewriter.notifyMatchFailure(
3537 padOp, "cannot fold PadOps with different padding values");
3538 }
3539
3540 // 5) Fail if a dimension is padded by both tensor::PadOps.
3541 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3542 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3543 if (innerDims.anyCommon(RHS: outerDims)) {
3544 return rewriter.notifyMatchFailure(
3545 padOp, "cannot fold PadOps with common padding dimensions");
3546 }
3547
3548 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3549 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3550 // for every dimension, and use the offset the other pair. Fail if no
3551 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3552 // exists.
3553 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
3554 for (auto en : enumerate(newOffsets)) {
3555 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3556 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3557 if (!innerDims.test(en.index()) &&
3558 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
3559 en.value() = outerOffset;
3560 continue;
3561 }
3562 if (!outerDims.test(en.index()) &&
3563 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
3564 en.value() = innerOffset;
3565 continue;
3566 }
3567 return rewriter.notifyMatchFailure(
3568 padOp, "cannot find zero-offset and zero-padding pair");
3569 }
3570
3571 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3572 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3573 // outer tensor::PadOp and fail if the size of the inner
3574 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3575 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3576 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3577 for (auto en : enumerate(newSizes)) {
3578 if (!outerDims.test(en.index()))
3579 continue;
3580 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3581 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3582 assert(!ShapedType::isDynamic(sourceSize) &&
3583 "expected padded dimension to have a static size");
3584 if (getConstantIntValue(sliceSize) != sourceSize) {
3585 return rewriter.notifyMatchFailure(
3586 padOp, "cannot fold since the inner ExtractSliceOp size does not "
3587 "match the size of the outer padding");
3588 }
3589 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3590 }
3591
3592 // Combine the high paddings of the two tensor::PadOps.
3593 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
3594 for (auto en : enumerate(newHighPad)) {
3595 if (innerDims.test(en.index()))
3596 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3597 if (outerDims.test(en.index()))
3598 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3599 }
3600
3601 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3602 // the two paddings in one step.
3603 auto newSliceOp = rewriter.create<ExtractSliceOp>(
3604 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3605 innerSliceOp.getMixedStrides());
3606 auto newPadOp = rewriter.create<PadOp>(
3607 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3608 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3609 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
3610 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3611 newPadOp.getRegion().begin());
3612 rewriter.replaceOp(padOp, newPadOp.getResult());
3613 return success();
3614 }
3615};
3616
3617struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3618 using OpRewritePattern<PadOp>::OpRewritePattern;
3619
3620 LogicalResult matchAndRewrite(PadOp padTensorOp,
3621 PatternRewriter &rewriter) const override {
3622 Value input = padTensorOp.getSource();
3623 if (!llvm::isa<RankedTensorType>(Val: input.getType()))
3624 return failure();
3625 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
3626 auto inputRank = inputDims.size();
3627
3628 auto oldResultType =
3629 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3630 if (!oldResultType)
3631 return failure();
3632
3633 auto outputDims = oldResultType.getShape();
3634
3635 // Extract the static info from the high and low operands.
3636 SmallVector<int64_t> constOperandsLow;
3637 SmallVector<Value> newLows;
3638 for (auto operand : padTensorOp.getLow()) {
3639 APSInt intOp;
3640 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3641 constOperandsLow.push_back(ShapedType::kDynamic);
3642 newLows.push_back(operand);
3643 continue;
3644 }
3645 constOperandsLow.push_back(intOp.getExtValue());
3646 }
3647 SmallVector<int64_t> constOperandsHigh;
3648 SmallVector<Value> newHighs;
3649 for (auto operand : padTensorOp.getHigh()) {
3650 APSInt intOp;
3651 if (!matchPattern(operand, m_ConstantInt(&intOp))) {
3652 constOperandsHigh.push_back(ShapedType::kDynamic);
3653 newHighs.push_back(operand);
3654 continue;
3655 }
3656 constOperandsHigh.push_back(intOp.getExtValue());
3657 }
3658
3659 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3660 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3661
3662 // Verify the op is well-formed.
3663 if (inputDims.size() != outputDims.size() ||
3664 inputDims.size() != constLow.size() ||
3665 inputDims.size() != constHigh.size())
3666 return failure();
3667
3668 auto lowCount = 0;
3669 auto highCount = 0;
3670 for (size_t i = 0; i < inputRank; i++) {
3671 if (constLow[i] == ShapedType::kDynamic)
3672 constLow[i] = constOperandsLow[lowCount++];
3673 if (constHigh[i] == ShapedType::kDynamic)
3674 constHigh[i] = constOperandsHigh[highCount++];
3675 }
3676
3677 auto staticLow = ArrayRef<int64_t>(constLow);
3678 auto staticHigh = ArrayRef<int64_t>(constHigh);
3679
3680 // Calculate the output sizes with the static information.
3681 SmallVector<int64_t> newOutDims;
3682 for (size_t i = 0; i < inputRank; i++) {
3683 if (outputDims[i] == ShapedType::kDynamic) {
3684 newOutDims.push_back(
3685 (staticLow[i] == ShapedType::kDynamic ||
3686 staticHigh[i] == ShapedType::kDynamic ||
3687 inputDims[i] == ShapedType::kDynamic
3688 ? ShapedType::kDynamic
3689 : inputDims[i] + staticLow[i] + staticHigh[i]));
3690 } else {
3691 newOutDims.push_back(Elt: outputDims[i]);
3692 }
3693 }
3694
3695 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3696 llvm::all_of(Range&: newOutDims,
3697 P: [&](int64_t x) { return x == ShapedType::kDynamic; }))
3698 return failure();
3699
3700 // Rewrite the op using the new static type.
3701 auto newResultType = RankedTensorType::get(
3702 newOutDims, padTensorOp.getType().getElementType());
3703 auto newOp = rewriter.create<PadOp>(
3704 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3705 newLows, newHighs, padTensorOp.getNofold(),
3706 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
3707
3708 IRMapping mapper;
3709 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3710 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
3711 newOp);
3712
3713 return success();
3714 }
3715};
3716
3717/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3718///
3719/// Example:
3720///
3721/// ```mlir
3722/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3723/// tensor.yield %val
3724/// } : tensor<1x2xf32> to tensor<2x5xf32>
3725/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3726/// tensor.yield %val
3727/// } : tensor<1x5xf32> to tensor<5x7xf32>
3728/// ```
3729///
3730/// folds into:
3731///
3732/// ```mlir
3733/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3734/// tensor.yield %val
3735/// } : tensor<1x2xf32> to tensor<5x7xf32>
3736/// ```
3737struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3738 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3739
3740 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3741 PatternRewriter &rewriter) const override {
3742 if (padOp.getNofold()) {
3743 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3744 }
3745
3746 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3747 if (!producerPad || producerPad.getNofold()) {
3748 return rewriter.notifyMatchFailure(
3749 padOp, "producer is not a foldable tensor.pad op");
3750 }
3751
3752 // Fail if the tensor::PadOps padding values do not match.
3753 Value consumerPadValue = padOp.getConstantPaddingValue();
3754 Value producerPadValue = producerPad.getConstantPaddingValue();
3755 if (!consumerPadValue || !producerPadValue ||
3756 consumerPadValue != producerPadValue) {
3757 return rewriter.notifyMatchFailure(
3758 padOp,
3759 "cannot fold PadOps with different or non-constant padding values");
3760 }
3761
3762 Location loc = padOp.getLoc();
3763 AffineExpr d0, d1;
3764 bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1);
3765
3766 // Combine the low/high paddings of the two tensor::PadOps.
3767 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3768 ArrayRef<OpFoldResult> producerPaddings) {
3769 SmallVector<OpFoldResult> sumPaddings;
3770 for (auto [consumerIndex, producerIndex] :
3771 llvm::zip_equal(t&: consumerPaddings, u&: producerPaddings)) {
3772 sumPaddings.push_back(Elt: affine::makeComposedFoldedAffineApply(
3773 b&: rewriter, loc, expr: d0 + d1, operands: {consumerIndex, producerIndex}));
3774 }
3775 return sumPaddings;
3776 };
3777
3778 SmallVector<OpFoldResult> newHighPad =
3779 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3780 SmallVector<OpFoldResult> newLowPad =
3781 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3782
3783 auto newPadOp = rewriter.create<tensor::PadOp>(
3784 padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3785 newLowPad, newHighPad, padOp.getNofold(),
3786 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3787 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3788 newPadOp.getRegion().begin());
3789 rewriter.replaceOp(padOp, newPadOp.getResult());
3790 return success();
3791 }
3792};
3793
3794} // namespace
3795
3796void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3797 MLIRContext *context) {
3798 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3799 FoldOrthogonalPaddings, FoldStaticPadding,
3800 FoldConsecutiveConstantPadding>(context);
3801}
3802
3803/// Return the padding value of the PadOp if it constant. In this context,
3804/// "constant" means an actual constant or "defined outside of the block".
3805///
3806/// Values are considered constant in three cases:
3807/// - A ConstantLike value.
3808/// - A basic block argument from a different block.
3809/// - A value defined outside of the block.
3810///
3811/// If the padding value is not constant, an empty Value is returned.
3812Value PadOp::getConstantPaddingValue() {
3813 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3814 if (!yieldOp)
3815 return {};
3816 Value padValue = yieldOp.getValue();
3817 // Check if yield value is a constant.
3818 if (matchPattern(padValue, m_Constant()))
3819 return padValue;
3820 // Check if yield value is defined inside the PadOp block.
3821 if (padValue.getParentBlock() == &getRegion().front())
3822 return {};
3823 // Else: Yield value defined outside of the PadOp block.
3824 return padValue;
3825}
3826
3827OpFoldResult PadOp::fold(FoldAdaptor) {
3828 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3829 !getNofold())
3830 return getSource();
3831 return {};
3832}
3833
3834//===----------------------------------------------------------------------===//
3835// ParallelInsertSliceOp
3836//===----------------------------------------------------------------------===//
3837
3838OpResult ParallelInsertSliceOp::getTiedOpResult() {
3839 ParallelCombiningOpInterface parallelCombiningParent =
3840 getParallelCombiningParent();
3841 for (const auto &it :
3842 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3843 Operation &nextOp = it.value();
3844 if (&nextOp == getOperation())
3845 return parallelCombiningParent.getParentResult(it.index());
3846 }
3847 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3848}
3849
3850// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3851void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3852 Value source, Value dest,
3853 ArrayRef<OpFoldResult> offsets,
3854 ArrayRef<OpFoldResult> sizes,
3855 ArrayRef<OpFoldResult> strides,
3856 ArrayRef<NamedAttribute> attrs) {
3857 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3858 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3859 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3860 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3861 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3862 result.addAttributes(attrs);
3863 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3864 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3865 b.getDenseI64ArrayAttr(staticSizes),
3866 b.getDenseI64ArrayAttr(staticStrides));
3867}
3868
3869/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3870/// packed into a Range vector.
3871void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3872 Value source, Value dest,
3873 ArrayRef<Range> ranges,
3874 ArrayRef<NamedAttribute> attrs) {
3875 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3876 build(b, result, source, dest, offsets, sizes, strides, attrs);
3877}
3878
3879// Build a ParallelInsertSliceOp with dynamic entries.
3880void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3881 Value source, Value dest, ValueRange offsets,
3882 ValueRange sizes, ValueRange strides,
3883 ArrayRef<NamedAttribute> attrs) {
3884 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3885 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3886 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3887 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3888 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3889 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3890 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3891}
3892
3893LogicalResult ParallelInsertSliceOp::verify() {
3894 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3895 return this->emitError("expected ParallelCombiningOpInterface parent, got:")
3896 << *(getOperation()->getParentOp());
3897
3898 // Verify result type against inferred type.
3899 RankedTensorType expectedType;
3900 SliceVerificationResult result =
3901 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
3902 getStaticSizes(), getStaticStrides(), &expectedType);
3903 if (result != SliceVerificationResult::Success)
3904 return produceSliceErrorMsg(result, *this, expectedType);
3905
3906 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3907 // to the destination tensor.
3908 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3909 getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3910 getStaticStrides(), /*generateErrorMessage=*/true);
3911 if (!boundsResult.isValid)
3912 return getOperation()->emitError(boundsResult.errorMessage);
3913
3914 return success();
3915}
3916
3917void ParallelInsertSliceOp::getCanonicalizationPatterns(
3918 RewritePatternSet &results, MLIRContext *context) {
3919 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3920 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3921 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3922}
3923
3924llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3925 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
3926}
3927
3928//===----------------------------------------------------------------------===//
3929// ScatterOp
3930//===----------------------------------------------------------------------===//
3931
3932void ScatterOp::getAsmResultNames(
3933 function_ref<void(Value, StringRef)> setNameFn) {
3934 setNameFn(getResult(), "scatter");
3935}
3936
3937LogicalResult ScatterOp::verify() {
3938 int64_t destRank = getDestType().getRank();
3939 ArrayRef<int64_t> scatterDims = getScatterDims();
3940 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3941 getIndicesType().getShape(), destRank,
3942 "scatter", "dest")))
3943 return failure();
3944
3945 if (!getUnique())
3946 return emitOpError("requires 'unique' attribute to be set");
3947 // TODO: we could also check statically that there are fewer leading index
3948 // tensor dims than the dest dims. If this is not the case, the unique
3949 // attribute cannot be true.
3950
3951 // Use the GatherOp::inferResultType on the `dest` type and verify the
3952 // expected type matches the source type.
3953 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3954 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
3955 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3956 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
3957 if (getSourceType() != expectedSourceType &&
3958 getSourceType() != expectedRankReducedSourceType) {
3959 return emitOpError("source type "
3960 "mismatch: "
3961 "expected ")
3962 << expectedSourceType << " or its rank-reduced variant "
3963 << expectedRankReducedSourceType << " (got: " << getSourceType()
3964 << ")";
3965 }
3966
3967 return success();
3968}
3969
3970//===----------------------------------------------------------------------===//
3971// SplatOp
3972//===----------------------------------------------------------------------===//
3973
3974void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3975 Type aggregateType, ValueRange dynamicSizes) {
3976 build(builder, result, aggregateType, element, dynamicSizes);
3977}
3978
3979void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3980 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
3981 auto aggregateType = RankedTensorType::get(staticShape, element.getType());
3982 build(builder, result, aggregateType, element, dynamicSizes);
3983}
3984
3985void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3986 ArrayRef<OpFoldResult> sizes) {
3987 SmallVector<int64_t> staticShape;
3988 SmallVector<Value> dynamicSizes;
3989 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
3990 build(builder, result, element, staticShape, dynamicSizes);
3991}
3992
3993void SplatOp::getAsmResultNames(
3994 function_ref<void(Value, StringRef)> setNameFn) {
3995 setNameFn(getResult(), "splat");
3996}
3997
3998LogicalResult SplatOp::verify() {
3999 if (getType().getNumDynamicDims() != getDynamicSizes().size())
4000 return emitOpError("incorrect number of dynamic sizes, has ")
4001 << getDynamicSizes().size() << ", expected "
4002 << getType().getNumDynamicDims();
4003 return success();
4004}
4005
4006LogicalResult
4007SplatOp::reifyResultShapes(OpBuilder &builder,
4008 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4009 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
4010 unsigned ctr = 0;
4011 for (int64_t i = 0; i < getType().getRank(); ++i) {
4012 if (getType().isDynamicDim(i)) {
4013 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4014 } else {
4015 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
4016 }
4017 }
4018 return success();
4019}
4020
4021OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4022 auto constOperand = adaptor.getInput();
4023 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4024 return {};
4025
4026 // Do not fold if the splat is not statically shaped
4027 if (!getType().hasStaticShape())
4028 return {};
4029
4030 // SplatElementsAttr::get treats single value for second arg as being a
4031 // splat.
4032 return SplatElementsAttr::get(getType(), {constOperand});
4033}
4034
4035//===----------------------------------------------------------------------===//
4036// Common Canonicalizers and Folders.
4037//===----------------------------------------------------------------------===//
4038bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4039 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4040 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4041 // might need special handling of attached regions.
4042 if (isa<InsertSliceOp>(op.getOperation()) ||
4043 isa<LoopLikeOpInterface>(op.getOperation()))
4044 return false;
4045
4046 return hasFoldableTensorCastOperand(op);
4047}
4048
4049/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4050/// the `tensor.cast` has source that is more static than the consuming op.
4051///
4052/// Example:
4053/// ```mlir
4054/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4055/// %2 = consumer %1 ... : tensor<?x?xf32> ...
4056/// ```
4057///
4058/// folds into:
4059///
4060/// ```mlir
4061/// %2 = consumer %0 ... : tensor<8x16xf32> ...
4062/// ```
4063/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4064/// can add the pattern to their canonicalizers.
4065struct FoldTensorCastProducerOp
4066 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4067 using OpInterfaceRewritePattern<
4068 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4069
4070 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4071 PatternRewriter &rewriter) const override {
4072
4073 // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4074 // for that instead.
4075 if (!foldTensorCastPrecondition(op) ||
4076 isa<linalg::RelayoutOpInterface>(*op))
4077 return failure();
4078
4079 SmallVector<Type> newResultTypes(op->getResultTypes());
4080 SmallVector<Value> newOperands =
4081 getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
4082
4083 // Clone op
4084 auto newOp = clone(rewriter, op, newResultTypes, newOperands);
4085
4086 SmallVector<Value, 4> replacements;
4087 replacements.reserve(N: newOp->getNumResults());
4088 for (auto [oldResult, newResult] :
4089 llvm::zip(op->getResults(), newOp->getResults())) {
4090 if (newResult.getType() != oldResult.getType()) {
4091 replacements.push_back(rewriter.create<tensor::CastOp>(
4092 op->getLoc(), oldResult.getType(), newResult));
4093 } else {
4094 replacements.push_back(newResult);
4095 }
4096 }
4097 rewriter.replaceOp(op, replacements);
4098
4099 return success();
4100 }
4101};
4102
4103//===----------------------------------------------------------------------===//
4104// TensorDialect
4105//===----------------------------------------------------------------------===//
4106
4107void TensorDialect::getCanonicalizationPatterns(
4108 RewritePatternSet &results) const {
4109 results.add<FoldTensorCastProducerOp>(getContext());
4110}
4111
4112//===----------------------------------------------------------------------===//
4113// TableGen'd op method definitions
4114//===----------------------------------------------------------------------===//
4115
4116#define GET_OP_CLASSES
4117#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
4118

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Tensor/IR/TensorOps.cpp