1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Arith/Utils/Utils.h"
12#include "mlir/Dialect/Complex/IR/Complex.h"
13#include "mlir/Dialect/Linalg/IR/RelayoutOpInterface.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/Dialect/Utils/IndexingUtils.h"
16#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
17#include "mlir/Dialect/Utils/StaticValueUtils.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinAttributeInterfaces.h"
20#include "mlir/IR/BuiltinTypeInterfaces.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/IRMapping.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/IR/OpDefinition.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/TypeUtilities.h"
27#include "mlir/Interfaces/DestinationStyleOpInterface.h"
28#include "mlir/Interfaces/InferIntRangeInterface.h"
29#include "mlir/Interfaces/LoopLikeInterface.h"
30#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
31#include "mlir/Interfaces/ViewLikeInterface.h"
32#include "mlir/Support/LLVM.h"
33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/SmallBitVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/MathExtras.h"
39#include <optional>
40
41using namespace mlir;
42using namespace mlir::tensor;
43
44using llvm::divideCeilSigned;
45using llvm::divideFloorSigned;
46using llvm::mod;
47
48/// Materialize a single constant operation from a given attribute value with
49/// the desired resultant type.
50Operation *TensorDialect::materializeConstant(OpBuilder &builder,
51 Attribute value, Type type,
52 Location loc) {
53 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
54 return op;
55 if (complex::ConstantOp::isBuildableWith(value, type))
56 return builder.create<complex::ConstantOp>(location: loc, args&: type,
57 args: llvm::cast<ArrayAttr>(Val&: value));
58 return nullptr;
59}
60
61OpFoldResult tensor::getMixedSize(OpBuilder &builder, Location loc, Value value,
62 int64_t dim) {
63 auto tensorType = llvm::cast<RankedTensorType>(Val: value.getType());
64 if (tensorType.isDynamicDim(idx: dim))
65 return builder.createOrFold<tensor::DimOp>(location: loc, args&: value, args&: dim);
66
67 return builder.getIndexAttr(value: tensorType.getDimSize(idx: dim));
68}
69
70SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder,
71 Location loc, Value value) {
72 auto tensorType = llvm::cast<RankedTensorType>(Val: value.getType());
73 SmallVector<OpFoldResult> result;
74 for (int64_t i = 0; i < tensorType.getRank(); ++i)
75 result.push_back(Elt: getMixedSize(builder, loc, value, dim: i));
76 return result;
77}
78
79FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc,
80 OpResult opResult) {
81 auto tensorType = llvm::dyn_cast<TensorType>(Val: opResult.getType());
82 assert(tensorType && "expected tensor type");
83
84 // If the op has a destination, it implements DestinationStyleOpInterface and
85 // we can query the destination operand from that interface.
86 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
87 if (destOp)
88 return destOp.getTiedOpOperand(opResult)->get();
89
90 // Otherwise, create a new destination tensor with the same shape.
91 OpBuilder::InsertionGuard g(b);
92 b.setInsertionPoint(opResult.getDefiningOp());
93
94 // Compute sizes.
95 SmallVector<OpFoldResult> mixedSizes;
96 if (!tensorType.hasStaticShape()) {
97 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
98 ReifiedRankedShapedTypeDims reifiedShapes;
99 if (failed(Result: reifyResultShapes(b, op: opResult.getDefiningOp(), reifiedReturnShapes&: reifiedShapes)))
100 return failure();
101 mixedSizes = reifiedShapes[opResult.getResultNumber()];
102 } else {
103 // Static shape: Take static sizes directly.
104 for (int64_t sz : tensorType.getShape())
105 mixedSizes.push_back(Elt: b.getIndexAttr(value: sz));
106 }
107
108 // Create empty tensor.
109 Value emptyTensor =
110 b.create<tensor::EmptyOp>(location: loc, args&: mixedSizes, args: tensorType.getElementType());
111 return emptyTensor;
112}
113
114LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
115 Operation *op,
116 SmallVector<Value> &result) {
117 for (OpResult opResult : op->getResults()) {
118 if (llvm::isa<TensorType>(Val: opResult.getType())) {
119 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
120 if (failed(Result: destination))
121 return failure();
122 result.push_back(Elt: *destination);
123 }
124 }
125 return success();
126}
127
128bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
129 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(Val&: tp1)) {
130 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(Val&: tp2))
131 return rtp1.getShape() == rtp2.getShape() &&
132 rtp1.getElementType() == rtp2.getElementType();
133 return false;
134 }
135 return tp1 == tp2; // default implementation
136}
137
138/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
139/// rank-extending tensor.insert_slice op.
140static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
141 ArrayRef<OpFoldResult> mixedSizes) {
142 llvm::SmallBitVector droppedDims(mixedSizes.size());
143 int64_t shapePos = reducedShape.size() - 1;
144
145 for (const auto &size : enumerate(First: llvm::reverse(C&: mixedSizes))) {
146 size_t idx = mixedSizes.size() - size.index() - 1;
147 // Rank-reduced dims must have a static unit dimension.
148 bool isStaticUnitSize =
149 isa<Attribute>(Val: size.value()) &&
150 llvm::cast<IntegerAttr>(Val: cast<Attribute>(Val: size.value())).getInt() == 1;
151
152 if (shapePos < 0) {
153 // There are no more dims in the reduced shape. All remaining sizes must
154 // be rank-reduced dims.
155 assert(isStaticUnitSize && "expected unit dim");
156 droppedDims.set(idx);
157 continue;
158 }
159
160 // Dim is preserved if the size is not a static 1.
161 if (!isStaticUnitSize) {
162 --shapePos;
163 continue;
164 }
165
166 // Dim is preserved if the reduced shape dim is also 1.
167 if (reducedShape[shapePos] == 1) {
168 --shapePos;
169 continue;
170 }
171
172 // Otherwise: Dim is dropped.
173 droppedDims.set(idx);
174 }
175
176 assert(shapePos < 0 && "dimension mismatch");
177 return droppedDims;
178}
179
180/// Given a ranked tensor type and a range of values that defines its dynamic
181/// dimension sizes, turn all dynamic sizes that have a constant value into
182/// static dimension sizes.
183static RankedTensorType
184foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
185 SmallVector<Value> &foldedDynamicSizes) {
186 SmallVector<int64_t> staticShape(type.getShape());
187 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
188 "incorrect number of dynamic sizes");
189
190 // Compute new static and dynamic sizes.
191 unsigned ctr = 0;
192 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
193 if (type.isDynamicDim(idx: i)) {
194 Value dynamicSize = dynamicSizes[ctr++];
195 std::optional<int64_t> cst = getConstantIntValue(ofr: dynamicSize);
196 if (cst.has_value()) {
197 // Dynamic size must be non-negative.
198 if (cst.value() < 0) {
199 foldedDynamicSizes.push_back(Elt: dynamicSize);
200 continue;
201 }
202 staticShape[i] = *cst;
203 } else {
204 foldedDynamicSizes.push_back(Elt: dynamicSize);
205 }
206 }
207 }
208
209 return RankedTensorType::get(shape: staticShape, elementType: type.getElementType(),
210 encoding: type.getEncoding());
211}
212
213//===----------------------------------------------------------------------===//
214// BitcastOp
215//===----------------------------------------------------------------------===//
216
217bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
218 if (inputs.size() != 1 || outputs.size() != 1)
219 return false;
220 Type a = inputs.front(), b = outputs.front();
221 auto aT = dyn_cast<TensorType>(Val&: a);
222 auto bT = dyn_cast<TensorType>(Val&: b);
223 if (!aT || !bT)
224 return false;
225
226 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
227 return false;
228
229 return succeeded(Result: verifyCompatibleShape(type1: aT, type2: bT));
230}
231
232namespace {
233
234/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
235/// operation.
236struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
237 using OpRewritePattern<BitcastOp>::OpRewritePattern;
238
239 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
240 PatternRewriter &rewriter) const final {
241 auto tensorBitcastOperand =
242 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
243 if (!tensorBitcastOperand)
244 return failure();
245
246 auto resultType = cast<TensorType>(Val: tensorBitcast.getType());
247 rewriter.replaceOpWithNewOp<BitcastOp>(op: tensorBitcast, args&: resultType,
248 args: tensorBitcastOperand.getOperand());
249 return success();
250 }
251};
252
253} // namespace
254
255void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
256 MLIRContext *context) {
257 results.add<ChainedTensorBitcast>(arg&: context);
258}
259
260//===----------------------------------------------------------------------===//
261// CastOp
262//===----------------------------------------------------------------------===//
263
264void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
265 setNameFn(getResult(), "cast");
266}
267
268/// Returns true if `target` is a ranked tensor type that preserves static
269/// information available in the `source` ranked tensor type.
270bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
271 auto sourceType = llvm::dyn_cast<RankedTensorType>(Val&: source);
272 auto targetType = llvm::dyn_cast<RankedTensorType>(Val&: target);
273
274 // Requires RankedTensorType.
275 if (!sourceType || !targetType)
276 return false;
277
278 // Requires same elemental type.
279 if (sourceType.getElementType() != targetType.getElementType())
280 return false;
281
282 // Requires same rank.
283 if (sourceType.getRank() != targetType.getRank())
284 return false;
285
286 // Requires same encoding.
287 if (sourceType.getEncoding() != targetType.getEncoding())
288 return false;
289
290 // If cast is towards more static sizes along any dimension, don't fold.
291 for (auto t : llvm::zip(t: sourceType.getShape(), u: targetType.getShape())) {
292 if (ShapedType::isStatic(dValue: std::get<0>(t&: t)) &&
293 ShapedType::isDynamic(dValue: std::get<1>(t&: t)))
294 return false;
295 }
296
297 return true;
298}
299
300/// Determines whether tensor::CastOp casts to a more dynamic version of the
301/// source tensor. This is useful to fold a tensor.cast into a consuming op and
302/// implement canonicalization patterns for ops in different dialects that may
303/// consume the results of tensor.cast operations. Such foldable tensor.cast
304/// operations are typically inserted as `slice` ops and are canonicalized,
305/// to preserve the type compatibility of their uses.
306///
307/// Returns true when all conditions are met:
308/// 1. source and result are ranked tensors with same element type and rank.
309/// 2. the tensor type has more static information than the result
310///
311/// Example:
312/// ```mlir
313/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
314/// %2 = consumer %1 ... : tensor<?x?xf32> ...
315/// ```
316///
317/// folds into:
318///
319/// ```mlir
320/// %2 = consumer %0 ... : tensor<8x16xf32> ...
321/// ```
322bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
323 if (!castOp)
324 return false;
325
326 // Can fold if the source of cast has at least as much static information as
327 // its results.
328 return preservesStaticInformation(source: castOp.getType(),
329 target: castOp.getSource().getType());
330}
331
332/// Determines whether the tensor::CastOp casts to a more static version of the
333/// source tensor. This is useful to fold into a producing op and implement
334/// canonicalization patterns with the `tensor.cast` op as the root, but
335/// producer being from different dialects. Returns true when all conditions are
336/// met:
337/// 1. source and result and ranked tensors with same element type and rank.
338/// 2. the result type has more static information than the source.
339///
340/// Example:
341/// ```mlir
342/// %1 = producer ... : tensor<?x?xf32>
343/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
344/// ```
345///
346/// can be canonicalized to :
347///
348/// ```mlir
349/// %2 = producer ... : tensor<8x16xf32>
350/// ```
351/// Not all ops might be canonicalizable this way, but for those that can be,
352/// this method provides a check that it is worth doing the canonicalization.
353bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
354 if (!castOp)
355 return false;
356 return preservesStaticInformation(source: castOp.getSource().getType(),
357 target: castOp.getType());
358}
359
360bool mlir::tensor::hasFoldableTensorCastOperand(Operation *op) {
361 return llvm::any_of(Range: op->getOpOperands(), P: [&](OpOperand &opOperand) {
362 if (llvm::isa<BlockArgument>(Val: opOperand.get()))
363 return false;
364 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
365 return castOp && canFoldIntoConsumerOp(castOp);
366 });
367}
368
369SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding(
370 DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
371 SmallVector<Value> newOperands;
372 newOperands.reserve(N: op->getNumOperands());
373
374 assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
375
376 // Assumes that the result has dpsInits followed by nonDpsInits.
377 int64_t dpsInitIdx = 0;
378 for (OpOperand &opOperand : op->getOpOperands()) {
379 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
380 bool fold = canFoldIntoConsumerOp(castOp: tensorCastOp);
381 newOperands.push_back(Elt: fold ? tensorCastOp.getOperand() : opOperand.get());
382 if (op.isDpsInit(opOperand: &opOperand) &&
383 !llvm::isa<MemRefType>(Val: newOperands.back().getType()))
384 newResTy[dpsInitIdx++] = newOperands.back().getType();
385 }
386 return newOperands;
387}
388
389/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
390/// that can be folded.
391LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
392 bool folded = false;
393 for (OpOperand &operand : op->getOpOperands()) {
394 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
395 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
396 operand.set(castOp.getOperand());
397 folded = true;
398 }
399 }
400 return success(IsSuccess: folded);
401}
402
403bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
404 if (inputs.size() != 1 || outputs.size() != 1)
405 return false;
406 Type a = inputs.front(), b = outputs.front();
407 auto aT = llvm::dyn_cast<TensorType>(Val&: a);
408 auto bT = llvm::dyn_cast<TensorType>(Val&: b);
409 if (!aT || !bT)
410 return false;
411
412 if (aT.getElementType() != bT.getElementType())
413 return false;
414
415 return succeeded(Result: verifyCompatibleShape(type1: aT, type2: bT));
416}
417
418/// Compute a TensorType that has the joined shape knowledge of the two
419/// given TensorTypes. The element types need to match.
420static TensorType joinShapes(TensorType one, TensorType two) {
421 assert(one.getElementType() == two.getElementType());
422
423 if (!one.hasRank())
424 return two;
425 if (!two.hasRank())
426 return one;
427
428 int64_t rank = one.getRank();
429 if (rank != two.getRank())
430 return {};
431
432 SmallVector<int64_t, 4> join;
433 join.reserve(N: rank);
434 for (int64_t i = 0; i < rank; ++i) {
435 if (one.isDynamicDim(idx: i)) {
436 join.push_back(Elt: two.getDimSize(idx: i));
437 continue;
438 }
439 if (two.isDynamicDim(idx: i)) {
440 join.push_back(Elt: one.getDimSize(idx: i));
441 continue;
442 }
443 if (one.getDimSize(idx: i) != two.getDimSize(idx: i))
444 return {};
445 join.push_back(Elt: one.getDimSize(idx: i));
446 }
447 return RankedTensorType::get(shape: join, elementType: one.getElementType());
448}
449
450namespace {
451
452/// Replaces chains of two tensor.cast operations by a single tensor.cast
453/// operation if doing so does not remove runtime constraints.
454struct ChainedTensorCast : public OpRewritePattern<CastOp> {
455 using OpRewritePattern<CastOp>::OpRewritePattern;
456
457 LogicalResult matchAndRewrite(CastOp tensorCast,
458 PatternRewriter &rewriter) const final {
459 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
460
461 if (!tensorCastOperand)
462 return failure();
463
464 auto sourceType =
465 llvm::cast<TensorType>(Val: tensorCastOperand.getOperand().getType());
466 auto intermediateType = llvm::cast<TensorType>(Val: tensorCastOperand.getType());
467 auto resultType = llvm::cast<TensorType>(Val: tensorCast.getType());
468
469 // We can remove the intermediate cast if joining all three produces the
470 // same result as just joining the source and result shapes.
471 auto firstJoin =
472 joinShapes(one: joinShapes(one: sourceType, two: intermediateType), two: resultType);
473
474 // The join might not exist if the cast sequence would fail at runtime.
475 if (!firstJoin)
476 return failure();
477
478 // The newJoin always exists if the above join exists, it might just contain
479 // less information. If so, we cannot drop the intermediate cast, as doing
480 // so would remove runtime checks.
481 auto newJoin = joinShapes(one: sourceType, two: resultType);
482 if (firstJoin != newJoin)
483 return failure();
484
485 rewriter.replaceOpWithNewOp<CastOp>(op: tensorCast, args&: resultType,
486 args: tensorCastOperand.getOperand());
487 return success();
488 }
489};
490
491/// Fold tensor.cast into tesor.extract_slice producer.
492/// Example:
493/// ```
494/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
495/// tensor<128x512xf32> to tensor<?x512xf32>
496/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
497/// ```
498/// ->
499/// ```
500/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
501/// tensor<128x512xf32> to tensor<16x512xf32>
502/// ```
503struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
504 using OpRewritePattern<CastOp>::OpRewritePattern;
505
506 LogicalResult matchAndRewrite(CastOp tensorCast,
507 PatternRewriter &rewriter) const final {
508 auto extractOperand =
509 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
510
511 // Cannot fold cast to unranked tensor.
512 auto rankedResultType =
513 llvm::dyn_cast<RankedTensorType>(Val: tensorCast.getType());
514 if (!rankedResultType)
515 return failure();
516
517 if (!extractOperand || !canFoldIntoProducerOp(castOp: tensorCast) ||
518 rankedResultType.getShape() ==
519 llvm::cast<RankedTensorType>(Val: tensorCast.getSource().getType())
520 .getShape())
521 return failure();
522
523 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
524 auto dimMask = computeRankReductionMask(
525 originalShape: extractOperand.getStaticSizes(), reducedShape: extractOperand.getType().getShape());
526 size_t dimIndex = 0;
527 for (size_t i = 0, e = sizes.size(); i < e; i++) {
528 if (dimMask && dimMask->count(V: i))
529 continue;
530 int64_t dim = rankedResultType.getShape()[dimIndex++];
531 if (ShapedType::isDynamic(dValue: dim))
532 continue;
533 sizes[i] = rewriter.getIndexAttr(value: dim);
534 }
535
536 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
537 op: tensorCast, args&: rankedResultType, args: extractOperand.getSource(),
538 args: extractOperand.getMixedOffsets(), args&: sizes,
539 args: extractOperand.getMixedStrides());
540 return success();
541 }
542};
543
544} // namespace
545
546void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
547 MLIRContext *context) {
548 results.add<ChainedTensorCast, TensorCastExtractSlice>(arg&: context);
549}
550
551//===----------------------------------------------------------------------===//
552// ConcatOp
553//===----------------------------------------------------------------------===//
554
555RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
556 assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
557 auto tensorTypes =
558 llvm::to_vector<4>(Range: llvm::map_range(C&: inputTypes, F: [](Type type) {
559 return llvm::cast<RankedTensorType>(Val&: type);
560 }));
561 int64_t concatRank = tensorTypes[0].getRank();
562
563 // The concatenation dim must be in the range [0, rank).
564 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
565
566 SmallVector<int64_t> sizes(concatRank);
567 for (int64_t i = 0, e = concatRank; i < e; ++i) {
568 if (i == dim)
569 continue;
570 SaturatedInteger size;
571 for (auto tensorType : tensorTypes)
572 size = *size.desaturate(other: SaturatedInteger::wrap(v: tensorType.getDimSize(idx: i)));
573 sizes[i] = size.asInteger();
574 }
575 auto concatSize = SaturatedInteger::wrap(v: 0);
576 for (auto tensorType : tensorTypes)
577 concatSize =
578 concatSize + SaturatedInteger::wrap(v: tensorType.getDimSize(idx: dim));
579 sizes[dim] = concatSize.asInteger();
580 return RankedTensorType::get(shape: sizes, elementType: tensorTypes[0].getElementType());
581}
582
583void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
584 ValueRange inputs) {
585 FailureOr<RankedTensorType> resultType =
586 inferResultType(dim, inputTypes: inputs.getTypes());
587 assert(succeeded(resultType) && "failed to infer concatenation result type");
588 build(odsBuilder&: builder, odsState&: result, result: *resultType, dim, inputs);
589}
590
591LogicalResult ConcatOp::verify() {
592 if (getInputs().size() < 1)
593 return emitOpError(message: "requires at least one input");
594
595 SmallVector<RankedTensorType> inputTypes;
596 for (auto input : getInputs())
597 inputTypes.push_back(Elt: cast<RankedTensorType>(Val: input.getType()));
598
599 RankedTensorType resultType = getResultType();
600 int64_t resultRank = getRank();
601 if (llvm::any_of(Range&: inputTypes, P: [resultRank](RankedTensorType type) {
602 return type.getRank() != resultRank;
603 }))
604 return emitOpError(message: "rank of concatenated inputs must match result rank");
605
606 Type resultElementType = resultType.getElementType();
607 if (llvm::any_of(Range&: inputTypes, P: [&](RankedTensorType type) {
608 return type.getElementType() != resultElementType;
609 }))
610 return emitOpError(message: "inputs and result element type must match");
611
612 int64_t dim = getDim();
613 if (dim >= resultRank)
614 return emitOpError(message: "concatenation dim must be less than the tensor rank");
615
616 SmallVector<int64_t> sizes(resultRank);
617 for (int64_t i = 0, e = resultRank; i < e; ++i) {
618 if (i == dim)
619 continue;
620 SaturatedInteger size;
621 for (auto tensorType : inputTypes) {
622 FailureOr<SaturatedInteger> maybeSize =
623 size.desaturate(other: SaturatedInteger::wrap(v: tensorType.getDimSize(idx: i)));
624 if (failed(Result: maybeSize))
625 return emitOpError(message: "static concatenation size mismatch along ")
626 << "non-concatenated dimension " << i;
627 size = *maybeSize;
628 }
629 sizes[i] = size.asInteger();
630 }
631 auto concatSize = SaturatedInteger::wrap(v: 0);
632 for (auto tensorType : inputTypes)
633 concatSize =
634 concatSize + SaturatedInteger::wrap(v: tensorType.getDimSize(idx: dim));
635 sizes[dim] = concatSize.asInteger();
636 auto inferredResultType =
637 RankedTensorType::get(shape: sizes, elementType: inputTypes[0].getElementType());
638
639 for (auto [inferredSize, actualSize] :
640 llvm::zip_equal(t: inferredResultType.getShape(), u: resultType.getShape())) {
641 bool hasDynamic = ShapedType::isDynamic(dValue: inferredSize) ||
642 ShapedType::isDynamic(dValue: actualSize);
643 if (!hasDynamic && inferredSize != actualSize)
644 return emitOpError(message: "result type ")
645 << resultType << "does not match inferred shape "
646 << inferredResultType << " static sizes";
647 }
648
649 return success();
650}
651
652FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
653 size_t numInputs = getInputs().size();
654 uint64_t concatDim = getDim();
655
656 SmallVector<SmallVector<OpFoldResult>> inputShapes;
657 inputShapes.reserve(N: numInputs);
658 SmallVector<OpFoldResult> concatOffsets;
659 concatOffsets.reserve(N: numInputs);
660 SmallVector<OpFoldResult> outputShape;
661
662 AffineExpr addExpr =
663 builder.getAffineSymbolExpr(position: 0) + builder.getAffineSymbolExpr(position: 1);
664 OpFoldResult zero = builder.getIndexAttr(value: 0);
665 Location loc = getLoc();
666 for (auto [index, input] : llvm::enumerate(First: getInputs())) {
667 SmallVector<OpFoldResult> inputShape =
668 tensor::getMixedSizes(builder, loc: input.getLoc(), value: input);
669 if (index == 0) {
670 outputShape = inputShape;
671 concatOffsets.push_back(Elt: zero);
672 } else {
673 concatOffsets.push_back(Elt: outputShape[concatDim]);
674 outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
675 b&: builder, loc, expr: addExpr,
676 operands: {outputShape[concatDim], inputShape[concatDim]});
677 }
678 inputShapes.emplace_back(Args: std::move(inputShape));
679 }
680
681 Value replacement = builder.create<tensor::EmptyOp>(
682 location: loc, args&: outputShape, args: getType().getElementType());
683
684 int64_t rank = getType().getRank();
685 OpFoldResult one = builder.getIndexAttr(value: 1);
686 SmallVector<OpFoldResult> strides(rank, one);
687 SmallVector<OpFoldResult> offsets(rank, zero);
688 for (auto [index, input] : llvm::enumerate(First: getInputs())) {
689 offsets[concatDim] = concatOffsets[index];
690 auto insertSlice = builder.create<tensor::InsertSliceOp>(
691 location: loc, args&: input, args&: replacement, args&: offsets, args&: inputShapes[index], args&: strides);
692 replacement = insertSlice.getResult();
693 }
694 if (replacement.getType() != getType()) {
695 replacement = builder.create<tensor::CastOp>(location: loc, args: getType(), args&: replacement);
696 }
697 return SmallVector<Value>{replacement};
698}
699
700LogicalResult
701ConcatOp::reifyResultShapes(OpBuilder &builder,
702 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
703 ValueRange inputs = getInputs();
704 int64_t dim = getDim();
705 RankedTensorType inferredResultType = inferResultType(dim, inputTypes: inputs.getTypes());
706
707 Value init = inputs[0];
708 int64_t rank = getType().getRank();
709
710 reifiedReturnShapes.resize(N: 1, NV: SmallVector<OpFoldResult>(rank));
711
712 // Pre-populate the result sizes with as much static information as possible
713 // from the given result type, as well as the inferred result type, otherwise
714 // use the dim sizes from the first input.
715 for (int64_t i = 0; i < rank; ++i) {
716 if (i == dim)
717 continue;
718 if (!getType().isDynamicDim(idx: i)) {
719 reifiedReturnShapes[0][i] = builder.getIndexAttr(value: getType().getDimSize(idx: i));
720 } else if (!inferredResultType.isDynamicDim(idx: i)) {
721 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
722 b&: builder, loc: getLoc(),
723 ofr: builder.getIndexAttr(value: inferredResultType.getDimSize(idx: i)));
724 } else {
725 reifiedReturnShapes[0][i] =
726 builder.create<tensor::DimOp>(location: init.getLoc(), args&: init, args&: i).getResult();
727 }
728 }
729
730 if (getType().isDynamicDim(idx: dim)) {
731 // Take the sum of the input sizes along the concatenated dim.
732 AffineExpr sum = builder.getAffineDimExpr(position: 0);
733 SmallVector<OpFoldResult> sizes = {
734 builder.createOrFold<tensor::DimOp>(location: init.getLoc(), args&: init, args&: dim)};
735 for (auto [idx, input] : llvm::enumerate(First: inputs.drop_front())) {
736 sum = sum + builder.getAffineDimExpr(position: idx + 1);
737 sizes.push_back(
738 Elt: builder.createOrFold<tensor::DimOp>(location: input.getLoc(), args&: input, args&: dim));
739 }
740 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
741 b&: builder, loc: getLoc(),
742 ofr: affine::makeComposedFoldedAffineApply(b&: builder, loc: getLoc(), expr: sum, operands: sizes));
743 } else {
744 // If the result shape is static along the concatenated dim, use the static
745 // shape.
746 reifiedReturnShapes[0][dim] =
747 builder.getIndexAttr(value: getType().getDimSize(idx: dim));
748 }
749 return success();
750}
751
752void ConcatOp::getAsmResultNames(
753 function_ref<void(Value, StringRef)> setNameFn) {
754 setNameFn(getResult(), "concat");
755}
756
757OpFoldResult ConcatOp::fold(FoldAdaptor) {
758 ValueRange inputs = getInputs();
759 if (inputs.size() == 1 && inputs[0].getType() == getResultType())
760 return inputs[0];
761 return {};
762}
763
764namespace {
765/// Fold a concat op with a single input to a cast.
766struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
767 using OpRewritePattern<ConcatOp>::OpRewritePattern;
768
769 LogicalResult matchAndRewrite(ConcatOp concatOp,
770 PatternRewriter &rewriter) const override {
771 if (concatOp.getInputs().size() != 1)
772 return failure();
773 rewriter.replaceOpWithNewOp<CastOp>(op: concatOp, args: concatOp.getResultType(),
774 args: concatOp.getInputs()[0]);
775 return success();
776 }
777};
778
779/// Propagate static shapes into the operands of a `tensor.concat`.
780///
781/// `tensor.concat` requires every operand to match on all dimensions except the
782/// concatenation dimension. If one operand is already static in those
783/// dimensions, the other operands may safely be refined to that same static
784/// shape.
785///
786/// Example:
787///
788/// ```mlir
789/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
790/// tensor<?x12xi32>
791/// ```
792/// ->
793/// ```mlir
794/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
795/// %2 = tensor.concat dim(0) %0, %cast :
796/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
797/// ```
798struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
799 using OpRewritePattern<ConcatOp>::OpRewritePattern;
800
801 LogicalResult matchAndRewrite(ConcatOp concatOp,
802 PatternRewriter &rewriter) const override {
803 int64_t dim = concatOp.getDim();
804 RankedTensorType inferredResultType =
805 ConcatOp::inferResultType(dim, inputTypes: concatOp->getOperandTypes());
806
807 // Find operands for which a more static shape can be inferred.
808 LogicalResult matched = failure();
809 // Inferred operand shapes are identical in every dimension except the
810 // concatenation dimension.
811 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
812 for (auto [operandIdx, operandType] :
813 llvm::enumerate(First: concatOp->getOperandTypes())) {
814 // Compute inferred type for operand.
815 inferredOperandShape[dim] =
816 cast<RankedTensorType>(Val&: operandType).getDimSize(idx: dim);
817 auto inferredOperandType = RankedTensorType::get(
818 shape: inferredOperandShape, elementType: inferredResultType.getElementType());
819
820 // Check if inferred type is more static.
821 if (!preservesStaticInformation(source: inferredOperandType, target: operandType)) {
822 matched = success();
823
824 // Use refined operand type and create cast from original operand.
825 auto castOp =
826 rewriter.create<CastOp>(location: concatOp->getLoc(), args&: inferredOperandType,
827 args: concatOp.getOperand(i: operandIdx));
828 rewriter.modifyOpInPlace(root: concatOp, callable: [=, operandIdx = operandIdx] {
829 concatOp->setOperand(idx: operandIdx, value: castOp->getResult(idx: 0));
830 });
831 }
832 }
833
834 return matched;
835 }
836};
837
838// Ensure `tensor.concat`'s result type is at least as static as can be inferred
839// from its operand types.
840///
841/// Example:
842/// ```mlir
843/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
844/// tensor<?x?xi32>
845/// ```
846/// ->
847/// ```mlir
848/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
849/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
850/// tensor<?x?xi32>
851/// ```
852struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
853 using OpRewritePattern<ConcatOp>::OpRewritePattern;
854
855 LogicalResult matchAndRewrite(ConcatOp concatOp,
856 PatternRewriter &rewriter) const override {
857 int64_t dim = concatOp.getDim();
858 RankedTensorType inferredResultType =
859 ConcatOp::inferResultType(dim, inputTypes: concatOp->getOperandTypes());
860
861 // The result type should be at least as static as inferred result type.
862 if (preservesStaticInformation(source: inferredResultType,
863 target: concatOp.getResultType())) {
864 return failure();
865 }
866
867 auto newConcatOp = rewriter.create<ConcatOp>(
868 location: concatOp->getLoc(), args&: inferredResultType, args&: dim, args: concatOp->getOperands());
869 rewriter.replaceOpWithNewOp<CastOp>(op: concatOp, args: concatOp.getResultType(),
870 args&: newConcatOp);
871
872 return success();
873 }
874};
875} // namespace
876
877void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
878 MLIRContext *context) {
879 results
880 .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
881 arg&: context);
882}
883
884//===----------------------------------------------------------------------===//
885// DimOp
886//===----------------------------------------------------------------------===//
887
888void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
889 setNameFn(getResult(), "dim");
890}
891
892void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
893 int64_t index) {
894 auto loc = result.location;
895 Value indexValue = builder.create<arith::ConstantIndexOp>(location: loc, args&: index);
896 build(odsBuilder&: builder, odsState&: result, source, index: indexValue);
897}
898
899std::optional<int64_t> DimOp::getConstantIndex() {
900 return getConstantIntValue(ofr: getIndex());
901}
902
903Speculation::Speculatability DimOp::getSpeculatability() {
904 auto constantIndex = getConstantIndex();
905 if (!constantIndex)
906 return Speculation::NotSpeculatable;
907
908 auto rankedSourceType = dyn_cast<RankedTensorType>(Val: getSource().getType());
909 if (!rankedSourceType)
910 return Speculation::NotSpeculatable;
911
912 if (rankedSourceType.getRank() <= constantIndex)
913 return Speculation::NotSpeculatable;
914
915 return Speculation::Speculatable;
916}
917
918void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
919 SetIntLatticeFn setResultRange) {
920 setResultRange(getResult(),
921 intrange::inferShapedDimOpInterface(op: *this, maybeDim: argRanges[1]));
922}
923
924OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
925 // All forms of folding require a known index.
926 auto index = llvm::dyn_cast_if_present<IntegerAttr>(Val: adaptor.getIndex());
927 if (!index)
928 return {};
929
930 // Folding for unranked types (UnrankedTensorType) is not supported.
931 auto tensorType = llvm::dyn_cast<RankedTensorType>(Val: getSource().getType());
932 if (!tensorType)
933 return {};
934
935 // Out of bound indices produce undefined behavior but are still valid IR.
936 // Don't choke on them.
937 int64_t indexVal = index.getInt();
938 if (indexVal < 0 || indexVal >= tensorType.getRank())
939 return {};
940
941 // Fold if the shape extent along the given index is known.
942 if (!tensorType.isDynamicDim(idx: index.getInt())) {
943 Builder builder(getContext());
944 return builder.getIndexAttr(value: tensorType.getShape()[index.getInt()]);
945 }
946
947 Operation *definingOp = getSource().getDefiningOp();
948
949 // Fold dim to the operand of tensor.generate.
950 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(Val: definingOp)) {
951 auto resultType =
952 llvm::cast<RankedTensorType>(Val: fromElements.getResult().getType());
953 // The case where the type encodes the size of the dimension is handled
954 // above.
955 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
956
957 // Find the operand of the fromElements that corresponds to this index.
958 auto dynExtents = fromElements.getDynamicExtents().begin();
959 for (auto dim : resultType.getShape().take_front(N: index.getInt()))
960 if (ShapedType::isDynamic(dValue: dim))
961 dynExtents++;
962
963 return Value{*dynExtents};
964 }
965
966 // The size at the given index is now known to be a dynamic size.
967 unsigned unsignedIndex = index.getValue().getZExtValue();
968
969 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(Val: definingOp)) {
970 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on
971 // `resolve-shaped-type-result-dims` pass.
972 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
973 sliceOp.isDynamicSize(idx: unsignedIndex)) {
974 return {sliceOp.getDynamicSize(idx: unsignedIndex)};
975 }
976 }
977
978 // dim(cast) -> dim
979 if (succeeded(Result: foldTensorCast(op: *this)))
980 return getResult();
981
982 return {};
983}
984
985namespace {
986/// Fold dim of a cast into the dim of the source of the tensor cast.
987struct DimOfCastOp : public OpRewritePattern<DimOp> {
988 using OpRewritePattern<DimOp>::OpRewritePattern;
989
990 LogicalResult matchAndRewrite(DimOp dimOp,
991 PatternRewriter &rewriter) const override {
992 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
993 if (!castOp)
994 return failure();
995 Value newSource = castOp.getOperand();
996 rewriter.replaceOpWithNewOp<DimOp>(op: dimOp, args&: newSource, args: dimOp.getIndex());
997 return success();
998 }
999};
1000
1001/// Fold dim of a destination passing style op into the dim of the corresponding
1002/// init.
1003struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
1004 using OpRewritePattern<DimOp>::OpRewritePattern;
1005
1006 LogicalResult matchAndRewrite(DimOp dimOp,
1007 PatternRewriter &rewriter) const override {
1008 auto source = dimOp.getSource();
1009 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1010 if (!destOp)
1011 return failure();
1012
1013 auto resultIndex = cast<OpResult>(Val&: source).getResultNumber();
1014 auto *initOperand = destOp.getDpsInitOperand(i: resultIndex);
1015
1016 rewriter.modifyOpInPlace(
1017 root: dimOp, callable: [&]() { dimOp.getSourceMutable().assign(value: initOperand->get()); });
1018 return success();
1019 }
1020};
1021
1022/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
1023/// operand.
1024struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
1025 using OpRewritePattern<DimOp>::OpRewritePattern;
1026
1027 LogicalResult matchAndRewrite(DimOp dim,
1028 PatternRewriter &rewriter) const override {
1029 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1030
1031 if (!reshape)
1032 return failure();
1033
1034 // Since tensors are immutable we don't need to worry about where to place
1035 // the extract call
1036 rewriter.setInsertionPointAfter(dim);
1037 Location loc = dim.getLoc();
1038 Value extract =
1039 rewriter.create<ExtractOp>(location: loc, args: reshape.getShape(), args: dim.getIndex());
1040 if (extract.getType() != dim.getType())
1041 extract =
1042 rewriter.create<arith::IndexCastOp>(location: loc, args: dim.getType(), args&: extract);
1043 rewriter.replaceOp(op: dim, newValues: extract);
1044 return success();
1045 }
1046};
1047} // namespace
1048
1049void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1050 MLIRContext *context) {
1051 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(arg&: context);
1052}
1053
1054//===----------------------------------------------------------------------===//
1055// EmptyOp
1056//===----------------------------------------------------------------------===//
1057
1058void EmptyOp::build(OpBuilder &builder, OperationState &result,
1059 ArrayRef<int64_t> staticShape, Type elementType,
1060 Attribute encoding) {
1061 assert(none_of(staticShape, ShapedType::isDynamic) &&
1062 "expected only static sizes");
1063 build(odsBuilder&: builder, odsState&: result, staticShape, elementType, dynamicSizes: ValueRange{}, encoding);
1064}
1065
1066void EmptyOp::build(OpBuilder &builder, OperationState &result,
1067 ArrayRef<int64_t> staticShape, Type elementType,
1068 ValueRange dynamicSizes, Attribute encoding) {
1069 auto tensorType = RankedTensorType::get(shape: staticShape, elementType, encoding);
1070 build(odsBuilder&: builder, odsState&: result, result: tensorType, dynamicSizes);
1071}
1072
1073void EmptyOp::build(OpBuilder &builder, OperationState &result,
1074 ArrayRef<OpFoldResult> sizes, Type elementType,
1075 Attribute encoding) {
1076 SmallVector<int64_t> staticShape;
1077 SmallVector<Value> dynamicSizes;
1078 dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynamicSizes, staticVec&: staticShape);
1079 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1080}
1081
1082LogicalResult EmptyOp::verify() {
1083 if (getType().getNumDynamicDims() != getDynamicSizes().size())
1084 return emitOpError(message: "incorrect number of dynamic sizes, has ")
1085 << getDynamicSizes().size() << ", expected "
1086 << getType().getNumDynamicDims();
1087 return success();
1088}
1089
1090LogicalResult
1091EmptyOp::reifyResultShapes(OpBuilder &builder,
1092 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1093 reifiedReturnShapes.resize(N: 1, NV: SmallVector<OpFoldResult>(getType().getRank()));
1094 unsigned ctr = 0;
1095 for (int64_t i = 0; i < getType().getRank(); ++i) {
1096 if (getType().isDynamicDim(idx: i)) {
1097 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
1098 } else {
1099 reifiedReturnShapes[0][i] = builder.getIndexAttr(value: getType().getDimSize(idx: i));
1100 }
1101 }
1102 return success();
1103}
1104
1105Value EmptyOp::getDynamicSize(unsigned idx) {
1106 assert(getType().isDynamicDim(idx) && "expected dynamic dim");
1107 unsigned ctr = 0;
1108 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1109 if (getType().isDynamicDim(idx: i))
1110 ++ctr;
1111 return getDynamicSizes()[ctr];
1112}
1113
1114SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1115 SmallVector<OpFoldResult> result;
1116 unsigned ctr = 0;
1117 OpBuilder b(getContext());
1118 for (int64_t i = 0; i < getType().getRank(); ++i) {
1119 if (getType().isDynamicDim(idx: i)) {
1120 result.push_back(Elt: getDynamicSizes()[ctr++]);
1121 } else {
1122 result.push_back(Elt: b.getIndexAttr(value: getType().getShape()[i]));
1123 }
1124 }
1125 return result;
1126}
1127
1128namespace {
1129/// Change the type of the result of a `tensor.empty` by making the result
1130/// type statically sized along dimensions that in the original operation were
1131/// defined as dynamic, but the size was defined using a `constant` op. For
1132/// example
1133///
1134/// %c5 = arith.constant 5: index
1135/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
1136///
1137/// to
1138///
1139/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
1140struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1141 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1142
1143 LogicalResult matchAndRewrite(EmptyOp op,
1144 PatternRewriter &rewriter) const override {
1145 SmallVector<Value> foldedDynamicSizes;
1146 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1147 type: op.getType(), dynamicSizes: op.getDynamicSizes(), foldedDynamicSizes);
1148
1149 // Stop here if no dynamic size was promoted to static.
1150 if (foldedTensorType == op.getType())
1151 return failure();
1152
1153 auto newOp = rewriter.create<EmptyOp>(location: op.getLoc(), args&: foldedTensorType,
1154 args&: foldedDynamicSizes);
1155 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, args: op.getType(), args&: newOp);
1156 return success();
1157 }
1158};
1159
1160struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
1161 using OpRewritePattern<DimOp>::OpRewritePattern;
1162
1163 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1164 PatternRewriter &rewriter) const override {
1165 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1166 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1167 if (!emptyTensorOp || !maybeConstantIndex)
1168 return failure();
1169 auto emptyTensorType = emptyTensorOp.getType();
1170 if (*maybeConstantIndex < 0 ||
1171 *maybeConstantIndex >= emptyTensorType.getRank() ||
1172 !emptyTensorType.isDynamicDim(idx: *maybeConstantIndex))
1173 return failure();
1174 rewriter.replaceOp(op: dimOp,
1175 newValues: emptyTensorOp.getDynamicSize(idx: *maybeConstantIndex));
1176 return success();
1177 }
1178};
1179
1180/// Canonicalize
1181///
1182/// ```mlir
1183/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1184/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1185/// ```
1186///
1187/// into
1188///
1189/// ```mlir
1190/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
1191/// ```
1192///
1193/// This assumes the input program is correct in terms of its shape. So it is
1194/// safe to assume that `%d0` is in fact 4.
1195struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
1196 using OpRewritePattern<CastOp>::OpRewritePattern;
1197
1198 LogicalResult matchAndRewrite(CastOp castOp,
1199 PatternRewriter &rewriter) const override {
1200 if (!canFoldIntoProducerOp(castOp))
1201 return failure();
1202 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1203 if (!producer)
1204 return failure();
1205
1206 auto resultType =
1207 llvm::cast<RankedTensorType>(Val: castOp->getResult(idx: 0).getType());
1208 ArrayRef<int64_t> resultShape = resultType.getShape();
1209 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1210 SmallVector<OpFoldResult> newMixedSizes;
1211 newMixedSizes.reserve(N: currMixedSizes.size());
1212 assert(resultShape.size() == currMixedSizes.size() &&
1213 "mismatch in result shape and sizes of empty op");
1214 for (auto it : llvm::zip(t&: resultShape, u&: currMixedSizes)) {
1215 int64_t newDim = std::get<0>(t&: it);
1216 OpFoldResult currDim = std::get<1>(t&: it);
1217 // Case 1: The empty tensor dim is static. Check that the tensor cast
1218 // result dim matches.
1219 if (auto attr = llvm::dyn_cast_if_present<Attribute>(Val&: currDim)) {
1220 if (ShapedType::isDynamic(dValue: newDim) ||
1221 newDim != llvm::cast<IntegerAttr>(Val&: attr).getInt()) {
1222 // Something is off, the cast result shape cannot be more dynamic
1223 // than the empty tensor result shape (enforced by
1224 // `canFoldIntoProducer`). Abort for now.
1225 return rewriter.notifyMatchFailure(
1226 arg&: producer, msg: "mismatch in static value of shape of empty tensor "
1227 "result and cast result");
1228 }
1229 newMixedSizes.push_back(Elt: attr);
1230 continue;
1231 }
1232
1233 // Case 2 : The tensor cast shape is static, but empty tensor result
1234 // shape is dynamic.
1235 if (ShapedType::isStatic(dValue: newDim)) {
1236 newMixedSizes.push_back(Elt: rewriter.getIndexAttr(value: newDim));
1237 continue;
1238 }
1239
1240 // Case 3 : The tensor cast shape is dynamic and empty tensor result
1241 // shape is dynamic. Use the dynamic value from the empty tensor op.
1242 newMixedSizes.push_back(Elt: currDim);
1243 }
1244
1245 // TODO: Do not drop tensor encoding.
1246 rewriter.replaceOpWithNewOp<EmptyOp>(op: castOp, args&: newMixedSizes,
1247 args: resultType.getElementType());
1248 return success();
1249 }
1250};
1251
1252} // namespace
1253
1254void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1255 MLIRContext *context) {
1256 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1257 ReplaceEmptyTensorStaticShapeDims>(arg&: context);
1258}
1259
1260//===----------------------------------------------------------------------===//
1261// ExtractOp
1262//===----------------------------------------------------------------------===//
1263
1264namespace {
1265
1266/// Canonicalizes the pattern of the form
1267///
1268/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
1269/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
1270///
1271/// to
1272///
1273/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
1274struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1275 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1276
1277 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1278 PatternRewriter &rewriter) const final {
1279 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1280 if (!tensorCast)
1281 return failure();
1282 if (!llvm::isa<RankedTensorType>(Val: tensorCast.getSource().getType()))
1283 return failure();
1284 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1285 op: extract, args: tensorCast.getSource(), args: extract.getIndices());
1286 return success();
1287 }
1288};
1289
1290/// Canonicalizes the pattern of the form
1291///
1292/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1293/// tensor<12xf64>
1294/// %extracted_element = tensor.extract %val[%c10] :
1295/// tensor<12xf64>
1296///
1297/// to
1298///
1299/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1300struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1301 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1302
1303 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1304 PatternRewriter &rewriter) const final {
1305 auto collapseOp =
1306 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1307 if (!collapseOp)
1308 return failure();
1309 if (!collapseOp.getSrcType().hasStaticShape())
1310 return failure();
1311
1312 auto sourceSizes = collapseOp.getSrcType().getShape();
1313
1314 SmallVector<Value> indices(extractOp.getIndices().begin(),
1315 extractOp.getIndices().end());
1316 SmallVector<Value> sourceIndices;
1317 for (auto [index, group] :
1318 llvm::zip(t&: indices, u: collapseOp.getReassociationIndices())) {
1319 assert(!group.empty() && "association indices groups cannot be empty");
1320 auto groupSize = group.size();
1321
1322 if (groupSize == 1) {
1323 sourceIndices.push_back(Elt: index);
1324 continue;
1325 }
1326
1327 SmallVector<int64_t> basis =
1328 llvm::map_to_vector(C&: group, F: [&](int64_t d) { return sourceSizes[d]; });
1329 auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
1330 location: extractOp.getLoc(), args&: index, args&: basis, /*hasOuterBound=*/args: true);
1331 llvm::append_range(C&: sourceIndices, R: delinearize.getResults());
1332 }
1333 if (collapseOp.getReassociationIndices().empty()) {
1334 auto zeroAffineMap = rewriter.getConstantAffineMap(val: 0);
1335 int64_t srcRank =
1336 cast<RankedTensorType>(Val: collapseOp.getSrcType()).getRank();
1337 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1338 b&: rewriter, loc: extractOp.getLoc(), map: zeroAffineMap,
1339 operands: ArrayRef<OpFoldResult>{});
1340 for (int64_t i = 0; i < srcRank; i++) {
1341 sourceIndices.push_back(
1342 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc: extractOp.getLoc(), ofr));
1343 }
1344 }
1345
1346 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1347 op: extractOp, args: collapseOp.getSrc(), args&: sourceIndices);
1348 return success();
1349 }
1350};
1351
1352} // namespace
1353
1354void ExtractOp::getAsmResultNames(
1355 function_ref<void(Value, StringRef)> setNameFn) {
1356 setNameFn(getResult(), "extracted");
1357}
1358
1359LogicalResult ExtractOp::verify() {
1360 // Verify the # indices match if we have a ranked type.
1361 auto tensorType = llvm::cast<RankedTensorType>(Val: getTensor().getType());
1362 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
1363 return emitOpError(message: "incorrect number of indices for extract_element");
1364 return success();
1365}
1366
1367/// If we have an ExtractOp consuming an InsertOp with the same
1368/// indices, we can return the InsertOp's scalar directly.
1369// TODO: This only checks the immediate producer; extend to go up the
1370// insert/extract chain if the slices are disjoint.
1371static Value foldExtractAfterInsert(ExtractOp extractOp) {
1372 auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1373
1374 auto isSame = [](Value a, Value b) {
1375 return getAsOpFoldResult(val: a) == getAsOpFoldResult(val: b);
1376 };
1377 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1378 llvm::equal(LRange: insertOp.getIndices(), RRange: extractOp.getIndices(), P: isSame))
1379 return insertOp.getScalar();
1380
1381 return {};
1382}
1383
1384OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1385 if (Attribute tensor = adaptor.getTensor()) {
1386 // If this is a splat elements attribute, simply return the value.
1387 // All of the elements of a splat attribute are the same.
1388 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(Val&: tensor))
1389 return splatTensor.getSplatValue<Attribute>();
1390
1391 // If this is a dense resource elements attribute, return.
1392 if (isa<DenseResourceElementsAttr>(Val: tensor))
1393 return {};
1394 }
1395
1396 // Collect the constant indices into the tensor.
1397 SmallVector<uint64_t, 8> indices;
1398 for (Attribute indice : adaptor.getIndices()) {
1399 if (!indice || !llvm::isa<IntegerAttr>(Val: indice))
1400 return {};
1401 indices.push_back(Elt: llvm::cast<IntegerAttr>(Val&: indice).getInt());
1402 }
1403
1404 // Fold extract(from_elements(...)).
1405 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1406 auto tensorType = llvm::cast<RankedTensorType>(Val: fromElementsOp.getType());
1407 auto rank = tensorType.getRank();
1408 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
1409 "rank mismatch");
1410 int flatIndex = 0;
1411 int stride = 1;
1412 for (int i = rank - 1; i >= 0; --i) {
1413 flatIndex += indices[i] * stride;
1414 stride *= tensorType.getDimSize(idx: i);
1415 }
1416 // Prevent out of bounds accesses. This can happen in invalid code that
1417 // will never execute.
1418 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1419 flatIndex < 0)
1420 return {};
1421 return fromElementsOp.getElements()[flatIndex];
1422 }
1423
1424 // If this is an elements attribute, query the value at the given indices.
1425 if (Attribute tensor = adaptor.getTensor()) {
1426 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(Val&: tensor);
1427 if (elementsAttr && elementsAttr.isValidIndex(index: indices))
1428 return elementsAttr.getValues<Attribute>()[indices];
1429 }
1430
1431 if (Value result = foldExtractAfterInsert(extractOp: *this))
1432 return result;
1433
1434 return {};
1435}
1436
1437void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1438 MLIRContext *context) {
1439 results.add<ExtractFromTensorCast>(arg&: context);
1440}
1441
1442void mlir::tensor::populateFoldCollapseExtractPatterns(
1443 RewritePatternSet &patterns) {
1444 patterns.add<ExtractFromCollapseShape>(arg: patterns.getContext());
1445}
1446
1447//===----------------------------------------------------------------------===//
1448// FromElementsOp
1449//===----------------------------------------------------------------------===//
1450
1451void FromElementsOp::getAsmResultNames(
1452 function_ref<void(Value, StringRef)> setNameFn) {
1453 setNameFn(getResult(), "from_elements");
1454}
1455
1456void FromElementsOp::build(OpBuilder &builder, OperationState &result,
1457 ValueRange elements) {
1458 assert(!elements.empty() && "expected at least one element");
1459 Type resultType = RankedTensorType::get(
1460 shape: {static_cast<int64_t>(elements.size())}, elementType: elements.front().getType());
1461 build(odsBuilder&: builder, odsState&: result, result: resultType, elements);
1462}
1463
1464OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1465 if (!llvm::is_contained(Range: adaptor.getElements(), Element: nullptr))
1466 return DenseElementsAttr::get(type: getType(), values: adaptor.getElements());
1467 return {};
1468}
1469
1470namespace {
1471
1472// Pushes the index_casts that occur before extractions to after the extract.
1473// This minimizes type conversion in some cases and enables the extract
1474// canonicalizer. This changes:
1475//
1476// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
1477// %extract = tensor.extract %cast[%index] : tensor<1xindex>
1478//
1479// to the following:
1480//
1481// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
1482// %cast = arith.index_cast %extract : i32 to index
1483//
1484// to just %element.
1485//
1486// Consider expanding this to a template and handle all tensor cast
1487// operations.
1488struct ExtractElementFromIndexCast
1489 : public OpRewritePattern<tensor::ExtractOp> {
1490 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1491
1492 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1493 PatternRewriter &rewriter) const final {
1494 Location loc = extract.getLoc();
1495 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1496 if (!indexCast)
1497 return failure();
1498
1499 Type elementTy = getElementTypeOrSelf(val: indexCast.getIn());
1500
1501 auto newExtract = rewriter.create<tensor::ExtractOp>(
1502 location: loc, args&: elementTy, args: indexCast.getIn(), args: extract.getIndices());
1503
1504 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op: extract, args: extract.getType(),
1505 args&: newExtract);
1506
1507 return success();
1508 }
1509};
1510
1511} // namespace
1512
1513void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1514 MLIRContext *context) {
1515 results.add<ExtractElementFromIndexCast>(arg&: context);
1516}
1517
1518//===----------------------------------------------------------------------===//
1519// GatherOp
1520//===----------------------------------------------------------------------===//
1521
1522void GatherOp::getAsmResultNames(
1523 function_ref<void(Value, StringRef)> setNameFn) {
1524 setNameFn(getResult(), "gather");
1525}
1526
1527/// Return the inferred result type for a gatherOp where:
1528/// - sourceType is the type of the source tensor gathered from
1529/// - indicesType is the type of the indices used to gather
1530/// - gatherDims are the dims along which the gather occurs.
1531/// Return a full rank or ranked-reduced variant of the type depending on
1532/// the value of rankReduced.
1533///
1534/// The leading dimensions of the index tensor give the result tensor its
1535/// leading dimensions.
1536/// The trailing dimensions of the result tensor are obtained from the source
1537/// tensor by setting the dimensions specified in gather_dims to `1` (if
1538/// rankedReduced is false), or skipping them (otherwise).
1539RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1540 RankedTensorType indicesType,
1541 ArrayRef<int64_t> gatherDims,
1542 bool rankReduced) {
1543 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1544 resultShape.reserve(N: resultShape.size() + sourceType.getRank());
1545 for (int64_t idx : llvm::seq<int64_t>(Begin: 0, End: sourceType.getRank())) {
1546 if (llvm::binary_search(Range&: gatherDims, Value&: idx)) {
1547 if (!rankReduced)
1548 resultShape.push_back(Elt: 1);
1549 continue;
1550 }
1551 resultShape.push_back(Elt: sourceType.getDimSize(idx));
1552 }
1553 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1554}
1555
1556static LogicalResult
1557verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims,
1558 ArrayRef<int64_t> indices, int64_t rank,
1559 StringRef gatherOrScatter, StringRef sourceOrDest) {
1560 if (dims.empty())
1561 return op->emitOpError(message: gatherOrScatter) << "_dims must be non-empty";
1562
1563 int64_t numGatherDims = dims.size();
1564 if (numGatherDims > rank)
1565 return op->emitOpError(message: gatherOrScatter)
1566 << "_dims overflow " << sourceOrDest << " rank";
1567 if (indices.empty() || indices.back() != numGatherDims)
1568 return op->emitOpError(message: gatherOrScatter)
1569 << "_dims length must match the size of last dimension of indices";
1570 for (int64_t val : dims) {
1571 if (val < 0)
1572 return op->emitOpError(message: gatherOrScatter)
1573 << "_dims value must be non-negative";
1574 if (val >= rank)
1575 return op->emitOpError(message: gatherOrScatter)
1576 << "_dims value must be smaller than " << sourceOrDest << " rank";
1577 }
1578 for (int64_t i = 1; i < numGatherDims; ++i) {
1579 if (dims[i - 1] >= dims[i])
1580 return op->emitOpError(message: gatherOrScatter)
1581 << "_dims values must be strictly increasing";
1582 }
1583 return success();
1584}
1585
1586LogicalResult GatherOp::verify() {
1587 int64_t sourceRank = getSourceType().getRank();
1588 ArrayRef<int64_t> gatherDims = getGatherDims();
1589 if (failed(Result: verifyGatherOrScatterDims(op: getOperation(), dims: gatherDims,
1590 indices: getIndicesType().getShape(), rank: sourceRank,
1591 gatherOrScatter: "gather", sourceOrDest: "source")))
1592 return failure();
1593
1594 RankedTensorType expectedResultType = GatherOp::inferResultType(
1595 sourceType: getSourceType(), indicesType: getIndicesType(), gatherDims, /*rankReduced=*/false);
1596 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1597 sourceType: getSourceType(), indicesType: getIndicesType(), gatherDims, /*rankReduced=*/true);
1598 if (getResultType() != expectedResultType &&
1599 getResultType() != expectedRankReducedResultType) {
1600 return emitOpError(message: "result type "
1601 "mismatch: "
1602 "expected ")
1603 << expectedResultType << " or its rank-reduced variant "
1604 << expectedRankReducedResultType << " (got: " << getResultType()
1605 << ")";
1606 }
1607
1608 return success();
1609}
1610
1611OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1612 if (OpFoldResult reshapedSource = reshapeConstantSource(
1613 source: llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getSource()),
1614 result: getResult().getType()))
1615 return reshapedSource;
1616 return {};
1617}
1618
1619//===----------------------------------------------------------------------===//
1620// InsertOp
1621//===----------------------------------------------------------------------===//
1622
1623void InsertOp::getAsmResultNames(
1624 function_ref<void(Value, StringRef)> setNameFn) {
1625 setNameFn(getResult(), "inserted");
1626}
1627
1628LogicalResult InsertOp::verify() {
1629 // Verify the # indices match if we have a ranked type.
1630 auto destType = llvm::cast<RankedTensorType>(Val: getDest().getType());
1631 if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
1632 return emitOpError(message: "incorrect number of indices");
1633 return success();
1634}
1635
1636OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1637 Attribute scalar = adaptor.getScalar();
1638 Attribute dest = adaptor.getDest();
1639 if (scalar && dest)
1640 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(Val&: dest))
1641 if (scalar == splatDest.getSplatValue<Attribute>())
1642 return dest;
1643 return {};
1644}
1645
1646//===----------------------------------------------------------------------===//
1647// GenerateOp
1648//===----------------------------------------------------------------------===//
1649
1650void GenerateOp::getAsmResultNames(
1651 function_ref<void(Value, StringRef)> setNameFn) {
1652 setNameFn(getResult(), "generated");
1653}
1654
1655LogicalResult GenerateOp::reifyResultShapes(
1656 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1657 reifiedReturnShapes.resize(N: 1, NV: SmallVector<OpFoldResult>(getType().getRank()));
1658 int idx = 0;
1659 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: getType().getRank())) {
1660 if (getType().isDynamicDim(idx: dim)) {
1661 reifiedReturnShapes[0][dim] = getOperand(i: idx++);
1662 } else {
1663 reifiedReturnShapes[0][dim] =
1664 builder.getIndexAttr(value: getType().getDimSize(idx: dim));
1665 }
1666 }
1667 return success();
1668}
1669
1670LogicalResult GenerateOp::verify() {
1671 // Ensure that the tensor type has as many dynamic dimensions as are
1672 // specified by the operands.
1673 RankedTensorType resultType = llvm::cast<RankedTensorType>(Val: getType());
1674 if (getNumOperands() != resultType.getNumDynamicDims())
1675 return emitError(message: "must have as many index operands as dynamic extents "
1676 "in the result type");
1677 return success();
1678}
1679
1680LogicalResult GenerateOp::verifyRegions() {
1681 RankedTensorType resultTy = llvm::cast<RankedTensorType>(Val: getType());
1682 // Ensure that region arguments span the index space.
1683 if (!llvm::all_of(Range: getBody().getArgumentTypes(),
1684 P: [](Type ty) { return ty.isIndex(); }))
1685 return emitError(message: "all body arguments must be index");
1686 if (getBody().getNumArguments() != resultTy.getRank())
1687 return emitError(message: "must have one body argument per input dimension");
1688
1689 // Ensure that the region yields an element of the right type.
1690 auto yieldOp = cast<YieldOp>(Val: getBody().getBlocks().front().getTerminator());
1691
1692 if (yieldOp.getValue().getType() != resultTy.getElementType())
1693 return emitOpError(
1694 message: "body must be terminated with a `yield` operation of the tensor "
1695 "element type");
1696
1697 return success();
1698}
1699
1700void GenerateOp::build(
1701 OpBuilder &b, OperationState &result, Type resultTy,
1702 ValueRange dynamicExtents,
1703 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1704 build(odsBuilder&: b, odsState&: result, result: resultTy, dynamicExtents);
1705
1706 // Build and populate body.
1707 OpBuilder::InsertionGuard guard(b);
1708 Region *bodyRegion = result.regions.front().get();
1709 auto rank = llvm::cast<RankedTensorType>(Val&: resultTy).getRank();
1710 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
1711 SmallVector<Location, 2> argumentLocs(rank, result.location);
1712 Block *bodyBlock =
1713 b.createBlock(parent: bodyRegion, insertPt: bodyRegion->end(), argTypes: argumentTypes, locs: argumentLocs);
1714 bodyBuilder(b, result.location, bodyBlock->getArguments());
1715}
1716
1717namespace {
1718
1719/// Canonicalizes tensor.generate operations with a constant
1720/// operand into the equivalent operation with the operand expressed in the
1721/// result type, instead. We also insert a type cast to make sure that the
1722/// resulting IR is still well-typed.
1723struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1724 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1725
1726 LogicalResult matchAndRewrite(GenerateOp generateOp,
1727 PatternRewriter &rewriter) const final {
1728 SmallVector<Value> foldedDynamicSizes;
1729 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
1730 type: generateOp.getType(), dynamicSizes: generateOp.getDynamicExtents(),
1731 foldedDynamicSizes);
1732
1733 // Stop here if no dynamic size was promoted to static.
1734 if (foldedTensorType == generateOp.getType())
1735 return failure();
1736
1737 auto loc = generateOp.getLoc();
1738 auto newOp =
1739 rewriter.create<GenerateOp>(location: loc, args&: foldedTensorType, args&: foldedDynamicSizes);
1740 rewriter.inlineRegionBefore(region&: generateOp.getBody(), parent&: newOp.getBody(),
1741 before: newOp.getBody().begin());
1742 rewriter.replaceOpWithNewOp<tensor::CastOp>(op: generateOp,
1743 args: generateOp.getType(), args&: newOp);
1744 return success();
1745 }
1746};
1747
1748/// Canonicalizes the pattern of the form
1749///
1750/// %tensor = tensor.generate %x {
1751/// ^bb0(%arg0: index):
1752/// <computation>
1753/// yield %1 : index
1754/// } : tensor<?xindex>
1755/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
1756///
1757/// to just <computation> with %arg0 replaced by %c0. We only do this if the
1758/// tensor.generate operation has no side-effects.
1759struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
1760 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1761
1762 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1763 PatternRewriter &rewriter) const final {
1764 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1765 if (!tensorFromElements || !wouldOpBeTriviallyDead(op: tensorFromElements))
1766 return failure();
1767
1768 IRMapping mapping;
1769 Block *body = &tensorFromElements.getBody().front();
1770 mapping.map(from: body->getArguments(), to: extract.getIndices());
1771 for (auto &op : body->without_terminator())
1772 rewriter.clone(op, mapper&: mapping);
1773
1774 auto yield = cast<YieldOp>(Val: body->getTerminator());
1775
1776 rewriter.replaceOp(op: extract, newValues: mapping.lookupOrDefault(from: yield.getValue()));
1777 return success();
1778 }
1779};
1780
1781} // namespace
1782
1783void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1784 MLIRContext *context) {
1785 // TODO: Move extract pattern to tensor::ExtractOp.
1786 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(arg&: context);
1787}
1788
1789//===----------------------------------------------------------------------===//
1790// RankOp
1791//===----------------------------------------------------------------------===//
1792
1793void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1794 setNameFn(getResult(), "rank");
1795}
1796
1797OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1798 // Constant fold rank when the rank of the operand is known.
1799 auto type = getOperand().getType();
1800 auto shapedType = llvm::dyn_cast<ShapedType>(Val&: type);
1801 if (shapedType && shapedType.hasRank())
1802 return IntegerAttr::get(type: IndexType::get(context: getContext()), value: shapedType.getRank());
1803 return IntegerAttr();
1804}
1805
1806//===----------------------------------------------------------------------===//
1807// ReshapeOp
1808//===----------------------------------------------------------------------===//
1809
1810void ReshapeOp::getAsmResultNames(
1811 function_ref<void(Value, StringRef)> setNameFn) {
1812 setNameFn(getResult(), "reshape");
1813}
1814
1815static int64_t getNumElements(ShapedType type) {
1816 int64_t numElements = 1;
1817 for (auto dim : type.getShape())
1818 numElements *= dim;
1819 return numElements;
1820}
1821
1822LogicalResult ReshapeOp::verify() {
1823 TensorType operandType = llvm::cast<TensorType>(Val: getSource().getType());
1824 TensorType resultType = llvm::cast<TensorType>(Val: getResult().getType());
1825
1826 if (operandType.getElementType() != resultType.getElementType())
1827 return emitOpError(message: "element types of source and destination tensor "
1828 "types should be the same");
1829
1830 int64_t shapeSize =
1831 llvm::cast<RankedTensorType>(Val: getShape().getType()).getDimSize(idx: 0);
1832 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(Val&: resultType);
1833 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(Val&: operandType);
1834
1835 if (resultRankedType) {
1836 if (operandRankedType && resultRankedType.hasStaticShape() &&
1837 operandRankedType.hasStaticShape()) {
1838 if (getNumElements(type: operandRankedType) != getNumElements(type: resultRankedType))
1839 return emitOpError(message: "source and destination tensor should have the "
1840 "same number of elements");
1841 }
1842 if (ShapedType::isDynamic(dValue: shapeSize))
1843 return emitOpError(message: "cannot use shape operand with dynamic length to "
1844 "reshape to statically-ranked tensor type");
1845 if (shapeSize != resultRankedType.getRank())
1846 return emitOpError(
1847 message: "length of shape operand differs from the result's tensor rank");
1848 }
1849 return success();
1850}
1851
1852OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1853 if (OpFoldResult reshapedSource = reshapeConstantSource(
1854 source: llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getSource()),
1855 result: getResult().getType()))
1856 return reshapedSource;
1857
1858 // If the producer of operand 'source' is another 'tensor.reshape' op, use the
1859 // producer's input instead as the original tensor to reshape. This could
1860 // render such producer dead code.
1861 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1862 getSourceMutable().assign(value: reshapeOpProducer.getSource());
1863 return getResult();
1864 }
1865
1866 auto source = getSource();
1867 auto sourceTy = dyn_cast<RankedTensorType>(Val: source.getType());
1868 auto resultTy = dyn_cast<RankedTensorType>(Val: getType());
1869 if (!sourceTy || !resultTy || sourceTy != resultTy)
1870 return {};
1871
1872 // If the source and result are both 0D or 1D tensors and have the same type,
1873 // the reshape has no effect, even if the tensor is dynamically shaped.
1874 if (sourceTy.getRank() <= 1)
1875 return source;
1876
1877 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1878 auto elements = fromElements.getElements();
1879 bool dynamicNoop =
1880 sourceTy.getRank() == static_cast<int64_t>(elements.size());
1881 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1882 auto element = elements[id];
1883
1884 if (auto cst = getConstantIntValue(ofr: element)) {
1885 dynamicNoop &= cst.value() == sourceTy.getDimSize(idx: id);
1886 continue;
1887 }
1888
1889 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1890 dynamicNoop &= dimOp.getSource() == source;
1891
1892 auto cst = getConstantIntValue(ofr: dimOp.getIndex());
1893 dynamicNoop &=
1894 cst.has_value() && cst.value() == static_cast<int64_t>(id);
1895 continue;
1896 }
1897
1898 dynamicNoop = false;
1899 break;
1900 }
1901
1902 if (dynamicNoop)
1903 return source;
1904 }
1905
1906 return {};
1907}
1908
1909//===----------------------------------------------------------------------===//
1910// Reassociative reshape ops
1911//===----------------------------------------------------------------------===//
1912
1913void CollapseShapeOp::getAsmResultNames(
1914 function_ref<void(Value, StringRef)> setNameFn) {
1915 setNameFn(getResult(), "collapsed");
1916}
1917
1918void ExpandShapeOp::getAsmResultNames(
1919 function_ref<void(Value, StringRef)> setNameFn) {
1920 setNameFn(getResult(), "expanded");
1921}
1922
1923int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1924 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1925 "invalid resultDim");
1926 for (const auto &it : llvm::enumerate(First: getReassociationIndices()))
1927 if (llvm::is_contained(Range&: it.value(), Element: resultDim))
1928 return it.index();
1929 llvm_unreachable("could not find reassociation group");
1930}
1931
1932FailureOr<SmallVector<OpFoldResult>>
1933ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
1934 RankedTensorType expandedType,
1935 ArrayRef<ReassociationIndices> reassociation,
1936 ArrayRef<OpFoldResult> inputShape) {
1937 std::optional<SmallVector<OpFoldResult>> outputShape =
1938 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
1939 inputShape);
1940 if (!outputShape)
1941 return failure();
1942 return *outputShape;
1943}
1944
1945SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1946 return getMixedValues(staticValues: getStaticOutputShape(), dynamicValues: getOutputShape(), context: getContext());
1947}
1948
1949void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1950 Type resultType, Value src,
1951 ArrayRef<ReassociationIndices> reassociation,
1952 ArrayRef<OpFoldResult> outputShape) {
1953 auto [staticOutputShape, dynamicOutputShape] =
1954 decomposeMixedValues(mixedValues: SmallVector<OpFoldResult>(outputShape));
1955 build(odsBuilder&: builder, odsState&: result, result: cast<RankedTensorType>(Val&: resultType), src,
1956 reassociation: getReassociationIndicesAttribute(b&: builder, reassociation),
1957 output_shape: dynamicOutputShape, static_output_shape: staticOutputShape);
1958}
1959
1960void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
1961 Type resultType, Value src,
1962 ArrayRef<ReassociationIndices> reassociation) {
1963 SmallVector<OpFoldResult> inputShape =
1964 getMixedSizes(builder, loc: result.location, value: src);
1965 auto tensorResultTy = cast<RankedTensorType>(Val&: resultType);
1966 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1967 b&: builder, loc: result.location, expandedType: tensorResultTy, reassociation, inputShape);
1968 SmallVector<OpFoldResult> outputShapeOrEmpty;
1969 if (succeeded(Result: outputShape)) {
1970 outputShapeOrEmpty = *outputShape;
1971 }
1972 build(builder, result, resultType: tensorResultTy, src, reassociation,
1973 outputShape: outputShapeOrEmpty);
1974}
1975
1976SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1977 return getSymbolLessAffineMaps(reassociation: getReassociationExprs());
1978}
1979SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1980 return convertReassociationIndicesToExprs(context: getContext(),
1981 reassociationIndices: getReassociationIndices());
1982}
1983
1984SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1985 return getSymbolLessAffineMaps(reassociation: getReassociationExprs());
1986}
1987SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1988 return convertReassociationIndicesToExprs(context: getContext(),
1989 reassociationIndices: getReassociationIndices());
1990}
1991
1992RankedTensorType CollapseShapeOp::inferCollapsedType(
1993 RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1994 return inferCollapsedType(
1995 type, reassociation: getSymbolLessAffineMaps(reassociation: convertReassociationIndicesToExprs(
1996 context: type.getContext(), reassociationIndices: reassociation)));
1997}
1998
1999/// Compute the RankedTensorType obtained by applying `reassociation` to
2000/// `type`.
2001RankedTensorType
2002CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2003 ArrayRef<AffineMap> reassociation) {
2004 auto shape = type.getShape();
2005 SmallVector<int64_t, 4> newShape;
2006 newShape.reserve(N: reassociation.size());
2007
2008 // Use the fact that reassociation is valid to simplify the logic: only use
2009 // each map's rank.
2010 assert(isReassociationValid(reassociation) && "invalid reassociation");
2011 unsigned currentDim = 0;
2012 for (AffineMap m : reassociation) {
2013 unsigned dim = m.getNumResults();
2014 auto band = shape.slice(N: currentDim, M: dim);
2015 int64_t size = 1;
2016 if (llvm::is_contained(Range&: band, Element: ShapedType::kDynamic))
2017 size = ShapedType::kDynamic;
2018 else
2019 for (unsigned d = 0; d < dim; ++d)
2020 size *= shape[currentDim + d];
2021 newShape.push_back(Elt: size);
2022 currentDim += dim;
2023 }
2024
2025 return RankedTensorType::get(shape: newShape, elementType: type.getElementType());
2026}
2027
2028void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2029 ArrayRef<ReassociationIndices> reassociation,
2030 ArrayRef<NamedAttribute> attrs) {
2031 auto resultType = inferCollapsedType(
2032 type: llvm::cast<RankedTensorType>(Val: src.getType()),
2033 reassociation: getSymbolLessAffineMaps(
2034 reassociation: convertReassociationIndicesToExprs(context: b.getContext(), reassociationIndices: reassociation)));
2035 result.addAttribute(name: getReassociationAttrStrName(),
2036 attr: getReassociationIndicesAttribute(b, reassociation));
2037 build(b, odsState&: result, resultTypes: resultType, operands: src, attributes: attrs);
2038}
2039
2040template <typename TensorReshapeOp, bool isExpansion = std::is_same<
2041 TensorReshapeOp, ExpandShapeOp>::value>
2042static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
2043 RankedTensorType expandedType,
2044 RankedTensorType collapsedType) {
2045 if (failed(
2046 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2047 return failure();
2048
2049 auto maps = op.getReassociationMaps();
2050 RankedTensorType expectedType =
2051 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2052 if (!isSameTypeWithoutEncoding(tp1: collapsedType, tp2: expectedType))
2053 return op.emitOpError("expected collapsed type to be ")
2054 << expectedType << ", but got " << collapsedType;
2055 return success();
2056}
2057
2058LogicalResult ExpandShapeOp::verify() {
2059 auto srcType = getSrcType();
2060 auto resultType = getResultType();
2061
2062 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2063 return emitOpError(message: "expected number of static shape dims to be equal to "
2064 "the output rank (")
2065 << resultType.getRank() << ") but found "
2066 << getStaticOutputShape().size() << " inputs instead";
2067
2068 if ((int64_t)getOutputShape().size() !=
2069 llvm::count(Range: getStaticOutputShape(), Element: ShapedType::kDynamic))
2070 return emitOpError(message: "mismatch in dynamic dims in output_shape and "
2071 "static_output_shape: static_output_shape has ")
2072 << llvm::count(Range: getStaticOutputShape(), Element: ShapedType::kDynamic)
2073 << " dynamic dims while output_shape has " << getOutputShape().size()
2074 << " values";
2075
2076 return verifyTensorReshapeOp(op: *this, expandedType: resultType, collapsedType: srcType);
2077}
2078
2079LogicalResult CollapseShapeOp::verify() {
2080 return verifyTensorReshapeOp(op: *this, expandedType: getSrcType(), collapsedType: getResultType());
2081}
2082
2083namespace {
2084/// Reshape of a splat constant can be replaced with a constant of the result
2085/// type.
2086template <typename TensorReshapeOp>
2087struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2088 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2089 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2090 PatternRewriter &rewriter) const override {
2091 DenseElementsAttr attr;
2092 if (!matchPattern(reshapeOp.getSrc(), m_Constant(bind_value: &attr)))
2093 return failure();
2094 if (!attr || !attr.isSplat())
2095 return failure();
2096 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
2097 type: reshapeOp.getResultType(), rawBuffer: attr.getRawData());
2098 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
2099 return success();
2100 }
2101};
2102
2103// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
2104template <typename TensorReshapeOp>
2105class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
2106public:
2107 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2108
2109 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2110 PatternRewriter &rewriter) const override {
2111 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2112 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2113 return failure();
2114
2115 rewriter.replaceOpWithNewOp<tensor::SplatOp>(
2116 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2117 return success();
2118 }
2119};
2120
2121/// Reshape of a FromElements can be replaced with a FromElements of the
2122/// result type
2123template <typename TensorReshapeOp>
2124struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2125 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2126 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2127 PatternRewriter &rewriter) const override {
2128 auto fromElements =
2129 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2130 if (!fromElements)
2131 return failure();
2132
2133 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2134
2135 if (!shapedTy.hasStaticShape())
2136 return failure();
2137
2138 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
2139 fromElements.getElements());
2140 return success();
2141 }
2142};
2143
2144// Fold CastOp into CollapseShapeOp when adding static information.
2145struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
2146 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2147
2148 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2149 PatternRewriter &rewriter) const override {
2150 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2151 if (!tensor::canFoldIntoConsumerOp(castOp))
2152 return failure();
2153
2154 RankedTensorType srcType =
2155 llvm::cast<RankedTensorType>(Val: castOp.getSource().getType());
2156 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2157 type: srcType, reassociation: collapseShapeOp.getReassociationMaps());
2158
2159 if (newResultType == collapseShapeOp.getResultType()) {
2160 rewriter.modifyOpInPlace(root: collapseShapeOp, callable: [&]() {
2161 collapseShapeOp.getSrcMutable().assign(value: castOp.getSource());
2162 });
2163 } else {
2164 auto newOp = rewriter.create<CollapseShapeOp>(
2165 location: collapseShapeOp.getLoc(), args&: newResultType, args: castOp.getSource(),
2166 args: collapseShapeOp.getReassociation());
2167 rewriter.replaceOpWithNewOp<tensor::CastOp>(
2168 op: collapseShapeOp, args: collapseShapeOp.getResultType(), args&: newOp);
2169 }
2170 return success();
2171 }
2172};
2173
2174/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
2175/// matching constant output_shape operands of the expand. This makes the
2176/// `tensor.expand_shape` more static and creates a consumer cast that can be
2177/// propagated further.
2178struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
2179 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2180
2181 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2182 PatternRewriter &rewriter) const override {
2183 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2184 if (!canFoldIntoConsumerOp(castOp))
2185 return failure();
2186
2187 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2188 SmallVector<ReassociationIndices, 4> reassoc =
2189 expandOp.getReassociationIndices();
2190
2191 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2192 SmallVector<Value> dynamicOutputShape;
2193 auto outputIt = expandOp.getOutputShape().begin();
2194
2195 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(First&: reassoc)) {
2196 for (uint64_t outDim : innerReassoc) {
2197 if (ShapedType::isStatic(dValue: newOutputShape[outDim]))
2198 continue;
2199
2200 // If the cast's src type is dynamic, don't infer any of the
2201 // corresponding expanded dimensions. `tensor.expand_shape` requires at
2202 // least one of the expanded dimensions to be dynamic if the input is
2203 // dynamic.
2204 Value val = *outputIt;
2205 ++outputIt;
2206 if (ShapedType::isDynamic(dValue: castSrcShape[inputDim])) {
2207 dynamicOutputShape.push_back(Elt: val);
2208 continue;
2209 }
2210
2211 APInt cst;
2212 if (matchPattern(value: val, pattern: m_ConstantInt(bind_value: &cst))) {
2213 newOutputShape[outDim] = cst.getSExtValue();
2214 } else {
2215 dynamicOutputShape.push_back(Elt: val);
2216 }
2217 }
2218 }
2219
2220 // Couldn't match any values, nothing to change
2221 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2222 return failure();
2223
2224 // Calculate the input shape from the output
2225 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2226 for (auto inDim : llvm::seq<int>(Begin: 0, End: newInputShape.size())) {
2227 for (auto outDim : reassoc[inDim]) {
2228 auto ofr = newOutputShape[outDim];
2229 if (ShapedType::isDynamic(dValue: ofr)) {
2230 newInputShape[inDim] = ShapedType::kDynamic;
2231 break;
2232 }
2233 newInputShape[inDim] *= ofr;
2234 }
2235 }
2236
2237 SmallVector<OpFoldResult> outputOfr =
2238 getMixedValues(staticValues: newOutputShape, dynamicValues: dynamicOutputShape, b&: rewriter);
2239 auto inputType = RankedTensorType::get(
2240 shape: newInputShape, elementType: expandOp.getSrcType().getElementType());
2241 auto outputType = RankedTensorType::get(
2242 shape: newOutputShape, elementType: expandOp.getSrcType().getElementType());
2243 auto inputCast = rewriter.create<CastOp>(location: expandOp.getLoc(), args&: inputType,
2244 args: expandOp.getSrc());
2245 auto newExpand = rewriter.create<ExpandShapeOp>(
2246 location: expandOp.getLoc(), args&: outputType, args: inputCast.getResult(),
2247 args: expandOp.getReassociationIndices(), args&: outputOfr);
2248 rewriter.replaceOpWithNewOp<CastOp>(op: expandOp, args: expandOp.getType(),
2249 args: newExpand.getResult());
2250 return success();
2251 }
2252};
2253} // namespace
2254
2255void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2256 MLIRContext *context) {
2257 results.add<
2258 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2259 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
2260 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2261 FoldReshapeWithSplat<ExpandShapeOp>,
2262 FoldReshapeWithFromElements<ExpandShapeOp>>(arg&: context);
2263}
2264
2265void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2266 MLIRContext *context) {
2267 results.add<
2268 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2269 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2270 tensor::DimOp, RankedTensorType>,
2271 FoldReshapeWithConstant<CollapseShapeOp>,
2272 FoldReshapeWithSplat<CollapseShapeOp>,
2273 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2274 arg&: context);
2275}
2276
2277OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2278 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(reshapeOp: *this,
2279 operands: adaptor.getOperands());
2280}
2281
2282OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2283 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(reshapeOp: *this,
2284 operands: adaptor.getOperands());
2285}
2286
2287//===----------------------------------------------------------------------===//
2288// ExtractSliceOp
2289//===----------------------------------------------------------------------===//
2290
2291void ExtractSliceOp::getAsmResultNames(
2292 function_ref<void(Value, StringRef)> setNameFn) {
2293 setNameFn(getResult(), "extracted_slice");
2294}
2295
2296/// An extract_slice result type can be inferred, when it is not
2297/// rank-reduced, from the source type and the static representation of
2298/// offsets, sizes and strides. Special sentinels encode the dynamic case.
2299RankedTensorType ExtractSliceOp::inferResultType(
2300 RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2301 ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2302 // An extract_slice op may specify only a leading subset of offset/sizes/
2303 // strides in which case we complete with offset=0, sizes from memref type
2304 // and strides=1.
2305 assert(static_cast<int64_t>(staticSizes.size()) ==
2306 sourceTensorType.getRank() &&
2307 "unexpected staticSizes not equal to rank of source");
2308 return RankedTensorType::get(shape: staticSizes, elementType: sourceTensorType.getElementType(),
2309 encoding: sourceTensorType.getEncoding());
2310}
2311
2312RankedTensorType ExtractSliceOp::inferResultType(
2313 RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2314 ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
2315 SmallVector<int64_t> staticSizes;
2316 std::tie(args&: staticSizes, args: std::ignore) = decomposeMixedValues(mixedValues: sizes);
2317 assert(static_cast<int64_t>(staticSizes.size()) ==
2318 sourceTensorType.getRank() &&
2319 "unexpected staticSizes not equal to rank of source");
2320 return RankedTensorType::get(shape: staticSizes, elementType: sourceTensorType.getElementType(),
2321 encoding: sourceTensorType.getEncoding());
2322}
2323
2324/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
2325/// number of sizes), drop as many size 1 as needed to produce an inferred
2326/// type with the desired rank.
2327///
2328/// Note that there may be multiple ways to compute this rank-reduced type:
2329/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
2330///
2331/// To disambiguate, this function always drops the first 1 sizes occurrences.
2332RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2333 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2334 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2335 ArrayRef<int64_t> strides) {
2336 // Type inferred in the absence of rank-reducing behavior.
2337 auto inferredType = llvm::cast<RankedTensorType>(
2338 Val: inferResultType(sourceTensorType: sourceRankedTensorType, staticOffsets: offsets, staticSizes: sizes, staticStrides: strides));
2339 int rankDiff = inferredType.getRank() - desiredResultRank;
2340 if (rankDiff > 0) {
2341 auto shape = inferredType.getShape();
2342 llvm::SmallBitVector dimsToProject =
2343 getPositionsOfShapeOne(rank: rankDiff, shape);
2344 SmallVector<int64_t> projectedShape;
2345 // Best effort rank-reducing: drop 1s in order.
2346 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2347 if (!dimsToProject.test(Idx: pos))
2348 projectedShape.push_back(Elt: shape[pos]);
2349 inferredType =
2350 RankedTensorType::get(shape: projectedShape, elementType: inferredType.getElementType());
2351 }
2352 return inferredType;
2353}
2354
2355RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2356 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2357 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2358 ArrayRef<OpFoldResult> strides) {
2359 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2360 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2361 dispatchIndexOpFoldResults(ofrs: offsets, dynamicVec&: dynamicOffsets, staticVec&: staticOffsets);
2362 dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynamicSizes, staticVec&: staticSizes);
2363 dispatchIndexOpFoldResults(ofrs: strides, dynamicVec&: dynamicStrides, staticVec&: staticStrides);
2364 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2365 desiredResultRank, sourceRankedTensorType, offsets: staticOffsets, sizes: staticSizes,
2366 strides: staticStrides);
2367}
2368
2369/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
2370/// result type. If the type passed is nullptr, it is inferred.
2371void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2372 RankedTensorType resultType, Value source,
2373 ArrayRef<OpFoldResult> offsets,
2374 ArrayRef<OpFoldResult> sizes,
2375 ArrayRef<OpFoldResult> strides,
2376 ArrayRef<NamedAttribute> attrs) {
2377 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2378 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2379 dispatchIndexOpFoldResults(ofrs: offsets, dynamicVec&: dynamicOffsets, staticVec&: staticOffsets);
2380 dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynamicSizes, staticVec&: staticSizes);
2381 dispatchIndexOpFoldResults(ofrs: strides, dynamicVec&: dynamicStrides, staticVec&: staticStrides);
2382 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(Val: source.getType());
2383 // Structuring implementation this way avoids duplication between builders.
2384 if (!resultType) {
2385 resultType = llvm::cast<RankedTensorType>(Val: ExtractSliceOp::inferResultType(
2386 sourceTensorType: sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2387 }
2388 result.addAttributes(newAttributes: attrs);
2389 build(odsBuilder&: b, odsState&: result, result: resultType, source, offsets: dynamicOffsets, sizes: dynamicSizes,
2390 strides: dynamicStrides, static_offsets: b.getDenseI64ArrayAttr(values: staticOffsets),
2391 static_sizes: b.getDenseI64ArrayAttr(values: staticSizes),
2392 static_strides: b.getDenseI64ArrayAttr(values: staticStrides));
2393}
2394
2395/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
2396/// result type.
2397void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2398 ArrayRef<OpFoldResult> offsets,
2399 ArrayRef<OpFoldResult> sizes,
2400 ArrayRef<OpFoldResult> strides,
2401 ArrayRef<NamedAttribute> attrs) {
2402 build(b, result, resultType: RankedTensorType(), source, offsets, sizes, strides, attrs);
2403}
2404
2405/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
2406/// a Range vector.
2407void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2408 ArrayRef<Range> ranges,
2409 ArrayRef<NamedAttribute> attrs) {
2410 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2411 build(b, result, resultType: RankedTensorType(), source, offsets, sizes, strides, attrs);
2412}
2413
2414/// Build an ExtractSliceOp with dynamic entries and custom result type. If
2415/// the type passed is nullptr, it is inferred.
2416void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
2417 RankedTensorType resultType, Value source,
2418 ValueRange offsets, ValueRange sizes,
2419 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2420 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2421 Range: llvm::map_range(C&: offsets, F: [](Value v) -> OpFoldResult { return v; }));
2422 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2423 Range: llvm::map_range(C&: sizes, F: [](Value v) -> OpFoldResult { return v; }));
2424 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2425 Range: llvm::map_range(C&: strides, F: [](Value v) -> OpFoldResult { return v; }));
2426 build(b, result, resultType, source, offsets: offsetValues, sizes: sizeValues, strides: strideValues);
2427}
2428
2429/// Build an ExtractSliceOp with dynamic entries and inferred result type.
2430void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2431 ValueRange offsets, ValueRange sizes,
2432 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2433 build(b, result, resultType: RankedTensorType(), source, offsets, sizes, strides, attrs);
2434}
2435
2436static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2437 Operation *op,
2438 RankedTensorType expectedType) {
2439 switch (result) {
2440 case SliceVerificationResult::Success:
2441 return success();
2442 case SliceVerificationResult::RankTooLarge:
2443 return op->emitError(message: "expected rank to be smaller or equal to ")
2444 << "the other rank. ";
2445 case SliceVerificationResult::SizeMismatch:
2446 return op->emitError(message: "expected type to be ")
2447 << expectedType << " or a rank-reduced version. (size mismatch) ";
2448 case SliceVerificationResult::ElemTypeMismatch:
2449 return op->emitError(message: "expected element type to be ")
2450 << expectedType.getElementType();
2451 default:
2452 llvm_unreachable("unexpected extract_slice op verification result");
2453 }
2454}
2455
2456/// Verifier for ExtractSliceOp.
2457LogicalResult ExtractSliceOp::verify() {
2458 RankedTensorType sourceType = getSourceType();
2459
2460 // Verify result type against inferred type.
2461 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2462 sourceTensorType: sourceType, offsets: getMixedOffsets(), sizes: getMixedSizes(), strides: getMixedStrides());
2463 SliceVerificationResult result = isRankReducedType(originalType: expectedType, candidateReducedType: getType());
2464 if (result != SliceVerificationResult::Success)
2465 return produceSliceErrorMsg(result, op: *this, expectedType);
2466
2467 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2468 // to the source tensor.
2469 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2470 shape: sourceType.getShape(), staticOffsets: getStaticOffsets(), staticSizes: getStaticSizes(),
2471 staticStrides: getStaticStrides(), /*generateErrorMessage=*/true);
2472 if (!boundsResult.isValid)
2473 return getOperation()->emitError(message: boundsResult.errorMessage);
2474
2475 return success();
2476}
2477
2478llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2479 return ::getDroppedDims(reducedShape: getType().getShape(), mixedSizes: getMixedSizes());
2480}
2481
2482FailureOr<Value>
2483ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
2484 ArrayRef<int64_t> desiredShape) {
2485 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(Val: value.getType());
2486 assert(sourceTensorType && "not a ranked tensor type");
2487 auto sourceShape = sourceTensorType.getShape();
2488 if (sourceShape.equals(RHS: desiredShape))
2489 return value;
2490 auto maybeRankReductionMask =
2491 mlir::computeRankReductionMask(originalShape: sourceShape, reducedShape: desiredShape);
2492 if (!maybeRankReductionMask)
2493 return failure();
2494 return createCanonicalRankReducingExtractSliceOp(
2495 b, loc, tensor: value,
2496 targetType: RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2497}
2498
2499LogicalResult ExtractSliceOp::reifyResultShapes(
2500 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2501 reifiedReturnShapes.resize(N: 1);
2502 reifiedReturnShapes[0].reserve(N: getType().getRank());
2503 SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2504 llvm::SmallBitVector droppedDims = getDroppedDims();
2505 for (const auto &size : enumerate(First&: mixedSizes)) {
2506 if (droppedDims.test(Idx: size.index()))
2507 continue;
2508 reifiedReturnShapes[0].push_back(Elt: size.value());
2509 }
2510 return success();
2511}
2512
2513namespace {
2514/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
2515/// This essentially pushes memref_cast past its consuming slice when
2516/// `canFoldIntoConsumerOp` is true.
2517///
2518/// Example:
2519/// ```
2520/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
2521/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
2522/// tensor<3x4xf32>
2523/// ```
2524/// is rewritten into:
2525/// ```
2526/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
2527/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
2528/// ```
2529class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2530public:
2531 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2532
2533 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2534 PatternRewriter &rewriter) const override {
2535 // Any constant operand, just return to let the constant folder kick in.
2536 if (llvm::any_of(Range: sliceOp.getOperands(), P: [](Value operand) {
2537 return matchPattern(value: operand, pattern: matchConstantIndex());
2538 }))
2539 return failure();
2540
2541 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2542 if (!castOp)
2543 return failure();
2544
2545 if (!canFoldIntoConsumerOp(castOp))
2546 return failure();
2547
2548 // Pattern does not apply if the produced op would not verify.
2549 SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2550 shape: cast<RankedTensorType>(Val: castOp.getSource().getType()).getShape(),
2551 staticOffsets: sliceOp.getStaticOffsets(), staticSizes: sliceOp.getStaticSizes(),
2552 staticStrides: sliceOp.getStaticStrides());
2553 if (!sliceResult.isValid)
2554 return failure();
2555
2556 // Create folded extract.
2557 Location loc = sliceOp.getLoc();
2558 Value newResult = rewriter.create<ExtractSliceOp>(
2559 location: loc, args: sliceOp.getType(), args: castOp.getSource(), args: sliceOp.getOffsets(),
2560 args: sliceOp.getSizes(), args: sliceOp.getStrides(), args: sliceOp.getStaticOffsets(),
2561 args: sliceOp.getStaticSizes(), args: sliceOp.getStaticStrides());
2562 rewriter.replaceOp(op: sliceOp, newValues: newResult);
2563 return success();
2564 }
2565};
2566
2567/// Slice elements from `values` into `outValues`. `counts` represents the
2568/// numbers of elements to stride in the original values for each dimension.
2569/// The output values can be used to construct a DenseElementsAttr.
2570template <typename IterTy, typename ElemTy>
2571static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2572 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2573 ArrayRef<int64_t> strides,
2574 llvm::SmallVectorImpl<ElemTy> *outValues) {
2575 assert(offsets.size() == sizes.size());
2576 assert(offsets.size() == strides.size());
2577 if (offsets.empty())
2578 return;
2579
2580 int64_t offset = offsets.front();
2581 int64_t size = sizes.front();
2582 int64_t stride = strides.front();
2583 if (offsets.size() == 1) {
2584 for (int64_t i = 0; i < size; ++i, offset += stride)
2585 outValues->push_back(*(values + offset));
2586
2587 return;
2588 }
2589
2590 for (int64_t i = 0; i < size; ++i, offset += stride) {
2591 auto begin = values + offset * counts.front();
2592 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2593 offsets.drop_front(), sizes.drop_front(),
2594 strides.drop_front(), outValues);
2595 }
2596}
2597
2598/// Fold arith.constant and tensor.extract_slice into arith.constant. The
2599/// folded operation might introduce more constant data; Users can control
2600/// their heuristics by the control function.
2601class ConstantOpExtractSliceFolder final
2602 : public OpRewritePattern<ExtractSliceOp> {
2603public:
2604 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2605
2606 ConstantOpExtractSliceFolder(MLIRContext *context,
2607 ControlConstantExtractSliceFusionFn controlFn)
2608 : OpRewritePattern<ExtractSliceOp>(context),
2609 controlFn(std::move(controlFn)) {}
2610
2611 LogicalResult matchAndRewrite(ExtractSliceOp op,
2612 PatternRewriter &rewriter) const override {
2613 DenseElementsAttr attr;
2614 if (!matchPattern(value: op.getSource(), pattern: m_Constant(bind_value: &attr)))
2615 return failure();
2616
2617 // A constant splat is handled by fold().
2618 if (attr.isSplat())
2619 return failure();
2620
2621 // Dynamic result shape is not supported.
2622 auto sourceType = llvm::cast<ShapedType>(Val: op.getSource().getType());
2623 auto resultType = llvm::cast<ShapedType>(Val: op.getResult().getType());
2624 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2625 return failure();
2626
2627 // Customized control over the folding.
2628 if (!controlFn(op))
2629 return failure();
2630
2631 int64_t count = sourceType.getNumElements();
2632 if (count == 0)
2633 return failure();
2634
2635 // Check if there are any dynamic parts, which are not supported.
2636 auto offsets = op.getStaticOffsets();
2637 if (llvm::is_contained(Range&: offsets, Element: ShapedType::kDynamic))
2638 return failure();
2639 auto sizes = op.getStaticSizes();
2640 if (llvm::is_contained(Range&: sizes, Element: ShapedType::kDynamic))
2641 return failure();
2642 auto strides = op.getStaticStrides();
2643 if (llvm::is_contained(Range&: strides, Element: ShapedType::kDynamic))
2644 return failure();
2645
2646 // Compute the stride for each dimension.
2647 SmallVector<int64_t> counts;
2648 ArrayRef<int64_t> shape = sourceType.getShape();
2649 counts.reserve(N: shape.size());
2650 for (int64_t v : shape) {
2651 count = count / v;
2652 counts.push_back(Elt: count);
2653 }
2654
2655 // New attribute constructed by the sliced values.
2656 DenseElementsAttr newAttr;
2657
2658 if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(Val&: attr)) {
2659 SmallVector<APInt> outValues;
2660 outValues.reserve(N: sourceType.getNumElements());
2661 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2662 values: elems.begin(), counts, offsets, sizes, strides, outValues: &outValues);
2663 newAttr = DenseElementsAttr::get(type: resultType, values: outValues);
2664 } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(Val&: attr)) {
2665 SmallVector<APFloat> outValues;
2666 outValues.reserve(N: sourceType.getNumElements());
2667 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2668 values: elems.begin(), counts, offsets, sizes, strides, outValues: &outValues);
2669 newAttr = DenseElementsAttr::get(type: resultType, values: outValues);
2670 }
2671
2672 if (newAttr) {
2673 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, args&: resultType, args&: newAttr);
2674 return success();
2675 }
2676
2677 return failure();
2678 }
2679
2680private:
2681 /// This additionally controls whether the fold happens or not. Users can
2682 /// impose their heuristics in the function.
2683 ControlConstantExtractSliceFusionFn controlFn;
2684};
2685
2686} // namespace
2687
2688void mlir::tensor::populateFoldConstantExtractSlicePatterns(
2689 RewritePatternSet &patterns,
2690 const ControlConstantExtractSliceFusionFn &controlFn) {
2691 patterns.add<ConstantOpExtractSliceFolder>(arg: patterns.getContext(), args: controlFn);
2692}
2693
2694/// Return the canonical type of the result of an extract_slice op.
2695struct SliceReturnTypeCanonicalizer {
2696 RankedTensorType operator()(ExtractSliceOp op,
2697 ArrayRef<OpFoldResult> mixedOffsets,
2698 ArrayRef<OpFoldResult> mixedSizes,
2699 ArrayRef<OpFoldResult> mixedStrides) {
2700 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2701 desiredResultRank: op.getType().getRank(), sourceRankedTensorType: op.getSourceType(), offsets: mixedOffsets, sizes: mixedSizes,
2702 strides: mixedStrides);
2703 }
2704};
2705
2706/// A canonicalizer wrapper to replace ExtractSliceOps.
2707struct SliceCanonicalizer {
2708 void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
2709 ExtractSliceOp newOp) {
2710 Value replacement = newOp.getResult();
2711 if (replacement.getType() != op.getType())
2712 replacement = rewriter.create<tensor::CastOp>(location: op.getLoc(), args: op.getType(),
2713 args&: replacement);
2714 rewriter.replaceOp(op, newValues: replacement);
2715 }
2716};
2717
2718void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2719 MLIRContext *context) {
2720 results.add<
2721 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2722 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2723 ExtractSliceOpCastFolder>(arg&: context);
2724}
2725
2726//
2727static LogicalResult
2728foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
2729 ShapedType shapedType) {
2730 OpBuilder b(op.getContext());
2731 for (OpFoldResult ofr : op.getMixedOffsets())
2732 if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
2733 return failure();
2734 // Rank-reducing noops only need to inspect the leading dimensions:
2735 // llvm::zip is appropriate.
2736 auto shape = shapedType.getShape();
2737 for (auto it : llvm::zip(t: op.getMixedSizes(), u&: shape))
2738 if (getConstantIntValue(ofr: std::get<0>(t&: it)) != std::get<1>(t&: it))
2739 return failure();
2740 for (OpFoldResult ofr : op.getMixedStrides())
2741 if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
2742 return failure();
2743 return success();
2744}
2745
2746/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
2747/// slice, we can return the InsertSliceOp's source directly.
2748// TODO: This only checks the immediate producer; extend to go up the
2749// insert/extract chain if the slices are disjoint.
2750static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
2751 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2752
2753 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2754 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2755 insertOp.isSameAs(other: extractOp, cmp: isSame))
2756 return insertOp.getSource();
2757
2758 return {};
2759}
2760
2761OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2762 if (OpFoldResult reshapedSource = reshapeConstantSource(
2763 source: llvm::dyn_cast_if_present<SplatElementsAttr>(Val: adaptor.getSource()),
2764 result: getResult().getType()))
2765 return reshapedSource;
2766 if (getSourceType() == getType() &&
2767 succeeded(Result: foldIdentityOffsetSizeAndStrideOpInterface(op: *this, shapedType: getType())))
2768 return this->getSource();
2769 if (Value slice = foldExtractAfterInsertSlice(extractOp: *this))
2770 return slice;
2771
2772 return OpFoldResult();
2773}
2774
2775Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
2776 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
2777 auto rankedTensorType = llvm::cast<RankedTensorType>(Val: tensor.getType());
2778 unsigned rank = rankedTensorType.getRank();
2779 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(value: 0));
2780 SmallVector<OpFoldResult> sizes = getMixedSizes(builder&: b, loc, value: tensor);
2781 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(value: 1));
2782 return b.createOrFold<tensor::ExtractSliceOp>(location: loc, args&: targetType, args&: tensor,
2783 args&: offsets, args&: sizes, args&: strides);
2784}
2785
2786//===----------------------------------------------------------------------===//
2787// InsertSliceOp
2788//===----------------------------------------------------------------------===//
2789
2790void InsertSliceOp::getAsmResultNames(
2791 function_ref<void(Value, StringRef)> setNameFn) {
2792 setNameFn(getResult(), "inserted_slice");
2793}
2794
2795// Build a InsertSliceOp with mixed static and dynamic entries.
2796void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2797 Value dest, ArrayRef<OpFoldResult> offsets,
2798 ArrayRef<OpFoldResult> sizes,
2799 ArrayRef<OpFoldResult> strides,
2800 ArrayRef<NamedAttribute> attrs) {
2801 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2802 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2803 dispatchIndexOpFoldResults(ofrs: offsets, dynamicVec&: dynamicOffsets, staticVec&: staticOffsets);
2804 dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynamicSizes, staticVec&: staticSizes);
2805 dispatchIndexOpFoldResults(ofrs: strides, dynamicVec&: dynamicStrides, staticVec&: staticStrides);
2806 result.addAttributes(newAttributes: attrs);
2807 build(odsBuilder&: b, odsState&: result, result: dest.getType(), source, dest, offsets: dynamicOffsets, sizes: dynamicSizes,
2808 strides: dynamicStrides, static_offsets: b.getDenseI64ArrayAttr(values: staticOffsets),
2809 static_sizes: b.getDenseI64ArrayAttr(values: staticSizes),
2810 static_strides: b.getDenseI64ArrayAttr(values: staticStrides));
2811}
2812
2813/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
2814/// Range vector.
2815void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2816 Value dest, ArrayRef<Range> ranges,
2817 ArrayRef<NamedAttribute> attrs) {
2818 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
2819 build(b, result, source, dest, offsets, sizes, strides, attrs);
2820}
2821
2822// Build a InsertSliceOp with dynamic entries.
2823void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
2824 Value dest, ValueRange offsets, ValueRange sizes,
2825 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2826 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2827 Range: llvm::map_range(C&: offsets, F: [](Value v) -> OpFoldResult { return v; }));
2828 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2829 Range: llvm::map_range(C&: sizes, F: [](Value v) -> OpFoldResult { return v; }));
2830 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2831 Range: llvm::map_range(C&: strides, F: [](Value v) -> OpFoldResult { return v; }));
2832 build(b, result, source, dest, offsets: offsetValues, sizes: sizeValues, strides: strideValues);
2833}
2834
2835/// Rank-reducing type verification for both InsertSliceOp and
2836/// ParallelInsertSliceOp.
2837static SliceVerificationResult verifyInsertSliceOp(
2838 RankedTensorType srcType, RankedTensorType dstType,
2839 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
2840 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
2841 // insert_slice is the inverse of extract_slice, use the same type
2842 // inference.
2843 RankedTensorType expected = ExtractSliceOp::inferResultType(
2844 sourceTensorType: dstType, staticOffsets, staticSizes, staticStrides);
2845 if (expectedType)
2846 *expectedType = expected;
2847 return isRankReducedType(originalType: expected, candidateReducedType: srcType);
2848}
2849
2850/// Verifier for InsertSliceOp.
2851LogicalResult InsertSliceOp::verify() {
2852 // Verify result type against inferred type.
2853 RankedTensorType expectedType;
2854 SliceVerificationResult result =
2855 verifyInsertSliceOp(srcType: getSourceType(), dstType: getType(), staticOffsets: getStaticOffsets(),
2856 staticSizes: getStaticSizes(), staticStrides: getStaticStrides(), expectedType: &expectedType);
2857 if (result != SliceVerificationResult::Success)
2858 return produceSliceErrorMsg(result, op: *this, expectedType);
2859
2860 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2861 // to the destination tensor.
2862 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2863 shape: getDestType().getShape(), staticOffsets: getStaticOffsets(), staticSizes: getStaticSizes(),
2864 staticStrides: getStaticStrides(), /*generateErrorMessage=*/true);
2865 if (!boundsResult.isValid)
2866 return getOperation()->emitError(message: boundsResult.errorMessage);
2867
2868 return success();
2869}
2870
2871/// If we have two consecutive InsertSliceOp writing to the same slice, we
2872/// can mutate the second InsertSliceOp's destination to the first one's.
2873///
2874/// Example:
2875///
2876/// ```mlir
2877/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
2878/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
2879/// ```
2880///
2881/// folds into:
2882///
2883/// ```mlir
2884/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
2885/// ```
2886///
2887/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2888static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
2889 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2890
2891 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2892 if (!prevInsertOp ||
2893 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2894 !prevInsertOp.isSameAs(other: insertOp, cmp: isSame))
2895 return failure();
2896
2897 insertOp.getDestMutable().assign(value: prevInsertOp.getDest());
2898 return success();
2899}
2900
2901/// Folds round-trip extract/insert slice op pairs.
2902/// Example:
2903/// ```mlir
2904/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2905/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
2906/// ```
2907/// can be folded into %val.
2908static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
2909 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
2910
2911 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
2912 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2913 !extractOp.isSameAs(other: insertOp, cmp: isSame))
2914 return nullptr;
2915
2916 return extractOp.getSource();
2917}
2918
2919OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2920 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2921 getSourceType() == getType() &&
2922 succeeded(Result: foldIdentityOffsetSizeAndStrideOpInterface(op: *this, shapedType: getType())))
2923 return this->getSource();
2924 if (succeeded(Result: foldInsertAfterInsertSlice(insertOp: *this)))
2925 return getResult();
2926 if (auto result = foldInsertAfterExtractSlice(insertOp: *this))
2927 return result;
2928 if (llvm::any_of(Range: getMixedSizes(), P: isZeroInteger))
2929 return getDest();
2930 return OpFoldResult();
2931}
2932
2933LogicalResult InsertSliceOp::reifyResultShapes(
2934 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2935 reifiedReturnShapes.resize(N: 1, NV: SmallVector<OpFoldResult>(getType().getRank()));
2936 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, loc: getLoc(), value: getDest());
2937 return success();
2938}
2939
2940namespace {
2941/// Pattern to rewrite a insert_slice op with constant arguments.
2942///
2943/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
2944template <typename InsertOpTy>
2945class InsertSliceOpConstantArgumentFolder final
2946 : public OpRewritePattern<InsertOpTy> {
2947public:
2948 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2949
2950 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2951 PatternRewriter &rewriter) const override {
2952 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2953 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2954 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2955
2956 // No constant operands were folded, just return;
2957 if (failed(Result: foldDynamicOffsetSizeList(offsetsOrSizes&: mixedOffsets)) &&
2958 failed(Result: foldDynamicOffsetSizeList(offsetsOrSizes&: mixedSizes)) &&
2959 failed(Result: foldDynamicStrideList(strides&: mixedStrides)))
2960 return failure();
2961
2962 // Pattern does not apply if the produced op would not verify.
2963 SliceBoundsVerificationResult sliceResult =
2964 verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2965 mixedOffsets, mixedSizes, mixedStrides);
2966 if (!sliceResult.isValid)
2967 return failure();
2968
2969 // Create the new op in canonical form.
2970 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2971 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2972 mixedOffsets, mixedSizes, mixedStrides);
2973 Value toInsert = insertSliceOp.getSource();
2974 if (sourceType != insertSliceOp.getSourceType()) {
2975 OpBuilder::InsertionGuard g(rewriter);
2976 // The only difference between InsertSliceOp and ParallelInsertSliceOp
2977 // is that the insertion point is just before the ParallelCombiningOp in
2978 // the parallel case.
2979 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2980 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
2981 toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
2982 sourceType, toInsert);
2983 }
2984 rewriter.replaceOpWithNewOp<InsertOpTy>(
2985 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2986 mixedSizes, mixedStrides);
2987 return success();
2988 }
2989};
2990
2991/// Fold tensor_casts with insert_slice operations. If the source or
2992/// destination tensor is a tensor_cast that removes static type information,
2993/// the cast is folded into the insert_slice operation. E.g.:
2994///
2995/// ```mlir
2996/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
2997/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
2998/// ```
2999///
3000/// folds into:
3001///
3002/// ```mlir
3003/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
3004/// ```
3005///
3006/// Note: When folding a cast on the destination tensor, the result of the
3007/// insert_slice operation is casted to ensure that the type of the result did
3008/// not change.
3009///
3010/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
3011template <typename InsertOpTy>
3012struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
3013 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3014
3015 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3016 PatternRewriter &rewriter) const override {
3017 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3018 return matchPattern(value: operand, pattern: matchConstantIndex());
3019 }))
3020 return failure();
3021
3022 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3023 auto castOp = v.getDefiningOp<tensor::CastOp>();
3024 if (!castOp || !canFoldIntoConsumerOp(castOp))
3025 return std::nullopt;
3026 return castOp.getSource();
3027 };
3028 std::optional<Value> sourceCastSource =
3029 getSourceOfCastOp(insertSliceOp.getSource());
3030 std::optional<Value> destCastSource =
3031 getSourceOfCastOp(insertSliceOp.getDest());
3032 if (!sourceCastSource && !destCastSource)
3033 return failure();
3034
3035 auto src =
3036 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3037 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3038 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
3039 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3040 if (!srcType || !dstType)
3041 return failure();
3042
3043 // The tensor.cast source could have additional static information not seen
3044 // in the insert slice op static sizes, so we ignore dynamic dims when
3045 // computing the rank reduction mask.
3046 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3047 auto rankReductionMask = computeRankReductionMask(
3048 staticSizes, srcType.getShape(), /*matchDynamic=*/true);
3049 if (!rankReductionMask.has_value())
3050 return failure();
3051 // Replace dimensions in the insert slice op with corresponding static dims
3052 // from the cast source type. If the insert slice sizes have static dims
3053 // that are not static in the tensor.cast source (i.e., when the cast op
3054 // casts a dynamic dim to static), the dim should not be replaced, and the
3055 // pattern will fail later in `verifyInsertSliceOp`.
3056 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3057 int64_t rankReducedIdx = 0;
3058 for (auto [idx, size] : enumerate(First&: staticSizes)) {
3059 if (!rankReductionMask.value().contains(idx) &&
3060 !srcType.isDynamicDim(rankReducedIdx)) {
3061 mixedSizes[idx] = getAsIndexOpFoldResult(
3062 rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
3063 size = srcType.getDimSize(rankReducedIdx++);
3064 }
3065 }
3066
3067 // Pattern does not apply if the produced op would not verify.
3068 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
3069 staticSizes, insertSliceOp.getStaticStrides()) !=
3070 SliceVerificationResult::Success)
3071 return failure();
3072 SliceBoundsVerificationResult sliceResult =
3073 verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
3074 mixedSizes, insertSliceOp.getMixedStrides());
3075 if (!sliceResult.isValid)
3076 return failure();
3077
3078 Operation *replacement = rewriter.create<InsertOpTy>(
3079 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
3080 mixedSizes, insertSliceOp.getMixedStrides());
3081
3082 // In the parallel case there is no result and so nothing to cast.
3083 bool isParallelInsert =
3084 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3085 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3086 replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
3087 insertSliceOp.getDestType(),
3088 replacement->getResult(idx: 0));
3089 }
3090 rewriter.replaceOp(insertSliceOp, replacement->getResults());
3091 return success();
3092 }
3093};
3094
3095/// If additional static type information can be deduced from a insert_slice's
3096/// size operands, insert an explicit cast of the op's source operand. This
3097/// enables other canonicalization patterns that are matching for tensor_cast
3098/// ops such as `ForOpTensorCastFolder` in SCF.
3099///
3100/// Example:
3101///
3102/// ```mlir
3103/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
3104/// : tensor<?x?xf32> into ...
3105/// ```
3106///
3107/// folds into:
3108///
3109/// ```mlir
3110/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
3111/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
3112/// : tensor<64x64xf32> into ...
3113/// ```
3114///
3115/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
3116template <typename InsertOpTy>
3117struct InsertSliceOpSourceCastInserter final
3118 : public OpRewritePattern<InsertOpTy> {
3119 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3120
3121 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3122 PatternRewriter &rewriter) const override {
3123 RankedTensorType srcType = insertSliceOp.getSourceType();
3124 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3125 return failure();
3126 SmallVector<int64_t> newSrcShape(srcType.getShape());
3127 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3128 if (std::optional<int64_t> constInt =
3129 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
3130 // Bail on invalid IR.
3131 if (*constInt < 0)
3132 return failure();
3133 newSrcShape[i] = *constInt;
3134 }
3135 }
3136 if (!hasValidSizesOffsets(sizesOrOffsets: newSrcShape))
3137 return failure();
3138
3139 RankedTensorType newSrcType = RankedTensorType::get(
3140 shape: newSrcShape, elementType: srcType.getElementType(), encoding: srcType.getEncoding());
3141 if (srcType == newSrcType ||
3142 !preservesStaticInformation(source: srcType, target: newSrcType) ||
3143 !tensor::CastOp::areCastCompatible(inputs: srcType, outputs: newSrcType))
3144 return failure();
3145
3146 // newSrcType is:
3147 // 1) Different from srcType.
3148 // 2) "More static" than srcType.
3149 // 3) Cast-compatible with srcType.
3150 // Insert the cast.
3151 OpBuilder::InsertionGuard g(rewriter);
3152 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
3153 // that the insertion point is just before the ParallelCombiningOp in the
3154 // parallel case.
3155 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
3156 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
3157 Value cast = rewriter.create<tensor::CastOp>(
3158 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
3159 rewriter.replaceOpWithNewOp<InsertOpTy>(
3160 insertSliceOp, cast, insertSliceOp.getDest(),
3161 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3162 insertSliceOp.getMixedStrides());
3163 return success();
3164 }
3165};
3166} // namespace
3167
3168llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3169 return ::getDroppedDims(reducedShape: getSourceType().getShape(), mixedSizes: getMixedSizes());
3170}
3171
3172void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3173 MLIRContext *context) {
3174 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3175 InsertSliceOpCastFolder<InsertSliceOp>,
3176 InsertSliceOpSourceCastInserter<InsertSliceOp>>(arg&: context);
3177}
3178
3179Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
3180 Location loc,
3181 Value tensor,
3182 Value dest) {
3183 auto rankedTensorType = llvm::cast<RankedTensorType>(Val: dest.getType());
3184 unsigned rank = rankedTensorType.getRank();
3185 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(value: 0));
3186 SmallVector<OpFoldResult> sizes = getMixedSizes(builder&: b, loc, value: dest);
3187 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(value: 1));
3188 return b.createOrFold<tensor::InsertSliceOp>(location: loc, args&: tensor, args&: dest, args&: offsets,
3189 args&: sizes, args&: strides);
3190}
3191
3192//===----------------------------------------------------------------------===//
3193// PadOp
3194//===----------------------------------------------------------------------===//
3195
3196void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3197 setNameFn(getResult(), "padded");
3198}
3199
3200// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
3201// supports optional types.
3202void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
3203 Type typeToInfer, Type typeToInferFrom) {}
3204
3205ParseResult
3206parseInferType(OpAsmParser &parser,
3207 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3208 Type &typeToInfer, Type typeToInferFrom) {
3209 if (optOperand)
3210 typeToInfer = typeToInferFrom;
3211 return success();
3212}
3213
3214LogicalResult PadOp::verify() {
3215 auto sourceType = llvm::cast<RankedTensorType>(Val: getSource().getType());
3216 auto resultType = llvm::cast<RankedTensorType>(Val: getResult().getType());
3217 auto expectedType =
3218 PadOp::inferResultType(sourceType, staticLow: getStaticLow(), staticHigh: getStaticHigh());
3219 if (!expectedType) {
3220 return emitError(message: "failed to infer expectedType from sourceType ")
3221 << sourceType << ", specified resultType is " << resultType;
3222 }
3223 if (resultType.getRank() != expectedType.getRank()) {
3224 return emitError(message: "specified type ")
3225 << resultType << " does not match the inferred type "
3226 << expectedType;
3227 }
3228 for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
3229 if (resultType.getDimSize(idx: i) == expectedType.getDimSize(idx: i))
3230 continue;
3231 if (expectedType.isDynamicDim(idx: i))
3232 continue;
3233 return emitError(message: "specified type ")
3234 << resultType << " does not match the inferred type "
3235 << expectedType;
3236 }
3237
3238 return success();
3239}
3240
3241LogicalResult PadOp::verifyRegions() {
3242 auto &region = getRegion();
3243 unsigned rank = llvm::cast<RankedTensorType>(Val: getResult().getType()).getRank();
3244 Block &block = region.front();
3245 if (block.getNumArguments() != rank)
3246 return emitError(message: "expected the block to have ") << rank << " arguments";
3247
3248 // Note: the number and type of yield values are checked in the YieldOp.
3249 for (const auto &en : llvm::enumerate(First: block.getArgumentTypes())) {
3250 if (!en.value().isIndex())
3251 return emitOpError(message: "expected block argument ")
3252 << (en.index() + 1) << " to be an index";
3253 }
3254
3255 // Ensure that the region yields an element of the right type.
3256 auto yieldOp = llvm::cast<YieldOp>(Val: block.getTerminator());
3257 if (yieldOp.getValue().getType() !=
3258 llvm::cast<ShapedType>(Val: getType()).getElementType())
3259 return emitOpError(message: "expected yield type to match shape element type");
3260
3261 return success();
3262}
3263
3264RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3265 ArrayRef<int64_t> staticLow,
3266 ArrayRef<int64_t> staticHigh,
3267 ArrayRef<int64_t> resultShape) {
3268 unsigned rank = sourceType.getRank();
3269 if (staticLow.size() != rank)
3270 return RankedTensorType();
3271 if (staticHigh.size() != rank)
3272 return RankedTensorType();
3273 if (!resultShape.empty() && resultShape.size() != rank)
3274 return RankedTensorType();
3275
3276 SmallVector<int64_t, 4> inferredShape;
3277 for (auto i : llvm::seq<unsigned>(Begin: 0, End: rank)) {
3278 if (sourceType.isDynamicDim(idx: i) || staticLow[i] == ShapedType::kDynamic ||
3279 staticHigh[i] == ShapedType::kDynamic) {
3280 inferredShape.push_back(Elt: resultShape.empty() ? ShapedType::kDynamic
3281 : resultShape[i]);
3282 } else {
3283 int64_t size = sourceType.getDimSize(idx: i) + staticLow[i] + staticHigh[i];
3284 assert((resultShape.empty() || size == resultShape[i] ||
3285 resultShape[i] == ShapedType::kDynamic) &&
3286 "mismatch between inferred shape and result shape");
3287 inferredShape.push_back(Elt: size);
3288 }
3289 }
3290
3291 return RankedTensorType::get(shape: inferredShape, elementType: sourceType.getElementType());
3292}
3293
3294void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3295 Value source, ArrayRef<int64_t> staticLow,
3296 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
3297 bool nofold, ArrayRef<NamedAttribute> attrs) {
3298 auto sourceType = llvm::cast<RankedTensorType>(Val: source.getType());
3299 if (!resultType)
3300 resultType = inferResultType(sourceType, staticLow, staticHigh);
3301 result.addAttributes(newAttributes: attrs);
3302 build(odsBuilder&: b, odsState&: result, result: resultType, source, low, high,
3303 static_low: b.getDenseI64ArrayAttr(values: staticLow), static_high: b.getDenseI64ArrayAttr(values: staticHigh),
3304 nofold: nofold ? b.getUnitAttr() : UnitAttr());
3305}
3306
3307void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3308 Value source, ValueRange low, ValueRange high, bool nofold,
3309 ArrayRef<NamedAttribute> attrs) {
3310 auto sourceType = llvm::cast<RankedTensorType>(Val: source.getType());
3311 unsigned rank = sourceType.getRank();
3312 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3313 build(b, result, resultType, source, staticLow: staticVector, staticHigh: staticVector, low, high,
3314 nofold, attrs);
3315}
3316
3317void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3318 Value source, ArrayRef<OpFoldResult> low,
3319 ArrayRef<OpFoldResult> high, bool nofold,
3320 ArrayRef<NamedAttribute> attrs) {
3321 auto sourceType = llvm::cast<RankedTensorType>(Val: source.getType());
3322 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3323 SmallVector<int64_t, 4> staticLow, staticHigh;
3324 // staticLow and staticHigh have full information of the padding config.
3325 // This will grow staticLow and staticHigh with 1 value. If the config is
3326 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
3327 // value as well.
3328 dispatchIndexOpFoldResults(ofrs: low, dynamicVec&: dynamicLow, staticVec&: staticLow);
3329 dispatchIndexOpFoldResults(ofrs: high, dynamicVec&: dynamicHigh, staticVec&: staticHigh);
3330 if (!resultType) {
3331 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3332 }
3333 assert(llvm::isa<RankedTensorType>(resultType));
3334 result.addAttributes(newAttributes: attrs);
3335 build(odsBuilder&: b, odsState&: result, result: resultType, source, low: dynamicLow, high: dynamicHigh,
3336 static_low: b.getDenseI64ArrayAttr(values: staticLow), static_high: b.getDenseI64ArrayAttr(values: staticHigh),
3337 nofold: nofold ? b.getUnitAttr() : UnitAttr());
3338}
3339
3340void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
3341 Value source, ArrayRef<OpFoldResult> low,
3342 ArrayRef<OpFoldResult> high, Value constantPadValue,
3343 bool nofold, ArrayRef<NamedAttribute> attrs) {
3344 build(b, result, resultType, source, low, high, nofold, attrs);
3345
3346 // Add a region and a block to yield the pad value.
3347 Region *region = result.regions[0].get();
3348 int sourceRank = llvm::cast<RankedTensorType>(Val: source.getType()).getRank();
3349 SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
3350 SmallVector<Location> blockArgLocs(sourceRank, result.location);
3351
3352 // `builder.createBlock` changes the insertion point within the block. Create
3353 // a guard to reset the insertion point of the builder after it is destroyed.
3354 OpBuilder::InsertionGuard guard(b);
3355 b.createBlock(parent: region, insertPt: region->end(), argTypes: blockArgTypes, locs: blockArgLocs);
3356 b.create<tensor::YieldOp>(location: result.location, args&: constantPadValue);
3357}
3358
3359llvm::SmallBitVector PadOp::getPaddedDims() {
3360 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3361 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3362 for (const auto &en : enumerate(First&: paddingWidths))
3363 if (getConstantIntValue(ofr: en.value()) != static_cast<int64_t>(0))
3364 paddedDims.set(en.index());
3365 };
3366 extractPaddedDims(getMixedLowPad());
3367 extractPaddedDims(getMixedHighPad());
3368 return paddedDims;
3369}
3370
3371namespace {
3372// Folds tensor.pad when padding is static zeros and the attribute
3373// doesn't request otherwise.
3374struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
3375 using OpRewritePattern<PadOp>::OpRewritePattern;
3376
3377 LogicalResult matchAndRewrite(PadOp padTensorOp,
3378 PatternRewriter &rewriter) const override {
3379 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3380 return failure();
3381 if (padTensorOp.getNofold())
3382 return failure();
3383 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3384 op: padTensorOp, args: padTensorOp.getResult().getType(),
3385 args: padTensorOp.getSource());
3386 return success();
3387 }
3388};
3389
3390// Fold CastOp into PadOp when adding static information.
3391struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
3392 using OpRewritePattern<PadOp>::OpRewritePattern;
3393
3394 LogicalResult matchAndRewrite(PadOp padTensorOp,
3395 PatternRewriter &rewriter) const override {
3396 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3397 if (!tensor::canFoldIntoConsumerOp(castOp))
3398 return failure();
3399
3400 auto newResultType = PadOp::inferResultType(
3401 sourceType: llvm::cast<RankedTensorType>(Val: castOp.getSource().getType()),
3402 staticLow: padTensorOp.getStaticLow(), staticHigh: padTensorOp.getStaticHigh(),
3403 resultShape: padTensorOp.getResultType().getShape());
3404
3405 if (newResultType == padTensorOp.getResultType()) {
3406 rewriter.modifyOpInPlace(root: padTensorOp, callable: [&]() {
3407 padTensorOp.getSourceMutable().assign(value: castOp.getSource());
3408 });
3409 } else {
3410 auto newOp = rewriter.create<PadOp>(
3411 location: padTensorOp->getLoc(), args&: newResultType, args: padTensorOp.getSource(),
3412 args: padTensorOp.getStaticLow(), args: padTensorOp.getStaticHigh(),
3413 args: padTensorOp.getLow(), args: padTensorOp.getHigh(), args: padTensorOp.getNofold(),
3414 args: getPrunedAttributeList(op: padTensorOp, elidedAttrs: PadOp::getAttributeNames()));
3415 IRMapping mapper;
3416 padTensorOp.getRegion().cloneInto(dest: &newOp.getRegion(), mapper);
3417
3418 rewriter.replaceOpWithNewOp<tensor::CastOp>(
3419 op: padTensorOp, args: padTensorOp.getResultType(), args&: newOp);
3420 }
3421 return success();
3422 }
3423};
3424
3425// Fold CastOp using the result of PadOp back into the latter if it adds
3426// static information.
3427struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
3428 using OpRewritePattern<PadOp>::OpRewritePattern;
3429
3430 LogicalResult matchAndRewrite(PadOp padTensorOp,
3431 PatternRewriter &rewriter) const override {
3432 if (!padTensorOp.getResult().hasOneUse())
3433 return failure();
3434 auto tensorCastOp =
3435 dyn_cast<tensor::CastOp>(Val: *padTensorOp->getUsers().begin());
3436 if (!tensorCastOp)
3437 return failure();
3438 if (!tensor::preservesStaticInformation(source: padTensorOp.getResult().getType(),
3439 target: tensorCastOp.getDest().getType()))
3440 return failure();
3441
3442 auto replacementOp = rewriter.create<PadOp>(
3443 location: padTensorOp.getLoc(), args: tensorCastOp.getDest().getType(),
3444 args: padTensorOp.getSource(), args: padTensorOp.getStaticLow(),
3445 args: padTensorOp.getStaticHigh(), args: padTensorOp.getLow(),
3446 args: padTensorOp.getHigh(), args: padTensorOp.getNofold(),
3447 args: getPrunedAttributeList(op: padTensorOp, elidedAttrs: PadOp::getAttributeNames()));
3448 replacementOp.getRegion().takeBody(other&: padTensorOp.getRegion());
3449
3450 rewriter.replaceOp(op: padTensorOp, newValues: replacementOp.getResult());
3451 rewriter.replaceOp(op: tensorCastOp, newValues: replacementOp.getResult());
3452 return success();
3453 }
3454};
3455
3456/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
3457/// different dimensions. The pattern applies if the following preconditions
3458/// hold:
3459/// 1) the tensor::ExtractSliceOps are not rank-reducing,
3460/// 2) the tensor::ExtractSliceOps have only unit-strides,
3461/// 3) the tensor::PadOps perform only high-padding,
3462/// 4) the tensor::PadOps have the same constant padding value,
3463/// 5) the tensor::PadOps do not have common padding dimensions,
3464/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
3465/// zero-offset for every dimension.
3466/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
3467/// the
3468/// padded source dimensions.
3469///
3470/// Example:
3471///
3472/// ```mlir
3473/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
3474/// : tensor<64x64xf32> to tensor<?x64xf32>
3475/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
3476/// } : tensor<?x64xf32> to tensor<8x64xf32>
3477/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
3478/// : tensor<8x64xf32> to tensor<8x?xf32>
3479/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
3480/// } : tensor<8x?xf32> to tensor<8x4xf32>
3481/// ```
3482///
3483/// folds into:
3484///
3485/// ```mlir
3486/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
3487/// : tensor<64x64xf32> to tensor<?x?xf32>
3488/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
3489/// } : tensor<?x?xf32> to tensor<8x4xf32>
3490/// ```
3491struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
3492 using OpRewritePattern<PadOp>::OpRewritePattern;
3493
3494 LogicalResult matchAndRewrite(PadOp padOp,
3495 PatternRewriter &rewriter) const override {
3496 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3497 if (!innerSliceOp)
3498 return failure();
3499 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3500 if (!outerPadOp || outerPadOp.getNofold())
3501 return failure();
3502 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3503 if (!outerSliceOp)
3504 return failure();
3505
3506 // 1) Fail if the chain is rank-reducing.
3507 int64_t rank = padOp.getSourceType().getRank();
3508 if (outerSliceOp.getSourceType().getRank() != rank) {
3509 return rewriter.notifyMatchFailure(arg&: padOp,
3510 msg: "cannot fold rank-reducing chain");
3511 }
3512
3513 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
3514 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3515 return rewriter.notifyMatchFailure(
3516 arg&: padOp, msg: "cannot fold non-unit stride ExtractSliceOps");
3517 }
3518
3519 // 3) Fail if the tensor::PadOps have non-zero low padding.
3520 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3521 return rewriter.notifyMatchFailure(arg&: padOp,
3522 msg: "cannot fold PadOps with low padding");
3523 }
3524
3525 // 4) Fail if the tensor::PadOps padding values do not match.
3526 Attribute innerAttr, outerAttr;
3527 Value innerValue = padOp.getConstantPaddingValue();
3528 Value outerValue = outerPadOp.getConstantPaddingValue();
3529 if (!innerValue || !outerValue ||
3530 !matchPattern(value: innerValue, pattern: m_Constant(bind_value: &innerAttr)) ||
3531 !matchPattern(value: outerValue, pattern: m_Constant(bind_value: &outerAttr)) ||
3532 innerAttr != outerAttr) {
3533 return rewriter.notifyMatchFailure(
3534 arg&: padOp, msg: "cannot fold PadOps with different padding values");
3535 }
3536
3537 // 5) Fail if a dimension is padded by both tensor::PadOps.
3538 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3539 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3540 if (innerDims.anyCommon(RHS: outerDims)) {
3541 return rewriter.notifyMatchFailure(
3542 arg&: padOp, msg: "cannot fold PadOps with common padding dimensions");
3543 }
3544
3545 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
3546 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3547 // for every dimension, and use the offset the other pair. Fail if no
3548 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
3549 // exists.
3550 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(value: 0));
3551 for (auto en : enumerate(First&: newOffsets)) {
3552 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3553 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3554 if (!innerDims.test(Idx: en.index()) &&
3555 (getConstantIntValue(ofr: innerOffset) == static_cast<int64_t>(0))) {
3556 en.value() = outerOffset;
3557 continue;
3558 }
3559 if (!outerDims.test(Idx: en.index()) &&
3560 (getConstantIntValue(ofr: outerOffset) == static_cast<int64_t>(0))) {
3561 en.value() = innerOffset;
3562 continue;
3563 }
3564 return rewriter.notifyMatchFailure(
3565 arg&: padOp, msg: "cannot find zero-offset and zero-padding pair");
3566 }
3567
3568 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
3569 // of the outer tensor::ExtractSliceOp for the dimensions padded by the
3570 // outer tensor::PadOp and fail if the size of the inner
3571 // tensor::ExtractSliceOp does not match the size of the padded dimension.
3572 // Otherwise, take the size of the inner tensor::ExtractSliceOp.
3573 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3574 for (auto en : enumerate(First&: newSizes)) {
3575 if (!outerDims.test(Idx: en.index()))
3576 continue;
3577 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3578 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3579 assert(ShapedType::isStatic(sourceSize) &&
3580 "expected padded dimension to have a static size");
3581 if (getConstantIntValue(ofr: sliceSize) != sourceSize) {
3582 return rewriter.notifyMatchFailure(
3583 arg&: padOp, msg: "cannot fold since the inner ExtractSliceOp size does not "
3584 "match the size of the outer padding");
3585 }
3586 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3587 }
3588
3589 // Combine the high paddings of the two tensor::PadOps.
3590 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(value: 0));
3591 for (auto en : enumerate(First&: newHighPad)) {
3592 if (innerDims.test(Idx: en.index()))
3593 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3594 if (outerDims.test(Idx: en.index()))
3595 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3596 }
3597
3598 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
3599 // the two paddings in one step.
3600 auto newSliceOp = rewriter.create<ExtractSliceOp>(
3601 location: padOp.getLoc(), args: outerSliceOp.getSource(), args&: newOffsets, args&: newSizes,
3602 args: innerSliceOp.getMixedStrides());
3603 auto newPadOp = rewriter.create<PadOp>(
3604 location: padOp.getLoc(), args: padOp.getResultType(), args: newSliceOp.getResult(),
3605 args: padOp.getMixedLowPad(), args&: newHighPad, args: padOp.getNofold(),
3606 args: getPrunedAttributeList(op: padOp, elidedAttrs: PadOp::getAttributeNames()));
3607 rewriter.inlineRegionBefore(region&: padOp.getRegion(), parent&: newPadOp.getRegion(),
3608 before: newPadOp.getRegion().begin());
3609 rewriter.replaceOp(op: padOp, newValues: newPadOp.getResult());
3610 return success();
3611 }
3612};
3613
3614struct FoldStaticPadding : public OpRewritePattern<PadOp> {
3615 using OpRewritePattern<PadOp>::OpRewritePattern;
3616
3617 LogicalResult matchAndRewrite(PadOp padTensorOp,
3618 PatternRewriter &rewriter) const override {
3619 Value input = padTensorOp.getSource();
3620 if (!llvm::isa<RankedTensorType>(Val: input.getType()))
3621 return failure();
3622 auto inputDims = llvm::cast<RankedTensorType>(Val: input.getType()).getShape();
3623 auto inputRank = inputDims.size();
3624
3625 auto oldResultType =
3626 dyn_cast<RankedTensorType>(Val: padTensorOp.getResult().getType());
3627 if (!oldResultType)
3628 return failure();
3629
3630 auto outputDims = oldResultType.getShape();
3631
3632 // Extract the static info from the high and low operands.
3633 SmallVector<int64_t> constOperandsLow;
3634 SmallVector<Value> newLows;
3635 for (auto operand : padTensorOp.getLow()) {
3636 APSInt intOp;
3637 if (!matchPattern(value: operand, pattern: m_ConstantInt(bind_value: &intOp))) {
3638 constOperandsLow.push_back(Elt: ShapedType::kDynamic);
3639 newLows.push_back(Elt: operand);
3640 continue;
3641 }
3642 constOperandsLow.push_back(Elt: intOp.getExtValue());
3643 }
3644 SmallVector<int64_t> constOperandsHigh;
3645 SmallVector<Value> newHighs;
3646 for (auto operand : padTensorOp.getHigh()) {
3647 APSInt intOp;
3648 if (!matchPattern(value: operand, pattern: m_ConstantInt(bind_value: &intOp))) {
3649 constOperandsHigh.push_back(Elt: ShapedType::kDynamic);
3650 newHighs.push_back(Elt: operand);
3651 continue;
3652 }
3653 constOperandsHigh.push_back(Elt: intOp.getExtValue());
3654 }
3655
3656 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3657 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3658
3659 // Verify the op is well-formed.
3660 if (inputDims.size() != outputDims.size() ||
3661 inputDims.size() != constLow.size() ||
3662 inputDims.size() != constHigh.size())
3663 return failure();
3664
3665 auto lowCount = 0;
3666 auto highCount = 0;
3667 for (size_t i = 0; i < inputRank; i++) {
3668 if (constLow[i] == ShapedType::kDynamic)
3669 constLow[i] = constOperandsLow[lowCount++];
3670 if (constHigh[i] == ShapedType::kDynamic)
3671 constHigh[i] = constOperandsHigh[highCount++];
3672 }
3673
3674 auto staticLow = ArrayRef<int64_t>(constLow);
3675 auto staticHigh = ArrayRef<int64_t>(constHigh);
3676
3677 // Calculate the output sizes with the static information.
3678 SmallVector<int64_t> newOutDims;
3679 for (size_t i = 0; i < inputRank; i++) {
3680 if (outputDims[i] == ShapedType::kDynamic) {
3681 newOutDims.push_back(
3682 Elt: (staticLow[i] == ShapedType::kDynamic ||
3683 staticHigh[i] == ShapedType::kDynamic ||
3684 inputDims[i] == ShapedType::kDynamic
3685 ? ShapedType::kDynamic
3686 : inputDims[i] + staticLow[i] + staticHigh[i]));
3687 } else {
3688 newOutDims.push_back(Elt: outputDims[i]);
3689 }
3690 }
3691
3692 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3693 llvm::all_of(Range&: newOutDims,
3694 P: [&](int64_t x) { return x == ShapedType::kDynamic; }))
3695 return failure();
3696
3697 // Rewrite the op using the new static type.
3698 auto newResultType = RankedTensorType::get(
3699 shape: newOutDims, elementType: padTensorOp.getType().getElementType());
3700 auto newOp = rewriter.create<PadOp>(
3701 location: padTensorOp->getLoc(), args&: newResultType, args&: input, args&: staticLow, args&: staticHigh,
3702 args&: newLows, args&: newHighs, args: padTensorOp.getNofold(),
3703 args: getPrunedAttributeList(op: padTensorOp, elidedAttrs: PadOp::getAttributeNames()));
3704
3705 IRMapping mapper;
3706 padTensorOp.getRegion().cloneInto(dest: &newOp.getRegion(), mapper);
3707 rewriter.replaceOpWithNewOp<tensor::CastOp>(op: padTensorOp, args&: oldResultType,
3708 args&: newOp);
3709
3710 return success();
3711 }
3712};
3713
3714/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3715///
3716/// Example:
3717///
3718/// ```mlir
3719/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3720/// tensor.yield %val
3721/// } : tensor<1x2xf32> to tensor<2x5xf32>
3722/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3723/// tensor.yield %val
3724/// } : tensor<1x5xf32> to tensor<5x7xf32>
3725/// ```
3726///
3727/// folds into:
3728///
3729/// ```mlir
3730/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3731/// tensor.yield %val
3732/// } : tensor<1x2xf32> to tensor<5x7xf32>
3733/// ```
3734struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3735 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3736
3737 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3738 PatternRewriter &rewriter) const override {
3739 if (padOp.getNofold()) {
3740 return rewriter.notifyMatchFailure(arg&: padOp, msg: "skipping unfoldable pad");
3741 }
3742
3743 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3744 if (!producerPad || producerPad.getNofold()) {
3745 return rewriter.notifyMatchFailure(
3746 arg&: padOp, msg: "producer is not a foldable tensor.pad op");
3747 }
3748
3749 // Fail if the tensor::PadOps padding values do not match.
3750 Value consumerPadValue = padOp.getConstantPaddingValue();
3751 Value producerPadValue = producerPad.getConstantPaddingValue();
3752 if (!consumerPadValue || !producerPadValue ||
3753 consumerPadValue != producerPadValue) {
3754 return rewriter.notifyMatchFailure(
3755 arg&: padOp,
3756 msg: "cannot fold PadOps with different or non-constant padding values");
3757 }
3758
3759 Location loc = padOp.getLoc();
3760 AffineExpr d0, d1;
3761 bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1);
3762
3763 // Combine the low/high paddings of the two tensor::PadOps.
3764 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3765 ArrayRef<OpFoldResult> producerPaddings) {
3766 SmallVector<OpFoldResult> sumPaddings;
3767 for (auto [consumerIndex, producerIndex] :
3768 llvm::zip_equal(t&: consumerPaddings, u&: producerPaddings)) {
3769 sumPaddings.push_back(Elt: affine::makeComposedFoldedAffineApply(
3770 b&: rewriter, loc, expr: d0 + d1, operands: {consumerIndex, producerIndex}));
3771 }
3772 return sumPaddings;
3773 };
3774
3775 SmallVector<OpFoldResult> newHighPad =
3776 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3777 SmallVector<OpFoldResult> newLowPad =
3778 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3779
3780 auto newPadOp = rewriter.create<tensor::PadOp>(
3781 location: padOp.getLoc(), args: padOp.getResultType(), args: producerPad.getSource(),
3782 args&: newLowPad, args&: newHighPad, args: padOp.getNofold(),
3783 args: getPrunedAttributeList(op: padOp, elidedAttrs: tensor::PadOp::getAttributeNames()));
3784 rewriter.inlineRegionBefore(region&: padOp.getRegion(), parent&: newPadOp.getRegion(),
3785 before: newPadOp.getRegion().begin());
3786 rewriter.replaceOp(op: padOp, newValues: newPadOp.getResult());
3787 return success();
3788 }
3789};
3790
3791} // namespace
3792
3793LogicalResult
3794PadOp::reifyResultShapes(OpBuilder &b,
3795 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3796 reifiedReturnShapes.resize(N: 1, NV: SmallVector<OpFoldResult>(getType().getRank()));
3797 SmallVector<OpFoldResult> lp = getMixedLowPad();
3798 SmallVector<OpFoldResult> hp = getMixedHighPad();
3799 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3800 if (!getType().isDynamicDim(idx: i)) {
3801 reifiedReturnShapes[0][i] = b.getIndexAttr(value: getType().getDimSize(idx: i));
3802 continue;
3803 }
3804 Location loc = getLoc();
3805 Value dim = b.createOrFold<tensor::DimOp>(
3806 location: loc, args: getSource(), args: b.create<arith::ConstantIndexOp>(location: loc, args&: i));
3807
3808 AffineExpr d0, d1, d2;
3809 bindDims(ctx: b.getContext(), exprs&: d0, exprs&: d1, exprs&: d2);
3810 reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3811 b, loc, expr: {d0 + d1 + d2}, operands: {dim, lp[i], hp[i]});
3812 }
3813 return success();
3814}
3815
3816void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3817 MLIRContext *context) {
3818 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3819 FoldOrthogonalPaddings, FoldStaticPadding,
3820 FoldConsecutiveConstantPadding>(arg&: context);
3821}
3822
3823/// Return the padding value of the PadOp if it constant. In this context,
3824/// "constant" means an actual constant or "defined outside of the block".
3825///
3826/// Values are considered constant in three cases:
3827/// - A ConstantLike value.
3828/// - A basic block argument from a different block.
3829/// - A value defined outside of the block.
3830///
3831/// If the padding value is not constant, an empty Value is returned.
3832Value PadOp::getConstantPaddingValue() {
3833 auto yieldOp = dyn_cast<YieldOp>(Val: getRegion().front().getTerminator());
3834 if (!yieldOp)
3835 return {};
3836 Value padValue = yieldOp.getValue();
3837 // Check if yield value is a constant.
3838 if (matchPattern(value: padValue, pattern: m_Constant()))
3839 return padValue;
3840 // Check if yield value is defined inside the PadOp block.
3841 if (padValue.getParentBlock() == &getRegion().front())
3842 return {};
3843 // Else: Yield value defined outside of the PadOp block.
3844 return padValue;
3845}
3846
3847OpFoldResult PadOp::fold(FoldAdaptor) {
3848 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3849 !getNofold())
3850 return getSource();
3851 return {};
3852}
3853
3854//===----------------------------------------------------------------------===//
3855// ParallelInsertSliceOp
3856//===----------------------------------------------------------------------===//
3857
3858OpResult ParallelInsertSliceOp::getTiedOpResult() {
3859 ParallelCombiningOpInterface parallelCombiningParent =
3860 getParallelCombiningParent();
3861 for (const auto &it :
3862 llvm::enumerate(First: parallelCombiningParent.getYieldingOps())) {
3863 Operation &nextOp = it.value();
3864 if (&nextOp == getOperation())
3865 return parallelCombiningParent.getParentResult(idx: it.index());
3866 }
3867 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
3868}
3869
3870// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
3871void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3872 Value source, Value dest,
3873 ArrayRef<OpFoldResult> offsets,
3874 ArrayRef<OpFoldResult> sizes,
3875 ArrayRef<OpFoldResult> strides,
3876 ArrayRef<NamedAttribute> attrs) {
3877 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3878 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3879 dispatchIndexOpFoldResults(ofrs: offsets, dynamicVec&: dynamicOffsets, staticVec&: staticOffsets);
3880 dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynamicSizes, staticVec&: staticSizes);
3881 dispatchIndexOpFoldResults(ofrs: strides, dynamicVec&: dynamicStrides, staticVec&: staticStrides);
3882 result.addAttributes(newAttributes: attrs);
3883 build(odsBuilder&: b, odsState&: result, resultTypes: {}, source, dest, offsets: dynamicOffsets, sizes: dynamicSizes,
3884 strides: dynamicStrides, static_offsets: b.getDenseI64ArrayAttr(values: staticOffsets),
3885 static_sizes: b.getDenseI64ArrayAttr(values: staticSizes),
3886 static_strides: b.getDenseI64ArrayAttr(values: staticStrides));
3887}
3888
3889/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
3890/// packed into a Range vector.
3891void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3892 Value source, Value dest,
3893 ArrayRef<Range> ranges,
3894 ArrayRef<NamedAttribute> attrs) {
3895 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
3896 build(b, result, source, dest, offsets, sizes, strides, attrs);
3897}
3898
3899// Build a ParallelInsertSliceOp with dynamic entries.
3900void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
3901 Value source, Value dest, ValueRange offsets,
3902 ValueRange sizes, ValueRange strides,
3903 ArrayRef<NamedAttribute> attrs) {
3904 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3905 Range: llvm::map_range(C&: offsets, F: [](Value v) -> OpFoldResult { return v; }));
3906 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3907 Range: llvm::map_range(C&: sizes, F: [](Value v) -> OpFoldResult { return v; }));
3908 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3909 Range: llvm::map_range(C&: strides, F: [](Value v) -> OpFoldResult { return v; }));
3910 build(b, result, source, dest, offsets: offsetValues, sizes: sizeValues, strides: strideValues);
3911}
3912
3913LogicalResult ParallelInsertSliceOp::verify() {
3914 if (!isa<ParallelCombiningOpInterface>(Val: getOperation()->getParentOp()))
3915 return this->emitError(message: "expected ParallelCombiningOpInterface parent, got:")
3916 << *(getOperation()->getParentOp());
3917
3918 // Verify result type against inferred type.
3919 RankedTensorType expectedType;
3920 SliceVerificationResult result =
3921 verifyInsertSliceOp(srcType: getSourceType(), dstType: getDestType(), staticOffsets: getStaticOffsets(),
3922 staticSizes: getStaticSizes(), staticStrides: getStaticStrides(), expectedType: &expectedType);
3923 if (result != SliceVerificationResult::Success)
3924 return produceSliceErrorMsg(result, op: *this, expectedType);
3925
3926 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3927 // to the destination tensor.
3928 SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3929 shape: getDestType().getShape(), staticOffsets: getStaticOffsets(), staticSizes: getStaticSizes(),
3930 staticStrides: getStaticStrides(), /*generateErrorMessage=*/true);
3931 if (!boundsResult.isValid)
3932 return getOperation()->emitError(message: boundsResult.errorMessage);
3933
3934 return success();
3935}
3936
3937void ParallelInsertSliceOp::getCanonicalizationPatterns(
3938 RewritePatternSet &results, MLIRContext *context) {
3939 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3940 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3941 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(arg&: context);
3942}
3943
3944llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3945 return ::getDroppedDims(reducedShape: getSourceType().getShape(), mixedSizes: getMixedSizes());
3946}
3947
3948//===----------------------------------------------------------------------===//
3949// ScatterOp
3950//===----------------------------------------------------------------------===//
3951
3952void ScatterOp::getAsmResultNames(
3953 function_ref<void(Value, StringRef)> setNameFn) {
3954 setNameFn(getResult(), "scatter");
3955}
3956
3957LogicalResult ScatterOp::verify() {
3958 int64_t destRank = getDestType().getRank();
3959 ArrayRef<int64_t> scatterDims = getScatterDims();
3960 if (failed(Result: verifyGatherOrScatterDims(op: getOperation(), dims: scatterDims,
3961 indices: getIndicesType().getShape(), rank: destRank,
3962 gatherOrScatter: "scatter", sourceOrDest: "dest")))
3963 return failure();
3964
3965 if (!getUnique())
3966 return emitOpError(message: "requires 'unique' attribute to be set");
3967 // TODO: we could also check statically that there are fewer leading index
3968 // tensor dims than the dest dims. If this is not the case, the unique
3969 // attribute cannot be true.
3970
3971 // Use the GatherOp::inferResultType on the `dest` type and verify the
3972 // expected type matches the source type.
3973 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3974 sourceType: getDestType(), indicesType: getIndicesType(), gatherDims: scatterDims, /*rankReduced=*/false);
3975 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3976 sourceType: getDestType(), indicesType: getIndicesType(), gatherDims: scatterDims, /*rankReduced=*/true);
3977 if (getSourceType() != expectedSourceType &&
3978 getSourceType() != expectedRankReducedSourceType) {
3979 return emitOpError(message: "source type "
3980 "mismatch: "
3981 "expected ")
3982 << expectedSourceType << " or its rank-reduced variant "
3983 << expectedRankReducedSourceType << " (got: " << getSourceType()
3984 << ")";
3985 }
3986
3987 return success();
3988}
3989
3990//===----------------------------------------------------------------------===//
3991// SplatOp
3992//===----------------------------------------------------------------------===//
3993
3994void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
3995 Type aggregateType, ValueRange dynamicSizes) {
3996 build(odsBuilder&: builder, odsState&: result, aggregate: aggregateType, input: element, dynamicSizes);
3997}
3998
3999void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4000 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
4001 auto aggregateType = RankedTensorType::get(shape: staticShape, elementType: element.getType());
4002 build(odsBuilder&: builder, odsState&: result, aggregate: aggregateType, input: element, dynamicSizes);
4003}
4004
4005void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
4006 ArrayRef<OpFoldResult> sizes) {
4007 SmallVector<int64_t> staticShape;
4008 SmallVector<Value> dynamicSizes;
4009 dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynamicSizes, staticVec&: staticShape);
4010 build(builder, result, element, staticShape, dynamicSizes);
4011}
4012
4013void SplatOp::getAsmResultNames(
4014 function_ref<void(Value, StringRef)> setNameFn) {
4015 setNameFn(getResult(), "splat");
4016}
4017
4018LogicalResult SplatOp::verify() {
4019 if (getType().getNumDynamicDims() != getDynamicSizes().size())
4020 return emitOpError(message: "incorrect number of dynamic sizes, has ")
4021 << getDynamicSizes().size() << ", expected "
4022 << getType().getNumDynamicDims();
4023 return success();
4024}
4025
4026LogicalResult
4027SplatOp::reifyResultShapes(OpBuilder &builder,
4028 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4029 reifiedReturnShapes.resize(N: 1, NV: SmallVector<OpFoldResult>(getType().getRank()));
4030 unsigned ctr = 0;
4031 for (int64_t i = 0; i < getType().getRank(); ++i) {
4032 if (getType().isDynamicDim(idx: i)) {
4033 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
4034 } else {
4035 reifiedReturnShapes[0][i] = builder.getIndexAttr(value: getType().getDimSize(idx: i));
4036 }
4037 }
4038 return success();
4039}
4040
4041OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4042 auto constOperand = adaptor.getInput();
4043 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(Val: constOperand))
4044 return {};
4045
4046 // Do not fold if the splat is not statically shaped
4047 if (!getType().hasStaticShape())
4048 return {};
4049
4050 // SplatElementsAttr::get treats single value for second arg as being a
4051 // splat.
4052 return SplatElementsAttr::get(type: getType(), list: {constOperand});
4053}
4054
4055//===----------------------------------------------------------------------===//
4056// Common Canonicalizers and Folders.
4057//===----------------------------------------------------------------------===//
4058bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4059 // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4060 // 2. Exclude DPS ops that are also LoopLike from this interface as they
4061 // might need special handling of attached regions.
4062 if (isa<InsertSliceOp>(Val: op.getOperation()) ||
4063 isa<LoopLikeOpInterface>(Val: op.getOperation()))
4064 return false;
4065
4066 return hasFoldableTensorCastOperand(op);
4067}
4068
4069/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4070/// the `tensor.cast` has source that is more static than the consuming op.
4071///
4072/// Example:
4073/// ```mlir
4074/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4075/// %2 = consumer %1 ... : tensor<?x?xf32> ...
4076/// ```
4077///
4078/// folds into:
4079///
4080/// ```mlir
4081/// %2 = consumer %0 ... : tensor<8x16xf32> ...
4082/// ```
4083/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
4084/// can add the pattern to their canonicalizers.
4085struct FoldTensorCastProducerOp
4086 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
4087 using OpInterfaceRewritePattern<
4088 DestinationStyleOpInterface>::OpInterfaceRewritePattern;
4089
4090 LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
4091 PatternRewriter &rewriter) const override {
4092
4093 // Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
4094 // for that instead.
4095 if (!foldTensorCastPrecondition(op) ||
4096 isa<linalg::RelayoutOpInterface>(Val: *op))
4097 return failure();
4098
4099 SmallVector<Type> newResultTypes(op->getResultTypes());
4100 SmallVector<Value> newOperands =
4101 getUpdatedOperandsAfterCastOpFolding(op, newResTy&: newResultTypes);
4102
4103 // Clone op
4104 auto newOp = clone(b&: rewriter, op, newResultTypes, newOperands);
4105
4106 SmallVector<Value, 4> replacements;
4107 replacements.reserve(N: newOp->getNumResults());
4108 for (auto [oldResult, newResult] :
4109 llvm::zip(t: op->getResults(), u: newOp->getResults())) {
4110 if (newResult.getType() != oldResult.getType()) {
4111 replacements.push_back(Elt: rewriter.create<tensor::CastOp>(
4112 location: op->getLoc(), args: oldResult.getType(), args&: newResult));
4113 } else {
4114 replacements.push_back(Elt: newResult);
4115 }
4116 }
4117 rewriter.replaceOp(op, newValues: replacements);
4118
4119 return success();
4120 }
4121};
4122
4123//===----------------------------------------------------------------------===//
4124// TensorDialect
4125//===----------------------------------------------------------------------===//
4126
4127void TensorDialect::getCanonicalizationPatterns(
4128 RewritePatternSet &results) const {
4129 results.add<FoldTensorCastProducerOp>(arg: getContext());
4130}
4131
4132//===----------------------------------------------------------------------===//
4133// TableGen'd op method definitions
4134//===----------------------------------------------------------------------===//
4135
4136#define GET_OP_CLASSES
4137#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
4138

source code of mlir/lib/Dialect/Tensor/IR/TensorOps.cpp