1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15
16#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Arith/Utils/Utils.h"
20#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/Tensor/IR/Tensor.h"
23#include "mlir/Dialect/UB/IR/UBOps.h"
24#include "mlir/Dialect/Utils/IndexingUtils.h"
25#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
26#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/BuiltinAttributes.h"
30#include "mlir/IR/BuiltinTypes.h"
31#include "mlir/IR/DialectImplementation.h"
32#include "mlir/IR/IRMapping.h"
33#include "mlir/IR/OpImplementation.h"
34#include "mlir/IR/PatternMatch.h"
35#include "mlir/IR/TypeUtilities.h"
36#include "mlir/IR/ValueRange.h"
37#include "mlir/Interfaces/SubsetOpInterface.h"
38#include "mlir/Interfaces/ValueBoundsOpInterface.h"
39#include "mlir/Support/LLVM.h"
40#include "mlir/Transforms/InliningUtils.h"
41#include "llvm/ADT/ArrayRef.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/StringSet.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/Support/Casting.h"
47
48#include <cassert>
49#include <cstdint>
50
51#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
52// Pull in all enum type and utility function definitions.
53#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
54
55using namespace mlir;
56using namespace mlir::vector;
57
58/// Helper enum to classify mask value.
59enum class MaskFormat {
60 AllTrue = 0,
61 AllFalse = 1,
62 Unknown = 2,
63};
64
65/// Helper method to classify a mask value. Currently, the method
66/// looks "under the hood" of a constant value with dense attributes
67/// and a constant mask operation (since the client may be called at
68/// various stages during progressive lowering).
69static MaskFormat getMaskFormat(Value mask) {
70 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
71 // Inspect constant dense values. We count up for bits that
72 // are set, count down for bits that are cleared, and bail
73 // when a mix is detected.
74 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(Val: c.getValue())) {
75 int64_t val = 0;
76 for (bool b : denseElts.getValues<bool>())
77 if (b && val >= 0)
78 val++;
79 else if (!b && val <= 0)
80 val--;
81 else
82 return MaskFormat::Unknown;
83 if (val > 0)
84 return MaskFormat::AllTrue;
85 if (val < 0)
86 return MaskFormat::AllFalse;
87 }
88 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
89 // Inspect constant mask index. If the index exceeds the
90 // dimension size, all bits are set. If the index is zero
91 // or less, no bits are set.
92 ArrayRef<int64_t> masks = m.getMaskDimSizes();
93 auto shape = m.getType().getShape();
94 bool allTrue = true;
95 bool allFalse = true;
96 for (auto [maskIdx, dimSize] : llvm::zip_equal(t&: masks, u&: shape)) {
97 if (maskIdx < dimSize)
98 allTrue = false;
99 if (maskIdx > 0)
100 allFalse = false;
101 }
102 if (allTrue)
103 return MaskFormat::AllTrue;
104 if (allFalse)
105 return MaskFormat::AllFalse;
106 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
107 // Finds all-false create_masks. An all-true create_mask requires all
108 // dims to be constants, so that'll be folded to a constant_mask, then
109 // detected in the constant_mask case.
110 auto maskOperands = m.getOperands();
111 for (Value operand : maskOperands) {
112 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113 int64_t dimSize =
114 llvm::cast<IntegerAttr>(Val: constantOp.getValue()).getInt();
115 if (dimSize <= 0)
116 return MaskFormat::AllFalse;
117 }
118 }
119 return MaskFormat::Unknown;
120 }
121 return MaskFormat::Unknown;
122}
123
124/// Default callback to build a region with a 'vector.yield' terminator with no
125/// arguments.
126void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) {
127 builder.create<vector::YieldOp>(location: loc);
128}
129
130// Helper for verifying combining kinds in contractions and reductions.
131static bool isSupportedCombiningKind(CombiningKind combiningKind,
132 Type elementType) {
133 switch (combiningKind) {
134 case CombiningKind::ADD:
135 case CombiningKind::MUL:
136 return elementType.isIntOrIndexOrFloat();
137 case CombiningKind::MINUI:
138 case CombiningKind::MINSI:
139 case CombiningKind::MAXUI:
140 case CombiningKind::MAXSI:
141 case CombiningKind::AND:
142 case CombiningKind::OR:
143 case CombiningKind::XOR:
144 return elementType.isIntOrIndex();
145 case CombiningKind::MINNUMF:
146 case CombiningKind::MAXNUMF:
147 case CombiningKind::MINIMUMF:
148 case CombiningKind::MAXIMUMF:
149 return llvm::isa<FloatType>(Val: elementType);
150 }
151 return false;
152}
153
154/// Returns the effective rank of the vector to read/write for Xfer Ops
155///
156/// When the element type of the shaped type is _a scalar_, this will simply
157/// return the rank of the vector ( the result for xfer_read or the value to
158/// store for xfer_write).
159///
160/// When the element type of the base shaped type is _a vector_, returns the
161/// difference between the original vector type and the element type of the
162/// shaped type.
163///
164/// EXAMPLE 1 (element type is _a scalar_):
165/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
166/// - shapedType.getElementType() = f32 (rank 0)
167/// - vectorType.getRank() = 2
168/// - Result = 2 - 0 = 2
169///
170/// EXAMPLE 2 (element type is _a vector_):
171/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
172/// - shapedType.getElementType() = vector<20xf32> (rank 1)
173/// - vectorType.getRank() = 1
174/// - Result = 1 - 1 = 0
175///
176/// This is used to determine the number of minor dimensions for identity maps
177/// in vector transfer Ops.
178static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
179 VectorType vectorType) {
180 unsigned elementVectorRank = 0;
181 VectorType elementVectorType =
182 llvm::dyn_cast<VectorType>(Val: shapedType.getElementType());
183 if (elementVectorType)
184 elementVectorRank += elementVectorType.getRank();
185 return vectorType.getRank() - elementVectorRank;
186}
187
188AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
189 VectorType vectorType) {
190 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
191 // TODO: replace once we have 0-d vectors.
192 if (shapedType.getRank() == 0 &&
193 vectorType.getShape() == ArrayRef<int64_t>{1})
194 return AffineMap::get(
195 /*numDims=*/dimCount: 0, /*numSymbols=*/symbolCount: 0,
196 result: getAffineConstantExpr(constant: 0, context: shapedType.getContext()));
197 return AffineMap::getMinorIdentityMap(
198 dims: shapedType.getRank(),
199 results: getEffectiveVectorRankForXferOp(shapedType, vectorType),
200 context: shapedType.getContext());
201}
202
203/// Check if `write` is of a constant splat and the masked `read` is padded with
204/// the same splat value -- meaning it could be the same value as the initial
205/// constant splat.
206static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
207 vector::TransferReadOp read) {
208 auto readMask = read.getMask();
209 auto writeMask = write.getMask();
210 // Check if the masks are consistent. The splat value could be the same if the
211 // read is masked (and padded with the splat value), and the write is unmasked
212 // or has the same mask. Note this does not allow the case where the write is
213 // masked and the read is unmasked, as then the read could be of more elements
214 // than the write (which may not be the same value).
215 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
216 if (!couldBeSameSplat)
217 return false;
218 // Check for constant splat (as the source of the write).
219 DenseElementsAttr splatAttr;
220 if (!matchPattern(value: write.getVector(),
221 pattern: m_Constant<DenseElementsAttr>(bind_value: &splatAttr)) ||
222 !splatAttr.isSplat()) {
223 return false;
224 }
225 // The padding of the read and the constant splat value must be the same.
226 Attribute padAttr;
227 if (!matchPattern(value: read.getPadding(), pattern: m_Constant(bind_value: &padAttr)))
228 return false;
229 return padAttr == splatAttr.getSplatValue<Attribute>();
230}
231
232bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
233 vector::TransferReadOp read) {
234 return !defWrite.hasOutOfBoundsDim() &&
235 defWrite.getIndices() == read.getIndices() &&
236 defWrite.getVectorType() == read.getVectorType() &&
237 defWrite.getPermutationMap() == read.getPermutationMap() &&
238 ((!defWrite.getMask() && !read.getMask()) ||
239 isSplatWriteConsistentWithMaskedRead(write: defWrite, read));
240}
241
242bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
243 vector::TransferWriteOp priorWrite) {
244 return priorWrite.getIndices() == write.getIndices() &&
245 priorWrite.getMask() == write.getMask() &&
246 priorWrite.getVectorType() == write.getVectorType() &&
247 priorWrite.getPermutationMap() == write.getPermutationMap();
248}
249
250bool mlir::vector::isDisjointTransferIndices(
251 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
252 bool testDynamicValueUsingBounds) {
253 // For simplicity only look at transfer of same type.
254 if (transferA.getVectorType() != transferB.getVectorType())
255 return false;
256 unsigned rankOffset = transferA.getLeadingShapedRank();
257 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
258 Value indexA = transferA.getIndices()[i];
259 Value indexB = transferB.getIndices()[i];
260 std::optional<int64_t> cstIndexA = getConstantIntValue(ofr: indexA);
261 std::optional<int64_t> cstIndexB = getConstantIntValue(ofr: indexB);
262
263 if (i < rankOffset) {
264 // For leading dimensions, if we can prove that index are different we
265 // know we are accessing disjoint slices.
266 if (cstIndexA.has_value() && cstIndexB.has_value()) {
267 if (*cstIndexA != *cstIndexB)
268 return true;
269 continue;
270 }
271 if (testDynamicValueUsingBounds) {
272 // First try to see if we can fully compose and simplify the affine
273 // expression as a fast track.
274 FailureOr<uint64_t> delta =
275 affine::fullyComposeAndComputeConstantDelta(value1: indexA, value2: indexB);
276 if (succeeded(Result: delta) && *delta != 0)
277 return true;
278
279 FailureOr<bool> testEqual =
280 ValueBoundsConstraintSet::areEqual(var1: indexA, var2: indexB);
281 if (succeeded(Result: testEqual) && !testEqual.value())
282 return true;
283 }
284 } else {
285 // For this dimension, we slice a part of the memref we need to make sure
286 // the intervals accessed don't overlap.
287 int64_t vectorDim = transferA.getVectorType().getDimSize(idx: i - rankOffset);
288 if (cstIndexA.has_value() && cstIndexB.has_value()) {
289 int64_t distance = std::abs(i: *cstIndexA - *cstIndexB);
290 if (distance >= vectorDim)
291 return true;
292 continue;
293 }
294 if (testDynamicValueUsingBounds) {
295 // First try to see if we can fully compose and simplify the affine
296 // expression as a fast track.
297 FailureOr<int64_t> delta =
298 affine::fullyComposeAndComputeConstantDelta(value1: indexA, value2: indexB);
299 if (succeeded(Result: delta) && std::abs(i: *delta) >= vectorDim)
300 return true;
301
302 FailureOr<int64_t> computeDelta =
303 ValueBoundsConstraintSet::computeConstantDelta(value1: indexA, value2: indexB);
304 if (succeeded(Result: computeDelta)) {
305 if (std::abs(i: computeDelta.value()) >= vectorDim)
306 return true;
307 }
308 }
309 }
310 }
311 return false;
312}
313
314bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
315 VectorTransferOpInterface transferB,
316 bool testDynamicValueUsingBounds) {
317 if (transferA.getBase() != transferB.getBase())
318 return false;
319 return isDisjointTransferIndices(transferA, transferB,
320 testDynamicValueUsingBounds);
321}
322
323// Helper to iterate over n-D vector slice elements. Calculate the next
324// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
325// Modifies the `position` in place. Returns a failure when `position` becomes
326// the end position.
327static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
328 ArrayRef<int64_t> shape,
329 ArrayRef<int64_t> offsets) {
330 for (auto [posInDim, dimSize, offsetInDim] :
331 llvm::reverse(C: llvm::zip_equal(t&: position, u&: shape, args&: offsets))) {
332 ++posInDim;
333 if (posInDim < dimSize + offsetInDim)
334 return success();
335
336 // Carry the overflow to the next loop iteration.
337 posInDim = offsetInDim;
338 }
339
340 return failure();
341}
342
343/// Returns the integer numbers in `values`. `values` are expected to be
344/// constant operations.
345SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) {
346 SmallVector<int64_t> ints;
347 llvm::transform(Range&: values, d_first: std::back_inserter(x&: ints), F: [](Value value) {
348 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
349 assert(constOp && "Unexpected non-constant index");
350 return constOp.value();
351 });
352 return ints;
353}
354
355/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
356/// be constant operations.
357SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {
358 SmallVector<int64_t> ints;
359 llvm::transform(
360 Range&: foldResults, d_first: std::back_inserter(x&: ints), F: [](OpFoldResult foldResult) {
361 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
362 return cast<IntegerAttr>(Val: cast<Attribute>(Val&: foldResult)).getInt();
363 });
364 return ints;
365}
366
367/// Convert `foldResults` into Values. Integer attributes are converted to
368/// constant op.
369SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
370 ArrayRef<OpFoldResult> foldResults) {
371 SmallVector<Value> values;
372 llvm::transform(Range&: foldResults, d_first: std::back_inserter(x&: values),
373 F: [&](OpFoldResult foldResult) {
374 if (auto attr = dyn_cast<Attribute>(Val&: foldResult))
375 return builder
376 .create<arith::ConstantIndexOp>(
377 location: loc, args: cast<IntegerAttr>(Val&: attr).getInt())
378 .getResult();
379
380 return cast<Value>(Val&: foldResult);
381 });
382 return values;
383}
384
385std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
386 if (value.getDefiningOp<vector::VectorScaleOp>())
387 return 1;
388 auto mul = value.getDefiningOp<arith::MulIOp>();
389 if (!mul)
390 return {};
391 auto lhs = mul.getLhs();
392 auto rhs = mul.getRhs();
393 if (lhs.getDefiningOp<vector::VectorScaleOp>())
394 return getConstantIntValue(ofr: rhs);
395 if (rhs.getDefiningOp<vector::VectorScaleOp>())
396 return getConstantIntValue(ofr: lhs);
397 return {};
398}
399
400/// Converts an IntegerAttr to have the specified type if needed.
401/// This handles cases where constant attributes have a different type than the
402/// target element type. If the input attribute is not an IntegerAttr or already
403/// has the correct type, returns it unchanged.
404static Attribute convertIntegerAttr(Attribute attr, Type expectedType) {
405 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(Val&: attr)) {
406 if (intAttr.getType() != expectedType)
407 return IntegerAttr::get(type: expectedType, value: intAttr.getInt());
408 }
409 return attr;
410}
411
412//===----------------------------------------------------------------------===//
413// CombiningKindAttr
414//===----------------------------------------------------------------------===//
415
416namespace mlir {
417namespace vector {
418namespace detail {
419struct BitmaskEnumStorage : public AttributeStorage {
420 using KeyTy = uint64_t;
421
422 BitmaskEnumStorage(KeyTy val) : value(val) {}
423
424 bool operator==(const KeyTy &key) const { return value == key; }
425
426 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
427 const KeyTy &key) {
428 return new (allocator.allocate<BitmaskEnumStorage>())
429 BitmaskEnumStorage(key);
430 }
431
432 KeyTy value = 0;
433};
434} // namespace detail
435} // namespace vector
436} // namespace mlir
437
438//===----------------------------------------------------------------------===//
439// VectorDialect
440//===----------------------------------------------------------------------===//
441
442namespace {
443/// This class defines the interface for handling inlining with vector dialect
444/// operations.
445struct VectorInlinerInterface : public DialectInlinerInterface {
446 using DialectInlinerInterface::DialectInlinerInterface;
447
448 /// All vector dialect ops can be inlined.
449 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
450 return true;
451 }
452};
453} // namespace
454
455void VectorDialect::initialize() {
456 addAttributes<
457#define GET_ATTRDEF_LIST
458#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
459 >();
460
461 addOperations<
462#define GET_OP_LIST
463#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
464 >();
465
466 addInterfaces<VectorInlinerInterface>();
467
468 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
469 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
470 YieldOp>();
471 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
472 TransferWriteOp>();
473 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
474 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
475 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
476}
477
478/// Materialize a single constant operation from a given attribute value with
479/// the desired resultant type.
480Operation *VectorDialect::materializeConstant(OpBuilder &builder,
481 Attribute value, Type type,
482 Location loc) {
483 if (isa<ub::PoisonAttrInterface>(Val: value))
484 return value.getDialect().materializeConstant(builder, value, type, loc);
485
486 return arith::ConstantOp::materialize(builder, value, type, loc);
487}
488
489IntegerType vector::getVectorSubscriptType(Builder &builder) {
490 return builder.getIntegerType(width: 64);
491}
492
493ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
494 ArrayRef<int64_t> values) {
495 return builder.getI64ArrayAttr(values);
496}
497
498//===----------------------------------------------------------------------===//
499// MultiDimReductionOp
500//===----------------------------------------------------------------------===//
501
502void vector::MultiDimReductionOp::build(OpBuilder &builder,
503 OperationState &result, Value source,
504 Value acc, ArrayRef<bool> reductionMask,
505 CombiningKind kind) {
506 SmallVector<int64_t> reductionDims;
507 for (const auto &en : llvm::enumerate(First&: reductionMask))
508 if (en.value())
509 reductionDims.push_back(Elt: en.index());
510 build(odsBuilder&: builder, odsState&: result, kind, source, acc, reduction_dims: reductionDims);
511}
512
513OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
514 // Single parallel dim, this is a noop.
515 if (getSourceVectorType().getRank() == 1 && !isReducedDim(d: 0))
516 return getSource();
517 return {};
518}
519
520std::optional<SmallVector<int64_t, 4>>
521MultiDimReductionOp::getShapeForUnroll() {
522 return llvm::to_vector<4>(Range: getSourceVectorType().getShape());
523}
524
525LogicalResult MultiDimReductionOp::verify() {
526 SmallVector<int64_t> targetShape;
527 SmallVector<bool> scalableDims;
528 Type inferredReturnType;
529 auto sourceScalableDims = getSourceVectorType().getScalableDims();
530 for (auto [dimIdx, dimSize] :
531 llvm::enumerate(First: getSourceVectorType().getShape()))
532 if (!llvm::any_of(Range: getReductionDims(),
533 P: [dimIdx = dimIdx](int64_t reductionDimIdx) {
534 return reductionDimIdx == static_cast<int64_t>(dimIdx);
535 })) {
536 targetShape.push_back(Elt: dimSize);
537 scalableDims.push_back(Elt: sourceScalableDims[dimIdx]);
538 }
539 // TODO: update to also allow 0-d vectors when available.
540 if (targetShape.empty())
541 inferredReturnType = getSourceVectorType().getElementType();
542 else
543 inferredReturnType = VectorType::get(
544 shape: targetShape, elementType: getSourceVectorType().getElementType(), scalableDims);
545 if (getType() != inferredReturnType)
546 return emitOpError() << "destination type " << getType()
547 << " is incompatible with source type "
548 << getSourceVectorType();
549
550 return success();
551}
552
553/// Returns the mask type expected by this operation.
554Type MultiDimReductionOp::getExpectedMaskType() {
555 auto vecType = getSourceVectorType();
556 return VectorType::get(shape: vecType.getShape(),
557 elementType: IntegerType::get(context: vecType.getContext(), /*width=*/1),
558 scalableDims: vecType.getScalableDims());
559}
560
561namespace {
562// Only unit dimensions that are being reduced are folded. If the dimension is
563// unit, but not reduced, it is not folded, thereby keeping the output type the
564// same. If not all dimensions which are reduced are of unit dimension, this
565// transformation does nothing. This is just a generalization of
566// ElideSingleElementReduction for ReduceOp.
567struct ElideUnitDimsInMultiDimReduction
568 : public OpRewritePattern<MultiDimReductionOp> {
569 using OpRewritePattern::OpRewritePattern;
570
571 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
572 PatternRewriter &rewriter) const override {
573 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
574 for (const auto &dim : enumerate(First&: shape)) {
575 if (reductionOp.isReducedDim(d: dim.index()) && dim.value() != 1)
576 return failure();
577 }
578
579 // Vector mask setup.
580 OpBuilder::InsertionGuard guard(rewriter);
581 Operation *rootOp;
582 Value mask;
583 if (reductionOp.isMasked()) {
584 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
585 rootOp = reductionOp.getMaskingOp();
586 mask = reductionOp.getMaskingOp().getMask();
587 } else {
588 rootOp = reductionOp;
589 }
590
591 Location loc = reductionOp.getLoc();
592 Value acc = reductionOp.getAcc();
593 Value cast;
594 if (auto dstVecType = dyn_cast<VectorType>(Val: reductionOp.getDestType())) {
595 if (mask) {
596 VectorType newMaskType =
597 VectorType::get(shape: dstVecType.getShape(), elementType: rewriter.getI1Type(),
598 scalableDims: dstVecType.getScalableDims());
599 mask = rewriter.create<vector::ShapeCastOp>(location: loc, args&: newMaskType, args&: mask);
600 }
601 cast = rewriter.create<vector::ShapeCastOp>(
602 location: loc, args: reductionOp.getDestType(), args: reductionOp.getSource());
603 } else {
604 // This means we are reducing all the dimensions, and all reduction
605 // dimensions are of size 1. So a simple extraction would do.
606 if (mask)
607 mask = rewriter.create<vector::ExtractOp>(location: loc, args&: mask);
608 cast = rewriter.create<vector::ExtractOp>(location: loc, args: reductionOp.getSource());
609 }
610
611 Value result =
612 vector::makeArithReduction(b&: rewriter, loc, kind: reductionOp.getKind(), v1: acc,
613 acc: cast, /*fastmath=*/nullptr, mask);
614 rewriter.replaceOp(op: rootOp, newValues: result);
615 return success();
616 }
617};
618} // namespace
619
620void MultiDimReductionOp::getCanonicalizationPatterns(
621 RewritePatternSet &results, MLIRContext *context) {
622 results.add<ElideUnitDimsInMultiDimReduction>(arg&: context);
623}
624
625//===----------------------------------------------------------------------===//
626// ReductionOp
627//===----------------------------------------------------------------------===//
628
629void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
630 CombiningKind kind, Value vector,
631 arith::FastMathFlags fastMathFlags) {
632 build(odsBuilder&: builder, odsState&: result, kind, vector, /*acc=*/Value(), fastMathFlags);
633}
634
635void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
636 CombiningKind kind, Value vector, Value acc,
637 arith::FastMathFlags fastMathFlags) {
638 build(odsBuilder&: builder, odsState&: result,
639 dest: llvm::cast<VectorType>(Val: vector.getType()).getElementType(), kind, vector,
640 acc, fastmath: fastMathFlags);
641}
642
643LogicalResult ReductionOp::verify() {
644 // Verify for 0-D and 1-D vector.
645 int64_t rank = getSourceVectorType().getRank();
646 if (rank > 1)
647 return emitOpError(message: "unsupported reduction rank: ") << rank;
648
649 // Verify supported reduction kind.
650 Type eltType = getDest().getType();
651 if (!isSupportedCombiningKind(combiningKind: getKind(), elementType: eltType))
652 return emitOpError(message: "unsupported reduction type '")
653 << eltType << "' for kind '" << stringifyCombiningKind(val: getKind())
654 << "'";
655
656 return success();
657}
658
659// MaskableOpInterface methods.
660
661/// Returns the mask type expected by this operation.
662Type ReductionOp::getExpectedMaskType() {
663 auto vecType = getSourceVectorType();
664 return VectorType::get(shape: vecType.getShape(),
665 elementType: IntegerType::get(context: vecType.getContext(), /*width=*/1),
666 scalableDims: vecType.getScalableDims());
667}
668
669Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
670 OpBuilder &builder, Location loc,
671 Value vector) {
672 switch (op) {
673 case arith::AtomicRMWKind::addf:
674 case arith::AtomicRMWKind::addi:
675 return builder.create<vector::ReductionOp>(location: vector.getLoc(),
676 args: CombiningKind::ADD, args&: vector);
677 case arith::AtomicRMWKind::mulf:
678 case arith::AtomicRMWKind::muli:
679 return builder.create<vector::ReductionOp>(location: vector.getLoc(),
680 args: CombiningKind::MUL, args&: vector);
681 case arith::AtomicRMWKind::minimumf:
682 return builder.create<vector::ReductionOp>(location: vector.getLoc(),
683 args: CombiningKind::MINIMUMF, args&: vector);
684 case arith::AtomicRMWKind::mins:
685 return builder.create<vector::ReductionOp>(location: vector.getLoc(),
686 args: CombiningKind::MINSI, args&: vector);
687 case arith::AtomicRMWKind::minu:
688 return builder.create<vector::ReductionOp>(location: vector.getLoc(),
689 args: CombiningKind::MINUI, args&: vector);
690 case arith::AtomicRMWKind::maximumf:
691 return builder.create<vector::ReductionOp>(location: vector.getLoc(),
692 args: CombiningKind::MAXIMUMF, args&: vector);
693 case arith::AtomicRMWKind::maxs:
694 return builder.create<vector::ReductionOp>(location: vector.getLoc(),
695 args: CombiningKind::MAXSI, args&: vector);
696 case arith::AtomicRMWKind::maxu:
697 return builder.create<vector::ReductionOp>(location: vector.getLoc(),
698 args: CombiningKind::MAXUI, args&: vector);
699 case arith::AtomicRMWKind::andi:
700 return builder.create<vector::ReductionOp>(location: vector.getLoc(),
701 args: CombiningKind::AND, args&: vector);
702 case arith::AtomicRMWKind::ori:
703 return builder.create<vector::ReductionOp>(location: vector.getLoc(),
704 args: CombiningKind::OR, args&: vector);
705 // TODO: Add remaining reduction operations.
706 default:
707 (void)emitOptionalError(loc, args: "Reduction operation type not supported");
708 break;
709 }
710 return nullptr;
711}
712
713std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
714 return llvm::to_vector<4>(Range: getSourceVectorType().getShape());
715}
716
717namespace {
718struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
719 using OpRewritePattern::OpRewritePattern;
720
721 LogicalResult matchAndRewrite(ReductionOp reductionOp,
722 PatternRewriter &rewriter) const override {
723 // Vector mask setup.
724 OpBuilder::InsertionGuard guard(rewriter);
725 auto maskableOp =
726 cast<vector::MaskableOpInterface>(Val: reductionOp.getOperation());
727 Operation *rootOp;
728 Value mask;
729 if (maskableOp.isMasked()) {
730 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
731 rootOp = maskableOp.getMaskingOp();
732 mask = maskableOp.getMaskingOp().getMask();
733 } else {
734 rootOp = reductionOp;
735 }
736
737 auto vectorType = reductionOp.getSourceVectorType();
738 if (vectorType.getRank() != 0 && vectorType.getDimSize(idx: 0) != 1)
739 return failure();
740
741 Location loc = reductionOp.getLoc();
742 if (mask)
743 mask = rewriter.create<ExtractOp>(location: loc, args&: mask);
744 Value result = rewriter.create<ExtractOp>(location: loc, args: reductionOp.getVector());
745
746 if (Value acc = reductionOp.getAcc())
747 result = vector::makeArithReduction(b&: rewriter, loc, kind: reductionOp.getKind(),
748 v1: result, acc,
749 fastmath: reductionOp.getFastmathAttr(), mask);
750
751 rewriter.replaceOp(op: rootOp, newValues: result);
752 return success();
753 }
754};
755} // namespace
756
757void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
758 MLIRContext *context) {
759 results.add<ElideSingleElementReduction>(arg&: context);
760}
761
762//===----------------------------------------------------------------------===//
763// ContractionOp
764//===----------------------------------------------------------------------===//
765
766void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
767 Value lhs, Value rhs, Value acc,
768 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
769 ArrayRef<IteratorType> iteratorTypes) {
770 result.addOperands(newOperands: {lhs, rhs, acc});
771 result.addTypes(newTypes: acc.getType());
772 result.addAttribute(
773 name: getIndexingMapsAttrName(name: result.name),
774 attr: builder.getAffineMapArrayAttr(
775 values: AffineMap::inferFromExprList(exprsList: indexingExprs, context: builder.getContext())));
776 result.addAttribute(
777 name: getIteratorTypesAttrName(name: result.name),
778 attr: builder.getArrayAttr(value: llvm::to_vector(Range: llvm::map_range(
779 C&: iteratorTypes, F: [&](IteratorType t) -> mlir::Attribute {
780 return IteratorTypeAttr::get(context: builder.getContext(), value: t);
781 }))));
782}
783
784void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
785 Value lhs, Value rhs, Value acc,
786 ArrayAttr indexingMaps,
787 ArrayAttr iteratorTypes) {
788 build(odsBuilder&: builder, odsState&: result, lhs, rhs, acc, indexingMaps, iteratorTypes,
789 kind: ContractionOp::getDefaultKind());
790}
791
792void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
793 Value lhs, Value rhs, Value acc,
794 ArrayAttr indexingMaps,
795 ArrayAttr iteratorTypes, CombiningKind kind) {
796 result.addOperands(newOperands: {lhs, rhs, acc});
797 result.addTypes(newTypes: acc.getType());
798 result.addAttribute(name: getIndexingMapsAttrName(name: result.name), attr: indexingMaps);
799 result.addAttribute(name: getIteratorTypesAttrName(name: result.name), attr: iteratorTypes);
800 result.addAttribute(name: getKindAttrName(name: result.name),
801 attr: CombiningKindAttr::get(context: builder.getContext(), value: kind));
802}
803
804ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
805 OpAsmParser::UnresolvedOperand lhsInfo;
806 OpAsmParser::UnresolvedOperand rhsInfo;
807 OpAsmParser::UnresolvedOperand accInfo;
808 SmallVector<OpAsmParser::UnresolvedOperand, 2> masksInfo;
809 SmallVector<Type, 2> types;
810 Type resultType;
811 auto loc = parser.getCurrentLocation();
812 DictionaryAttr dictAttr;
813 // TODO: Unify linalg op attribute parsing.
814 if (parser.parseAttribute(result&: dictAttr) || parser.parseOperand(result&: lhsInfo) ||
815 parser.parseComma() || parser.parseOperand(result&: rhsInfo) ||
816 parser.parseComma() || parser.parseOperand(result&: accInfo) ||
817 parser.parseTrailingOperandList(result&: masksInfo) ||
818 parser.parseOptionalAttrDict(result&: result.attributes) ||
819 parser.parseColonTypeList(result&: types) ||
820 parser.parseKeywordType(keyword: "into", result&: resultType) ||
821 parser.resolveOperand(operand: lhsInfo, type: types[0], result&: result.operands) ||
822 parser.resolveOperand(operand: rhsInfo, type: types[1], result&: result.operands) ||
823 parser.resolveOperand(operand: accInfo, type: resultType, result&: result.operands) ||
824 parser.addTypeToList(type: resultType, result&: result.types))
825 return failure();
826 result.attributes.append(inStart: dictAttr.getValue().begin(),
827 inEnd: dictAttr.getValue().end());
828
829 // Convert array of string into an array of IteratyType enums. This is needed,
830 // because tests still use the old format when 'iterator_types' attribute is
831 // represented as an array of strings.
832 // TODO: Remove this conversion once tests are fixed.
833 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
834 Val: result.attributes.get(name: getIteratorTypesAttrName(name: result.name)));
835 if (!iteratorTypes) {
836 return parser.emitError(loc)
837 << "expected " << getIteratorTypesAttrName(name: result.name)
838 << " array attribute";
839 }
840
841 SmallVector<Attribute> iteratorTypeAttrs;
842
843 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
844 auto maybeIteratorType = symbolizeIteratorType(str: s);
845 if (!maybeIteratorType.has_value())
846 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
847
848 iteratorTypeAttrs.push_back(
849 Elt: IteratorTypeAttr::get(context: parser.getContext(), value: maybeIteratorType.value()));
850 }
851 result.attributes.set(name: getIteratorTypesAttrName(name: result.name),
852 value: parser.getBuilder().getArrayAttr(value: iteratorTypeAttrs));
853
854 if (!result.attributes.get(name: getKindAttrName(name: result.name))) {
855 result.addAttribute(
856 name: getKindAttrName(name: result.name),
857 attr: CombiningKindAttr::get(context: result.getContext(),
858 value: ContractionOp::getDefaultKind()));
859 }
860 if (masksInfo.empty())
861 return success();
862 if (masksInfo.size() != 2)
863 return parser.emitError(loc: parser.getNameLoc(),
864 message: "expected zero or exactly 2 vector mask operands");
865 auto lhsType = llvm::cast<VectorType>(Val&: types[0]);
866 auto rhsType = llvm::cast<VectorType>(Val&: types[1]);
867 auto maskElementType = parser.getBuilder().getI1Type();
868 std::array<VectorType, 2> maskTypes = {
869 VectorType::Builder(lhsType).setElementType(maskElementType),
870 VectorType::Builder(rhsType).setElementType(maskElementType)};
871 if (parser.resolveOperands(operands&: masksInfo, types&: maskTypes, loc, result&: result.operands))
872 return failure();
873 return success();
874}
875
876void ContractionOp::print(OpAsmPrinter &p) {
877 // TODO: Unify printing code with linalg ops.
878 auto attrNames = getTraitAttrNames();
879 llvm::StringSet<> traitAttrsSet;
880 traitAttrsSet.insert_range(R&: attrNames);
881 SmallVector<NamedAttribute, 8> attrs;
882 for (auto attr : (*this)->getAttrs()) {
883 if (attr.getName() == getIteratorTypesAttrName()) {
884 auto iteratorTypes =
885 llvm::cast<ArrayAttr>(Val: attr.getValue())
886 .getAsValueRange<IteratorTypeAttr, IteratorType>();
887 // Convert IteratorType enums into the string representation. This is
888 // needed, because tests still use the old format when 'iterator_types'
889 // attribute is represented as an array of strings.
890 // TODO: Remove this conversion once tests are fixed.
891 SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
892 Range: llvm::map_range(C&: iteratorTypes, F: [&](IteratorType t) -> Attribute {
893 return StringAttr::get(context: getContext(), bytes: stringifyIteratorType(val: t));
894 }));
895
896 attrs.emplace_back(Args: getIteratorTypesAttrName(),
897 Args: ArrayAttr::get(context: getContext(), value: iteratorTypeNames));
898 } else if (traitAttrsSet.count(Key: attr.getName().strref()) > 0)
899 attrs.push_back(Elt: attr);
900 }
901
902 auto dictAttr = DictionaryAttr::get(context: getContext(), value: attrs);
903 p << " " << dictAttr << " " << getLhs() << ", ";
904 p << getRhs() << ", " << getAcc();
905
906 p.printOptionalAttrDict(attrs: (*this)->getAttrs(), elidedAttrs: attrNames);
907 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
908 << getResultType();
909}
910
911static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
912 const std::vector<std::pair<int64_t, int64_t>> &map) {
913 for (auto &dimPair : map) {
914 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
915 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
916 lhsType.getDimSize(idx: dimPair.first) != rhsType.getDimSize(idx: dimPair.second))
917 return false;
918 }
919 return true;
920}
921
922static LogicalResult verifyOutputShape(
923 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
924 Type resType,
925 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
926 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
927 DenseSet<int64_t> lhsContractingDimSet;
928 DenseSet<int64_t> rhsContractingDimSet;
929 for (auto &dimPair : contractingDimMap) {
930 lhsContractingDimSet.insert(V: dimPair.first);
931 rhsContractingDimSet.insert(V: dimPair.second);
932 }
933 DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
934 llvm::make_second_range(c: batchDimMap));
935
936 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
937 SmallVector<int64_t, 4> expectedResultDims;
938 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
939 if (lhsContractingDimSet.count(V: i) > 0)
940 continue;
941 expectedResultDims.push_back(Elt: lhsType.getDimSize(idx: i));
942 }
943
944 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
945 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
946 if (rhsContractingDimSet.count(V: i) > 0 || rhsBatchDimSet.count(V: i) > 0)
947 continue;
948 expectedResultDims.push_back(Elt: rhsType.getDimSize(idx: i));
949 }
950
951 // Verify 'expectedResultDims'.
952 if (expectedResultDims.empty()) {
953 // No batch or free dimension implies a scalar result.
954 if (llvm::isa<VectorType>(Val: resType) || llvm::isa<VectorType>(Val: accType))
955 return op.emitOpError(message: "invalid accumulator/result vector shape");
956 } else {
957 // At least one batch or free dimension implies a vector result.
958 auto resVectorType = llvm::dyn_cast<VectorType>(Val&: resType);
959 auto accVectorType = llvm::dyn_cast<VectorType>(Val&: accType);
960 if (!resVectorType || !accVectorType)
961 return op.emitOpError(message: "invalid accumulator/result vector shape");
962
963 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
964 // types fully define the result vector type. This assumes the affine maps
965 // are well-formed, which must have been verified already.
966 MLIRContext *ctx = op.getContext();
967 AffineMap lhsMap = op.getIndexingMapsArray()[0];
968 AffineMap rhsMap = op.getIndexingMapsArray()[1];
969 if (getUnusedDimsBitVector(maps: {lhsMap, rhsMap}).any())
970 return op.emitOpError(
971 message: "expected all dimensions to be either a LHS or a RHS dimension");
972 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
973 for (auto pair :
974 {std::make_pair(x&: lhsType, y&: lhsMap), std::make_pair(x&: rhsType, y&: rhsMap)}) {
975 VectorType v = pair.first;
976 auto map = pair.second;
977 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
978 unsigned pos = map.getDimPosition(idx);
979 if (!extents[pos])
980 extents[pos] = getAffineConstantExpr(constant: v.getShape()[idx], context: ctx);
981 }
982 }
983 if (!llvm::all_of(Range&: extents, P: [](AffineExpr e) { return e; }))
984 return op.emitOpError(message: "expected all dimensions to get an extent as "
985 "either a LHS or a RHS dimension");
986
987 AffineMap resMap = op.getIndexingMapsArray()[2];
988 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
989 /*symbolCount=*/0, results: extents, context: ctx);
990 // Compose the resMap with the extentsMap, which is a constant map.
991 AffineMap expectedMap = simplifyAffineMap(map: resMap.compose(map: extentsMap));
992 assert(llvm::all_of(expectedMap.getResults(),
993 llvm::IsaPred<AffineConstantExpr>) &&
994 "expected constant extent along all dimensions.");
995 // Extract the expected shape and build the type.
996 auto expectedShape = llvm::to_vector<4>(
997 Range: llvm::map_range(C: expectedMap.getResults(), F: [](AffineExpr e) {
998 return cast<AffineConstantExpr>(Val&: e).getValue();
999 }));
1000 auto expected =
1001 VectorType::get(shape: expectedShape, elementType: resVectorType.getElementType(),
1002 scalableDims: resVectorType.getScalableDims());
1003 if (resVectorType != expected || accVectorType != expected)
1004 return op.emitOpError(
1005 message: "invalid accumulator/result vector shape, expected: ")
1006 << expected;
1007 }
1008 return success();
1009}
1010
1011LogicalResult ContractionOp::verify() {
1012 VectorType lhsType = getLhsType();
1013 VectorType rhsType = getRhsType();
1014 Type accType = getAccType();
1015 Type resType = getResultType();
1016
1017 if (llvm::isa<IntegerType>(Val: lhsType.getElementType())) {
1018 if (!lhsType.getElementType().isSignlessInteger())
1019 return emitOpError(message: "only supports signless integer types");
1020 }
1021
1022 // Verify that an indexing map was specified for each vector operand.
1023 if (getIndexingMapsArray().size() != 3)
1024 return emitOpError(message: "expected an indexing map for each vector operand");
1025
1026 // Verify that each index map has 'numIterators' inputs, no symbols, and
1027 // that the number of map outputs equals the rank of its associated
1028 // vector operand.
1029 unsigned numIterators = getIteratorTypes().getValue().size();
1030 for (const auto &it : llvm::enumerate(First: getIndexingMapsArray())) {
1031 auto index = it.index();
1032 auto map = it.value();
1033 if (map.getNumSymbols() != 0)
1034 return emitOpError(message: "expected indexing map ")
1035 << index << " to have no symbols";
1036 auto vectorType = llvm::dyn_cast<VectorType>(Val: getOperand(i: index).getType());
1037 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1038 // Verify that the map has the right number of inputs, outputs, and indices.
1039 // This also correctly accounts for (..) -> () for rank-0 results.
1040 if (map.getNumDims() != numIterators)
1041 return emitOpError(message: "expected indexing map ")
1042 << index << " to have " << numIterators << " number of inputs";
1043 if (map.getNumResults() != rank)
1044 return emitOpError(message: "expected indexing map ")
1045 << index << " to have " << rank << " number of outputs";
1046 if (!map.isProjectedPermutation())
1047 return emitOpError(message: "expected indexing map ")
1048 << index << " to be a projected permutation of its inputs";
1049 }
1050
1051 auto contractingDimMap = getContractingDimMap();
1052 auto batchDimMap = getBatchDimMap();
1053
1054 // Verify at least one contracting dimension pair was specified.
1055 if (contractingDimMap.empty())
1056 return emitOpError(message: "expected at least one contracting dimension pair");
1057
1058 // Verify contracting dimension map was properly constructed.
1059 if (!verifyDimMap(lhsType, rhsType, map: contractingDimMap))
1060 return emitOpError(message: "invalid contracting dimension map");
1061
1062 // Verify batch dimension map was properly constructed.
1063 if (!verifyDimMap(lhsType, rhsType, map: batchDimMap))
1064 return emitOpError(message: "invalid batch dimension map");
1065
1066 // Verify 'accType' and 'resType' shape.
1067 if (failed(Result: verifyOutputShape(op: *this, lhsType, rhsType, accType, resType,
1068 contractingDimMap, batchDimMap)))
1069 return failure();
1070
1071 // Verify supported combining kind.
1072 auto vectorType = llvm::dyn_cast<VectorType>(Val&: resType);
1073 auto elementType = vectorType ? vectorType.getElementType() : resType;
1074 if (!isSupportedCombiningKind(combiningKind: getKind(), elementType))
1075 return emitOpError(message: "unsupported contraction type");
1076
1077 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1078 return cast<IndexingMapOpInterface>(Val: this->getOperation()).verifyImpl();
1079}
1080
1081// MaskableOpInterface methods.
1082
1083/// Returns the mask type expected by this operation. Mostly used for
1084/// verification purposes. It requires the operation to be vectorized."
1085Type ContractionOp::getExpectedMaskType() {
1086 auto indexingMaps = this->getIndexingMapsArray();
1087 AffineMap lhsIdxMap = indexingMaps[0];
1088 AffineMap rhsIdxMap = indexingMaps[1];
1089 VectorType lhsType = this->getLhsType();
1090 VectorType rhsType = this->getRhsType();
1091
1092 unsigned numVecDims = lhsIdxMap.getNumDims();
1093 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1094 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1095
1096 // Using the information in the indexing maps, extract the size of each
1097 // dimension in the vector.contract operation from the two input operands.
1098 for (auto [dimIdx, dimSize] : llvm::enumerate(First: lhsType.getShape())) {
1099 maskShape[lhsIdxMap.getDimPosition(idx: dimIdx)] = dimSize;
1100 maskShapeScalableDims[lhsIdxMap.getDimPosition(idx: dimIdx)] =
1101 lhsType.getScalableDims()[dimIdx];
1102 }
1103 for (auto [dimIdx, dimSize] : llvm::enumerate(First: rhsType.getShape())) {
1104 maskShape[rhsIdxMap.getDimPosition(idx: dimIdx)] = dimSize;
1105 maskShapeScalableDims[rhsIdxMap.getDimPosition(idx: dimIdx)] =
1106 rhsType.getScalableDims()[dimIdx];
1107 }
1108
1109 assert(ShapedType::isStaticShape(maskShape) &&
1110 "Mask shape couldn't be computed");
1111
1112 return VectorType::get(shape: maskShape,
1113 elementType: IntegerType::get(context: lhsType.getContext(), /*width=*/1),
1114 scalableDims: maskShapeScalableDims);
1115}
1116
1117SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1118 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1119 getIteratorTypesAttrName(), getKindAttrName()};
1120}
1121
1122static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1123 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1124 if (targetExpr == map.getResult(idx: i))
1125 return i;
1126 return -1;
1127}
1128
1129static std::vector<std::pair<int64_t, int64_t>>
1130getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1131 IteratorType targetIteratorType, MLIRContext *context) {
1132 std::vector<std::pair<int64_t, int64_t>> dimMap;
1133 for (const auto &it : llvm::enumerate(First&: iteratorTypes)) {
1134 auto iteratorType = llvm::cast<IteratorTypeAttr>(Val: it.value()).getValue();
1135 if (iteratorType != targetIteratorType)
1136 continue;
1137 // Search lhs/rhs map results for 'targetExpr'.
1138 auto targetExpr = getAffineDimExpr(position: it.index(), context);
1139 int64_t lhsDim = getResultIndex(map: indexingMaps[0], targetExpr);
1140 int64_t rhsDim = getResultIndex(map: indexingMaps[1], targetExpr);
1141 if (lhsDim >= 0 && rhsDim >= 0)
1142 dimMap.emplace_back(args&: lhsDim, args&: rhsDim);
1143 }
1144 return dimMap;
1145}
1146
1147void ContractionOp::getIterationBounds(
1148 SmallVectorImpl<int64_t> &iterationBounds) {
1149 auto lhsShape = getLhsType().getShape();
1150 auto resVectorType = llvm::dyn_cast<VectorType>(Val: getResultType());
1151 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1152 for (const auto &it : llvm::enumerate(First: getIteratorTypes())) {
1153 // Search lhs/rhs map results for 'targetExpr'.
1154 auto targetExpr = getAffineDimExpr(position: it.index(), context: getContext());
1155 auto iteratorType = llvm::cast<IteratorTypeAttr>(Val: it.value()).getValue();
1156 if (iteratorType == IteratorType::reduction) {
1157 // Get reduction dim size from lhs shape (same size in rhsShape).
1158 int64_t lhsDimIndex = getResultIndex(map: indexingMaps[0], targetExpr);
1159 assert(lhsDimIndex >= 0);
1160 iterationBounds.push_back(Elt: lhsShape[lhsDimIndex]);
1161 continue;
1162 }
1163 // Get parallel dimension size from result shape.
1164 int64_t resDimIndex = getResultIndex(map: indexingMaps[2], targetExpr);
1165 assert(resDimIndex >= 0);
1166 assert(resVectorType != nullptr);
1167 iterationBounds.push_back(Elt: resVectorType.getShape()[resDimIndex]);
1168 }
1169}
1170
1171void ContractionOp::getIterationIndexMap(
1172 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1173 unsigned numMaps = getIndexingMapsArray().size();
1174 iterationIndexMap.resize(new_size: numMaps);
1175 for (const auto &it : llvm::enumerate(First: getIndexingMapsArray())) {
1176 auto index = it.index();
1177 auto map = it.value();
1178 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1179 auto dim = cast<AffineDimExpr>(Val: map.getResult(idx: i));
1180 iterationIndexMap[index][dim.getPosition()] = i;
1181 }
1182 }
1183}
1184
1185std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1186 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1187 return getDimMap(indexingMaps, iteratorTypes: getIteratorTypes(), targetIteratorType: IteratorType::reduction,
1188 context: getContext());
1189}
1190
1191std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1192 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1193 return getDimMap(indexingMaps, iteratorTypes: getIteratorTypes(), targetIteratorType: IteratorType::parallel,
1194 context: getContext());
1195}
1196
1197std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1198 SmallVector<int64_t, 4> shape;
1199 getIterationBounds(iterationBounds&: shape);
1200 return shape;
1201}
1202
1203/// Return a fused vector::ContractionOp which represents a patterns such as:
1204///
1205/// ```mlir
1206/// %c0 = vector.constant 0: ...
1207/// %c = vector.contract %a, %b, %c0: ...
1208/// %e = add %c, %d: ...
1209/// ```
1210///
1211/// by:
1212///
1213/// ```mlir
1214/// %e = vector.contract %a, %b, %d: ...
1215/// ```
1216///
1217/// Return null if the canonicalization does not apply.
1218// TODO: This should be a folding of Add into Contract in core but while they
1219// live in different dialects, it is not possible without unnatural
1220// dependencies.
1221template <typename AddOpType>
1222struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1223 using OpRewritePattern<AddOpType>::OpRewritePattern;
1224
1225 LogicalResult matchAndRewrite(AddOpType addOp,
1226 PatternRewriter &rewriter) const override {
1227 auto canonicalize = [&](Value maybeContraction,
1228 Value otherOperand) -> vector::ContractionOp {
1229 vector::ContractionOp contractionOp =
1230 dyn_cast_or_null<vector::ContractionOp>(
1231 Val: maybeContraction.getDefiningOp());
1232 if (!contractionOp)
1233 return vector::ContractionOp();
1234 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1235 Val: contractionOp.getAcc().getDefiningOp())) {
1236 if (maybeZero.getValue() ==
1237 rewriter.getZeroAttr(type: contractionOp.getAcc().getType())) {
1238 IRMapping bvm;
1239 bvm.map(from: contractionOp.getAcc(), to: otherOperand);
1240 auto newContraction =
1241 cast<vector::ContractionOp>(Val: rewriter.clone(op&: *contractionOp, mapper&: bvm));
1242 rewriter.replaceOp(addOp, newContraction.getResult());
1243 return newContraction;
1244 }
1245 }
1246 return vector::ContractionOp();
1247 };
1248
1249 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1250 vector::ContractionOp contract = canonicalize(a, b);
1251 contract = contract ? contract : canonicalize(b, a);
1252 return contract ? success() : failure();
1253 }
1254};
1255
1256void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1257 MLIRContext *context) {
1258 results.add<CanonicalizeContractAdd<arith::AddIOp>,
1259 CanonicalizeContractAdd<arith::AddFOp>>(arg&: context);
1260}
1261
1262//===----------------------------------------------------------------------===//
1263// ExtractElementOp
1264//===----------------------------------------------------------------------===//
1265
1266void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1267 SetIntRangeFn setResultRanges) {
1268 setResultRanges(getResult(), argRanges.front());
1269}
1270
1271void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1272 Value source) {
1273 result.addOperands(newOperands: {source});
1274 result.addTypes(newTypes: llvm::cast<VectorType>(Val: source.getType()).getElementType());
1275}
1276
1277LogicalResult vector::ExtractElementOp::verify() {
1278 VectorType vectorType = getSourceVectorType();
1279 if (vectorType.getRank() == 0) {
1280 if (getPosition())
1281 return emitOpError(message: "expected position to be empty with 0-D vector");
1282 return success();
1283 }
1284 if (vectorType.getRank() != 1)
1285 return emitOpError(message: "unexpected >1 vector rank");
1286 if (!getPosition())
1287 return emitOpError(message: "expected position for 1-D vector");
1288 return success();
1289}
1290
1291OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1292 // Skip the 0-D vector here now.
1293 if (!adaptor.getPosition())
1294 return {};
1295
1296 // Fold extractelement (splat X) -> X.
1297 if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1298 return splat.getInput();
1299
1300 // Fold extractelement(broadcast(X)) -> X.
1301 if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1302 if (!llvm::isa<VectorType>(Val: broadcast.getSource().getType()))
1303 return broadcast.getSource();
1304
1305 auto src = dyn_cast_or_null<DenseElementsAttr>(Val: adaptor.getVector());
1306 auto pos = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getPosition());
1307 if (!pos || !src)
1308 return {};
1309
1310 auto srcElements = src.getValues<Attribute>();
1311
1312 uint64_t posIdx = pos.getInt();
1313 if (posIdx >= srcElements.size())
1314 return {};
1315
1316 return srcElements[posIdx];
1317}
1318
1319// Returns `true` if `index` is either within [0, maxIndex) or equal to
1320// `poisonValue`.
1321static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
1322 int64_t maxIndex) {
1323 return index == poisonValue || (index >= 0 && index < maxIndex);
1324}
1325
1326//===----------------------------------------------------------------------===//
1327// ExtractOp
1328//===----------------------------------------------------------------------===//
1329
1330void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1331 SetIntRangeFn setResultRanges) {
1332 setResultRanges(getResult(), argRanges.front());
1333}
1334
1335void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1336 Value source) {
1337 auto vectorTy = cast<VectorType>(Val: source.getType());
1338 build(odsBuilder&: builder, odsState&: result, source, position: SmallVector<int64_t>(vectorTy.getRank(), 0));
1339}
1340
1341void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1342 Value source, int64_t position) {
1343 build(odsBuilder&: builder, odsState&: result, source, position: ArrayRef<int64_t>{position});
1344}
1345
1346void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1347 Value source, OpFoldResult position) {
1348 build(odsBuilder&: builder, odsState&: result, source, position: ArrayRef<OpFoldResult>{position});
1349}
1350
1351void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1352 Value source, ArrayRef<int64_t> position) {
1353 build(odsBuilder&: builder, odsState&: result, vector: source, /*dynamic_position=*/ArrayRef<Value>(),
1354 static_position: builder.getDenseI64ArrayAttr(values: position));
1355}
1356
1357void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1358 Value source, ArrayRef<OpFoldResult> position) {
1359 SmallVector<int64_t> staticPos;
1360 SmallVector<Value> dynamicPos;
1361 dispatchIndexOpFoldResults(ofrs: position, dynamicVec&: dynamicPos, staticVec&: staticPos);
1362 build(odsBuilder&: builder, odsState&: result, vector: source, dynamic_position: dynamicPos,
1363 static_position: builder.getDenseI64ArrayAttr(values: staticPos));
1364}
1365
1366LogicalResult
1367ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1368 ExtractOp::Adaptor adaptor,
1369 SmallVectorImpl<Type> &inferredReturnTypes) {
1370 auto vectorType = llvm::cast<VectorType>(Val: adaptor.getVector().getType());
1371 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1372 vectorType.getRank()) {
1373 inferredReturnTypes.push_back(Elt: vectorType.getElementType());
1374 } else {
1375 auto n = std::min<size_t>(a: adaptor.getStaticPosition().size(),
1376 b: vectorType.getRank());
1377 inferredReturnTypes.push_back(Elt: VectorType::get(
1378 shape: vectorType.getShape().drop_front(N: n), elementType: vectorType.getElementType(),
1379 scalableDims: vectorType.getScalableDims().drop_front(N: n)));
1380 }
1381 return success();
1382}
1383
1384bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1385 // Allow extracting 1-element vectors instead of scalars.
1386 auto isCompatible = [](TypeRange l, TypeRange r) {
1387 auto vectorType = llvm::dyn_cast<VectorType>(Val: l.front());
1388 return vectorType && vectorType.getShape().equals(RHS: {1}) &&
1389 vectorType.getElementType() == r.front();
1390 };
1391 if (l.size() == 1 && r.size() == 1 &&
1392 (isCompatible(l, r) || isCompatible(r, l)))
1393 return true;
1394 return l == r;
1395}
1396
1397LogicalResult vector::ExtractOp::verify() {
1398 if (auto resTy = dyn_cast<VectorType>(Val: getResult().getType()))
1399 if (resTy.getRank() == 0)
1400 return emitError(
1401 message: "expected a scalar instead of a 0-d vector as the result type");
1402
1403 // Note: This check must come before getMixedPosition() to prevent a crash.
1404 auto dynamicMarkersCount =
1405 llvm::count_if(Range: getStaticPosition(), P: ShapedType::isDynamic);
1406 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1407 return emitOpError(
1408 message: "mismatch between dynamic and static positions (kDynamic marker but no "
1409 "corresponding dynamic position) -- this can only happen due to an "
1410 "incorrect fold/rewrite");
1411 auto position = getMixedPosition();
1412 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1413 return emitOpError(
1414 message: "expected position attribute of rank no greater than vector rank");
1415 for (auto [idx, pos] : llvm::enumerate(First&: position)) {
1416 if (auto attr = dyn_cast<Attribute>(Val&: pos)) {
1417 int64_t constIdx = cast<IntegerAttr>(Val&: attr).getInt();
1418 if (!isValidPositiveIndexOrPoison(
1419 index: constIdx, poisonValue: kPoisonIndex, maxIndex: getSourceVectorType().getDimSize(idx))) {
1420 return emitOpError(message: "expected position attribute #")
1421 << (idx + 1)
1422 << " to be a non-negative integer smaller than the "
1423 "corresponding vector dimension or poison (-1)";
1424 }
1425 }
1426 }
1427 return success();
1428}
1429
1430template <typename IntType>
1431static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1432 return llvm::to_vector<4>(llvm::map_range(
1433 arrayAttr.getAsRange<IntegerAttr>(),
1434 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1435}
1436
1437/// Fold the result of chains of ExtractOp in place by simply concatenating the
1438/// positions.
1439static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1440 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1441 return failure();
1442
1443 // TODO: Canonicalization for dynamic position not implemented yet.
1444 if (extractOp.hasDynamicPosition())
1445 return failure();
1446
1447 SmallVector<int64_t> globalPosition;
1448 ExtractOp currentOp = extractOp;
1449 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1450 globalPosition.append(in_start: extrPos.rbegin(), in_end: extrPos.rend());
1451 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1452 currentOp = nextOp;
1453 // TODO: Canonicalization for dynamic position not implemented yet.
1454 if (currentOp.hasDynamicPosition())
1455 return failure();
1456 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1457 globalPosition.append(in_start: extrPos.rbegin(), in_end: extrPos.rend());
1458 }
1459 extractOp.setOperand(i: 0, value: currentOp.getVector());
1460 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1461 OpBuilder b(extractOp.getContext());
1462 std::reverse(first: globalPosition.begin(), last: globalPosition.end());
1463 extractOp.setStaticPosition(globalPosition);
1464 return success();
1465}
1466
1467namespace {
1468/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1469/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1470/// Compose TransposeOp permutations as we walk back.
1471/// This helper class keeps an updated extraction position `extractPosition`
1472/// with extra trailing sentinels.
1473/// The sentinels encode the internal transposition status of the result vector.
1474/// As we iterate, extractPosition is permuted and updated.
1475class ExtractFromInsertTransposeChainState {
1476public:
1477 ExtractFromInsertTransposeChainState(ExtractOp e);
1478
1479 /// Iterate over producing insert and transpose ops until we find a fold.
1480 Value fold();
1481
1482private:
1483 /// Return true if the vector at position `a` is contained within the vector
1484 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1485 /// is a prefix of `b`.
1486 template <typename ContainerA, typename ContainerB>
1487 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1488 return a.size() <= b.size() &&
1489 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1490 }
1491
1492 /// Return true if the vector at position `a` intersects the vector at
1493 /// position `b`. Under insert/extract semantics, this is the same as equality
1494 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1495 /// Comparison is on the common prefix (i.e. zip).
1496 template <typename ContainerA, typename ContainerB>
1497 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1498 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1499 if (elemA < 0 || elemB < 0)
1500 continue;
1501 if (elemA != elemB)
1502 return false;
1503 }
1504 return true;
1505 }
1506
1507 /// Folding is only possible in the absence of an internal permutation in the
1508 /// result vector.
1509 bool canFold() {
1510 return (sentinels == ArrayRef(extractPosition).drop_front(N: extractedRank));
1511 }
1512
1513 // Helper to get the next defining op of interest.
1514 void updateStateForNextIteration(Value v) {
1515 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1516 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1517 };
1518
1519 // Case 1. If we hit a transpose, just compose the map and iterate.
1520 // Invariant: insert + transpose do not change rank, we can always compose.
1521 LogicalResult handleTransposeOp();
1522
1523 // Case 2: the insert position matches extractPosition exactly, early return.
1524 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1525
1526 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1527 /// portion of the source of the insert.
1528 /// Example:
1529 /// ```
1530 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1531 /// // extractPosition == [1, 2, 3]
1532 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1533 /// // can fold to vector.extract %source[0, 3]
1534 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1535 /// ```
1536 /// To traverse through %source, we need to set the leading dims to 0 and
1537 /// drop the extra leading dims.
1538 /// This method updates the internal state.
1539 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1540
1541 /// Try to fold in place to extract(source, extractPosition) and return the
1542 /// folded result. Return null if folding is not possible (e.g. due to an
1543 /// internal transposition in the result).
1544 Value tryToFoldExtractOpInPlace(Value source);
1545
1546 ExtractOp extractOp;
1547 int64_t vectorRank;
1548 int64_t extractedRank;
1549
1550 InsertOp nextInsertOp;
1551 TransposeOp nextTransposeOp;
1552
1553 /// Sentinel values that encode the internal permutation status of the result.
1554 /// They are set to (-1, ... , -k) at the beginning and appended to
1555 /// `extractPosition`.
1556 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1557 /// ensure that there is no internal transposition.
1558 /// Internal transposition cannot be accounted for with a folding pattern.
1559 // TODO: We could relax the internal transposition with an extra transposition
1560 // operation in a future canonicalizer.
1561 SmallVector<int64_t> sentinels;
1562 SmallVector<int64_t> extractPosition;
1563};
1564} // namespace
1565
1566ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1567 ExtractOp e)
1568 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1569 extractedRank(extractOp.getNumIndices()) {
1570 assert(vectorRank >= extractedRank && "Extracted position overflow");
1571 sentinels.reserve(N: vectorRank - extractedRank);
1572 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1573 sentinels.push_back(Elt: -(i + 1));
1574 extractPosition.assign(in_start: extractOp.getStaticPosition().begin(),
1575 in_end: extractOp.getStaticPosition().end());
1576 llvm::append_range(C&: extractPosition, R&: sentinels);
1577}
1578
1579// Case 1. If we hit a transpose, just compose the map and iterate.
1580// Invariant: insert + transpose do not change rank, we can always compose.
1581LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1582 // TODO: Canonicalization for dynamic position not implemented yet.
1583 if (extractOp.hasDynamicPosition())
1584 return failure();
1585
1586 if (!nextTransposeOp)
1587 return failure();
1588 AffineMap m = inversePermutation(map: AffineMap::getPermutationMap(
1589 permutation: nextTransposeOp.getPermutation(), context: extractOp.getContext()));
1590 extractPosition = applyPermutationMap(map: m, source: ArrayRef(extractPosition));
1591 return success();
1592}
1593
1594// Case 2: the insert position matches extractPosition exactly, early return.
1595LogicalResult
1596ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1597 Value &res) {
1598 // TODO: Canonicalization for dynamic position not implemented yet.
1599 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1600 return failure();
1601
1602 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1603 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(N: extractedRank))
1604 return failure();
1605 // Case 2.a. early-exit fold.
1606 res = nextInsertOp.getValueToStore();
1607 // Case 2.b. if internal transposition is present, canFold will be false.
1608 return success(IsSuccess: canFold());
1609}
1610
1611/// Case 3: if inserted position is a prefix of extractPosition,
1612/// extract a portion of the source of the insertion.
1613/// This method updates the internal state.
1614LogicalResult
1615ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1616 // TODO: Canonicalization for dynamic position not implemented yet.
1617 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1618 return failure();
1619
1620 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1621 if (!isContainedWithin(a: insertedPos, b: extractPosition))
1622 return failure();
1623 // Set leading dims to zero.
1624 std::fill_n(first: extractPosition.begin(), n: insertedPos.size(), value: 0);
1625 // Drop extra leading dims.
1626 extractPosition.erase(CS: extractPosition.begin(),
1627 CE: extractPosition.begin() + insertedPos.size());
1628 extractedRank = extractPosition.size() - sentinels.size();
1629 // Case 3.a. early-exit fold (break and delegate to post-while path).
1630 res = nextInsertOp.getValueToStore();
1631 // Case 3.b. if internal transposition is present, canFold will be false.
1632 return success();
1633}
1634
1635/// Try to fold in place to extract(source, extractPosition) and return the
1636/// folded result. Return null if folding is not possible (e.g. due to an
1637/// internal transposition in the result).
1638Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1639 Value source) {
1640 // TODO: Canonicalization for dynamic position not implemented yet.
1641 if (extractOp.hasDynamicPosition())
1642 return Value();
1643
1644 // If we can't fold (either internal transposition, or nothing to fold), bail.
1645 bool nothingToFold = (source == extractOp.getVector());
1646 if (nothingToFold || !canFold())
1647 return Value();
1648
1649 // Otherwise, fold by updating the op inplace and return its result.
1650 OpBuilder b(extractOp.getContext());
1651 extractOp.setStaticPosition(
1652 ArrayRef(extractPosition).take_front(N: extractedRank));
1653 extractOp.getVectorMutable().assign(value: source);
1654 return extractOp.getResult();
1655}
1656
1657/// Iterate over producing insert and transpose ops until we find a fold.
1658Value ExtractFromInsertTransposeChainState::fold() {
1659 // TODO: Canonicalization for dynamic position not implemented yet.
1660 if (extractOp.hasDynamicPosition())
1661 return Value();
1662
1663 Value valueToExtractFrom = extractOp.getVector();
1664 updateStateForNextIteration(v: valueToExtractFrom);
1665 while (nextInsertOp || nextTransposeOp) {
1666 // Case 1. If we hit a transpose, just compose the map and iterate.
1667 // Invariant: insert + transpose do not change rank, we can always compose.
1668 if (succeeded(Result: handleTransposeOp())) {
1669 valueToExtractFrom = nextTransposeOp.getVector();
1670 updateStateForNextIteration(v: valueToExtractFrom);
1671 continue;
1672 }
1673
1674 Value result;
1675 // Case 2: the position match exactly.
1676 if (succeeded(Result: handleInsertOpWithMatchingPos(res&: result)))
1677 return result;
1678
1679 // Case 3: if the inserted position is a prefix of extractPosition, we can
1680 // just extract a portion of the source of the insert.
1681 if (succeeded(Result: handleInsertOpWithPrefixPos(res&: result)))
1682 return tryToFoldExtractOpInPlace(source: result);
1683
1684 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1685 // values. This is a more difficult case and we bail.
1686 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1687 if (isContainedWithin(a: extractPosition, b: insertedPos) ||
1688 intersectsWhereNonNegative(a: extractPosition, b: insertedPos))
1689 return Value();
1690
1691 // Case 5: No intersection, we forward the extract to insertOp.dest().
1692 valueToExtractFrom = nextInsertOp.getDest();
1693 updateStateForNextIteration(v: valueToExtractFrom);
1694 }
1695 // If after all this we can fold, go for it.
1696 return tryToFoldExtractOpInPlace(source: valueToExtractFrom);
1697}
1698
1699/// Returns true if the operation has a 0-D vector type operand or result.
1700static bool hasZeroDimVectors(Operation *op) {
1701 auto hasZeroDimVectorType = [](Type type) -> bool {
1702 auto vecType = dyn_cast<VectorType>(Val&: type);
1703 return vecType && vecType.getRank() == 0;
1704 };
1705
1706 return llvm::any_of(Range: op->getOperandTypes(), P: hasZeroDimVectorType) ||
1707 llvm::any_of(Range: op->getResultTypes(), P: hasZeroDimVectorType);
1708}
1709
1710/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1711static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1712 Operation *defOp = extractOp.getVector().getDefiningOp();
1713 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(Val: defOp))
1714 return Value();
1715
1716 Value source = defOp->getOperand(idx: 0);
1717 if (extractOp.getType() == source.getType())
1718 return source;
1719 auto getRank = [](Type type) {
1720 return llvm::isa<VectorType>(Val: type) ? llvm::cast<VectorType>(Val&: type).getRank()
1721 : 0;
1722 };
1723
1724 // If splat or broadcast from a scalar, just return the source scalar.
1725 unsigned broadcastSrcRank = getRank(source.getType());
1726 if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1727 return source;
1728
1729 unsigned extractResultRank = getRank(extractOp.getType());
1730 if (extractResultRank > broadcastSrcRank)
1731 return Value();
1732 // Check that the dimension of the result haven't been broadcasted.
1733 auto extractVecType = llvm::dyn_cast<VectorType>(Val: extractOp.getType());
1734 auto broadcastVecType = llvm::dyn_cast<VectorType>(Val: source.getType());
1735 if (extractVecType && broadcastVecType &&
1736 extractVecType.getShape() !=
1737 broadcastVecType.getShape().take_back(N: extractResultRank))
1738 return Value();
1739
1740 auto broadcastOp = cast<vector::BroadcastOp>(Val: defOp);
1741 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1742
1743 // Detect all the positions that come from "dim-1" broadcasting.
1744 // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1745 // extract position to `0` when extracting from the source operand.
1746 llvm::SetVector<int64_t> broadcastedUnitDims =
1747 broadcastOp.computeBroadcastedUnitDims();
1748 SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
1749 OpBuilder b(extractOp.getContext());
1750 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1751 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1752 if (broadcastedUnitDims.contains(key: i))
1753 extractPos[i] = b.getIndexAttr(value: 0);
1754 // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1755 // matching extract position when extracting from the source operand.
1756 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1757 extractPos.erase(CS: extractPos.begin(),
1758 CE: std::next(x: extractPos.begin(), n: extractPos.size() - rankDiff));
1759 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1760 auto [staticPos, dynPos] = decomposeMixedValues(mixedValues: extractPos);
1761 extractOp->setOperands(
1762 llvm::to_vector(Range: llvm::concat<Value>(Ranges: ValueRange(source), Ranges&: dynPos)));
1763 extractOp.setStaticPosition(staticPos);
1764 return extractOp.getResult();
1765}
1766
1767/// Fold extractOp coming from ShuffleOp.
1768///
1769/// Example:
1770///
1771/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1772/// : vector<8xf32>, vector<8xf32>
1773/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1774/// ->
1775/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1776///
1777static Value foldExtractFromShuffle(ExtractOp extractOp) {
1778 // Dynamic positions are not folded as the resulting code would be more
1779 // complex than the input code.
1780 if (extractOp.hasDynamicPosition())
1781 return Value();
1782
1783 auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1784 if (!shuffleOp)
1785 return Value();
1786
1787 // TODO: 0-D or multi-dimensional vectors not supported yet.
1788 if (shuffleOp.getResultVectorType().getRank() != 1)
1789 return Value();
1790
1791 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1792 auto shuffleMask = shuffleOp.getMask();
1793 int64_t extractIdx = extractOp.getStaticPosition()[0];
1794 int64_t shuffleIdx = shuffleMask[extractIdx];
1795
1796 // Find the shuffled vector to extract from based on the shuffle index.
1797 if (shuffleIdx < inputVecSize) {
1798 extractOp.setOperand(i: 0, value: shuffleOp.getV1());
1799 extractOp.setStaticPosition({shuffleIdx});
1800 } else {
1801 extractOp.setOperand(i: 0, value: shuffleOp.getV2());
1802 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1803 }
1804
1805 return extractOp.getResult();
1806}
1807
1808// Fold extractOp with source coming from ShapeCast op.
1809static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1810 // TODO: Canonicalization for dynamic position not implemented yet.
1811 if (extractOp.hasDynamicPosition())
1812 return Value();
1813
1814 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1815 if (!shapeCastOp)
1816 return Value();
1817
1818 // Get the nth dimension size starting from lowest dimension.
1819 auto getDimReverse = [](VectorType type, int64_t n) {
1820 return type.getShape().take_back(N: n + 1).front();
1821 };
1822 int64_t destinationRank =
1823 llvm::isa<VectorType>(Val: extractOp.getType())
1824 ? llvm::cast<VectorType>(Val: extractOp.getType()).getRank()
1825 : 0;
1826 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1827 return Value();
1828 if (destinationRank > 0) {
1829 auto destinationType =
1830 llvm::cast<VectorType>(Val: extractOp.getResult().getType());
1831 for (int64_t i = 0; i < destinationRank; i++) {
1832 // The lowest dimension of the destination must match the lowest
1833 // dimension of the shapecast op source.
1834 // TODO: This case could be support in a canonicalization pattern.
1835 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1836 getDimReverse(destinationType, i))
1837 return Value();
1838 }
1839 }
1840 // Extract the strides associated with the extract op vector source. Then use
1841 // this to calculate a linearized position for the extract.
1842 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1843 std::reverse(first: extractedPos.begin(), last: extractedPos.end());
1844 SmallVector<int64_t, 4> strides;
1845 int64_t stride = 1;
1846 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1847 strides.push_back(Elt: stride);
1848 stride *=
1849 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1850 }
1851
1852 int64_t position = linearize(offsets: extractedPos, basis: strides);
1853 // Then extract the strides associated to the shapeCast op vector source and
1854 // delinearize the position using those strides.
1855 SmallVector<int64_t, 4> newStrides;
1856 int64_t numDimension =
1857 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1858 stride = 1;
1859 for (int64_t i = 0; i < numDimension; i++) {
1860 newStrides.push_back(Elt: stride);
1861 stride *=
1862 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1863 }
1864 std::reverse(first: newStrides.begin(), last: newStrides.end());
1865 SmallVector<int64_t, 4> newPosition = delinearize(linearIndex: position, strides: newStrides);
1866 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1867 OpBuilder b(extractOp.getContext());
1868 extractOp.setStaticPosition(newPosition);
1869 extractOp.setOperand(i: 0, value: shapeCastOp.getSource());
1870 return extractOp.getResult();
1871}
1872
1873/// Fold an ExtractOp from ExtractStridedSliceOp.
1874static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1875 // TODO: Canonicalization for dynamic position not implemented yet.
1876 if (extractOp.hasDynamicPosition())
1877 return Value();
1878
1879 auto extractStridedSliceOp =
1880 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1881 if (!extractStridedSliceOp)
1882 return Value();
1883
1884 // 0-D vectors not supported.
1885 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1886 if (hasZeroDimVectors(op: extractStridedSliceOp))
1887 return Value();
1888
1889 // Return if 'extractStridedSliceOp' has non-unit strides.
1890 if (extractStridedSliceOp.hasNonUnitStrides())
1891 return Value();
1892
1893 // Trim offsets for dimensions fully extracted.
1894 auto sliceOffsets =
1895 extractVector<int64_t>(arrayAttr: extractStridedSliceOp.getOffsets());
1896 while (!sliceOffsets.empty()) {
1897 size_t lastOffset = sliceOffsets.size() - 1;
1898 if (sliceOffsets.back() != 0 ||
1899 extractStridedSliceOp.getType().getDimSize(idx: lastOffset) !=
1900 extractStridedSliceOp.getSourceVectorType().getDimSize(idx: lastOffset))
1901 break;
1902 sliceOffsets.pop_back();
1903 }
1904 unsigned destinationRank = 0;
1905 if (auto vecType = llvm::dyn_cast<VectorType>(Val: extractOp.getType()))
1906 destinationRank = vecType.getRank();
1907 // The dimensions of the result need to be untouched by the
1908 // extractStridedSlice op.
1909 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1910 sliceOffsets.size())
1911 return Value();
1912
1913 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1914 assert(extractedPos.size() >= sliceOffsets.size());
1915 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1916 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1917 extractOp.getVectorMutable().assign(value: extractStridedSliceOp.getVector());
1918
1919 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1920 OpBuilder b(extractOp.getContext());
1921 extractOp.setStaticPosition(extractedPos);
1922 return extractOp.getResult();
1923}
1924
1925/// Fold extract_op fed from a chain of insertStridedSlice ops.
1926static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1927 // TODO: Canonicalization for dynamic position not implemented yet.
1928 if (extractOp.hasDynamicPosition())
1929 return Value();
1930
1931 int64_t destinationRank =
1932 llvm::isa<VectorType>(Val: extractOp.getType())
1933 ? llvm::cast<VectorType>(Val: extractOp.getType()).getRank()
1934 : 0;
1935 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1936 if (!insertOp)
1937 return Value();
1938
1939 // 0-D vectors not supported.
1940 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1941 if (hasZeroDimVectors(op: insertOp))
1942 return Value();
1943
1944 while (insertOp) {
1945 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1946 insertOp.getSourceVectorType().getRank();
1947 if (destinationRank > insertOp.getSourceVectorType().getRank())
1948 return Value();
1949 auto insertOffsets = extractVector<int64_t>(arrayAttr: insertOp.getOffsets());
1950 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1951
1952 if (llvm::any_of(Range: insertOp.getStrides(), P: [](Attribute attr) {
1953 return llvm::cast<IntegerAttr>(Val&: attr).getInt() != 1;
1954 }))
1955 return Value();
1956 bool disjoint = false;
1957 SmallVector<int64_t, 4> offsetDiffs;
1958 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1959 int64_t start = insertOffsets[dim];
1960 int64_t size =
1961 (dim < insertRankDiff)
1962 ? 1
1963 : insertOp.getSourceVectorType().getDimSize(idx: dim - insertRankDiff);
1964 int64_t end = start + size;
1965 int64_t offset = extractOffsets[dim];
1966 // Check if the start of the extract offset is in the interval inserted.
1967 if (start <= offset && offset < end) {
1968 if (dim >= insertRankDiff)
1969 offsetDiffs.push_back(Elt: offset - start);
1970 continue;
1971 }
1972 disjoint = true;
1973 break;
1974 }
1975 // The extract element chunk overlap with the vector inserted.
1976 if (!disjoint) {
1977 // If any of the inner dimensions are only partially inserted we have a
1978 // partial overlap.
1979 int64_t srcRankDiff =
1980 insertOp.getSourceVectorType().getRank() - destinationRank;
1981 for (int64_t i = 0; i < destinationRank; i++) {
1982 if (insertOp.getSourceVectorType().getDimSize(idx: i + srcRankDiff) !=
1983 insertOp.getDestVectorType().getDimSize(idx: i + srcRankDiff +
1984 insertRankDiff))
1985 return Value();
1986 }
1987 extractOp.getVectorMutable().assign(value: insertOp.getValueToStore());
1988 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1989 OpBuilder b(extractOp.getContext());
1990 extractOp.setStaticPosition(offsetDiffs);
1991 return extractOp.getResult();
1992 }
1993 // If the chunk extracted is disjoint from the chunk inserted, keep
1994 // looking in the insert chain.
1995 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1996 }
1997 return Value();
1998}
1999
2000/// Try to fold the extraction of a scalar from a vector defined by
2001/// vector.from_elements. E.g.:
2002///
2003/// %0 = vector.from_elements %a, %b : vector<2xf32>
2004/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2005/// ==> fold to %a
2006static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2007 // Dynamic extractions cannot be folded.
2008 if (extractOp.hasDynamicPosition())
2009 return {};
2010
2011 // Look for extract(from_elements).
2012 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2013 if (!fromElementsOp)
2014 return {};
2015
2016 // Scalable vectors are not supported.
2017 auto vecType = llvm::cast<VectorType>(Val: fromElementsOp.getType());
2018 if (vecType.isScalable())
2019 return {};
2020
2021 // Only extractions of scalars are supported.
2022 int64_t rank = vecType.getRank();
2023 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2024 if (extractOp.getType() != vecType.getElementType())
2025 return {};
2026 assert(static_cast<int64_t>(indices.size()) == rank &&
2027 "unexpected number of indices");
2028
2029 // Compute flattened/linearized index and fold to operand.
2030 int flatIndex = 0;
2031 int stride = 1;
2032 for (int i = rank - 1; i >= 0; --i) {
2033 flatIndex += indices[i] * stride;
2034 stride *= vecType.getDimSize(idx: i);
2035 }
2036 return fromElementsOp.getElements()[flatIndex];
2037}
2038
2039/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2040/// then fold it.
2041template <typename OpType, typename AdaptorType>
2042static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2043 SmallVectorImpl<Value> &operands) {
2044 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2045 OperandRange dynamicPosition = op.getDynamicPosition();
2046 ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2047 ArrayRef<int64_t> vectorShape;
2048 if constexpr (std::is_same_v<OpType, ExtractOp>)
2049 vectorShape = op.getSourceVectorType().getShape();
2050 else
2051 vectorShape = op.getDestVectorType().getShape();
2052
2053 // If the dynamic operands is empty, it is returned directly.
2054 if (!dynamicPosition.size())
2055 return {};
2056
2057 // `index` is used to iterate over the `dynamicPosition`.
2058 unsigned index = 0;
2059
2060 // `opChange` is a flag. If it is true, it means to update `op` in place.
2061 bool opChange = false;
2062 for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2063 if (ShapedType::isStatic(dValue: staticPosition[i]))
2064 continue;
2065 Attribute positionAttr = dynamicPositionAttr[index];
2066 Value position = dynamicPosition[index++];
2067 if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(Val&: positionAttr)) {
2068 int64_t value = attr.getInt();
2069 // Do not fold if the value is out of bounds (-1 signifies a poison
2070 // value rather than OOB index).
2071 if (value >= -1 && value < vectorShape[i]) {
2072 staticPosition[i] = attr.getInt();
2073 opChange = true;
2074 continue;
2075 }
2076 }
2077 operands.push_back(Elt: position);
2078 }
2079
2080 if (opChange) {
2081 op.setStaticPosition(staticPosition);
2082 op.getOperation()->setOperands(operands);
2083 // Return the original result to indicate an in-place folding happened.
2084 return op.getResult();
2085 }
2086 return {};
2087}
2088
2089/// Fold an insert or extract operation into an poison value when a poison index
2090/// is found at any dimension of the static position.
2091static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
2092 ArrayRef<int64_t> staticPos,
2093 int64_t poisonVal) {
2094 if (!is_contained(Range&: staticPos, Element: poisonVal))
2095 return {};
2096
2097 return ub::PoisonAttr::get(ctx: context);
2098}
2099
2100/// Fold a vector extract from is a poison source.
2101static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
2102 if (isa_and_nonnull<ub::PoisonAttr>(Val: srcAttr))
2103 return srcAttr;
2104
2105 return {};
2106}
2107
2108/// Fold a vector extract extracting from a DenseElementsAttr.
2109static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
2110 Attribute srcAttr) {
2111 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(Val&: srcAttr);
2112 if (!denseAttr) {
2113 return {};
2114 }
2115
2116 if (denseAttr.isSplat()) {
2117 Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2118 if (auto vecDstType = dyn_cast<VectorType>(Val: extractOp.getType()))
2119 newAttr = DenseElementsAttr::get(type: vecDstType, values: newAttr);
2120 return newAttr;
2121 }
2122
2123 auto vecTy = cast<VectorType>(Val: extractOp.getSourceVectorType());
2124 if (vecTy.isScalable())
2125 return {};
2126
2127 if (extractOp.hasDynamicPosition()) {
2128 return {};
2129 }
2130
2131 // Materializing subsets of a large constant array can generally lead to
2132 // explosion in IR size because of different combination of subsets that
2133 // can exist. However, vector.extract is a restricted form of subset
2134 // extract where you can only extract non-overlapping (or the same) subset for
2135 // a given rank of the subset. Because of this property, the IR size can only
2136 // increase at most by `rank * size(array)` from a single constant array being
2137 // extracted by multiple extracts.
2138
2139 // Calculate the linearized position of the continuous chunk of elements to
2140 // extract.
2141 SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2142 copy(Range: extractOp.getStaticPosition(), Out: completePositions.begin());
2143 int64_t startPos =
2144 linearize(offsets: completePositions, basis: computeStrides(sizes: vecTy.getShape()));
2145 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2146
2147 TypedAttr newAttr;
2148 if (auto resVecTy = dyn_cast<VectorType>(Val: extractOp.getType())) {
2149 SmallVector<Attribute> elementValues(
2150 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2151 newAttr = DenseElementsAttr::get(type: resVecTy, values: elementValues);
2152 } else {
2153 newAttr = *denseValuesBegin;
2154 }
2155
2156 return newAttr;
2157}
2158
2159OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2160 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2161 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2162 // mismatch).
2163 if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
2164 return getVector();
2165 if (auto res = foldPoisonSrcExtractOp(srcAttr: adaptor.getVector()))
2166 return res;
2167 // Fold `arith.constant` indices into the `vector.extract` operation.
2168 // Do not stop here as this fold may enable subsequent folds that require
2169 // constant indices.
2170 SmallVector<Value> operands = {getVector()};
2171 auto inplaceFolded = extractInsertFoldConstantOp(op: *this, adaptor, operands);
2172
2173 if (auto res = foldPoisonIndexInsertExtractOp(
2174 context: getContext(), staticPos: adaptor.getStaticPosition(), poisonVal: kPoisonIndex))
2175 return res;
2176 if (auto res = foldDenseElementsAttrSrcExtractOp(extractOp: *this, srcAttr: adaptor.getVector()))
2177 return res;
2178 if (succeeded(Result: foldExtractOpFromExtractChain(extractOp: *this)))
2179 return getResult();
2180 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2181 return res;
2182 if (auto res = foldExtractFromBroadcast(extractOp: *this))
2183 return res;
2184 if (auto res = foldExtractFromShuffle(extractOp: *this))
2185 return res;
2186 if (auto res = foldExtractFromShapeCast(extractOp: *this))
2187 return res;
2188 if (auto val = foldExtractFromExtractStrided(extractOp: *this))
2189 return val;
2190 if (auto val = foldExtractStridedOpFromInsertChain(extractOp: *this))
2191 return val;
2192 if (auto val = foldScalarExtractFromFromElements(extractOp: *this))
2193 return val;
2194
2195 return inplaceFolded;
2196}
2197
2198namespace {
2199
2200// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2201class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2202public:
2203 using OpRewritePattern::OpRewritePattern;
2204
2205 LogicalResult matchAndRewrite(ExtractOp extractOp,
2206 PatternRewriter &rewriter) const override {
2207 Operation *defOp = extractOp.getVector().getDefiningOp();
2208 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(Val: defOp))
2209 return failure();
2210
2211 Value source = defOp->getOperand(idx: 0);
2212 if (extractOp.getType() == source.getType())
2213 return failure();
2214 auto getRank = [](Type type) {
2215 return llvm::isa<VectorType>(Val: type)
2216 ? llvm::cast<VectorType>(Val&: type).getRank()
2217 : 0;
2218 };
2219 unsigned broadcastSrcRank = getRank(source.getType());
2220 unsigned extractResultRank = getRank(extractOp.getType());
2221 // We only consider the case where the rank of the source is less than or
2222 // equal to the rank of the extract dst. The other cases are handled in the
2223 // folding patterns.
2224 if (extractResultRank < broadcastSrcRank)
2225 return failure();
2226 // For scalar result, the input can only be a rank-0 vector, which will
2227 // be handled by the folder.
2228 if (extractResultRank == 0)
2229 return failure();
2230
2231 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2232 op: extractOp, args: extractOp.getType(), args&: source);
2233 return success();
2234 }
2235};
2236
2237// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2238class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2239public:
2240 using OpRewritePattern::OpRewritePattern;
2241
2242 LogicalResult matchAndRewrite(ExtractOp extractOp,
2243 PatternRewriter &rewriter) const override {
2244 auto createMaskOp =
2245 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
2246 if (!createMaskOp)
2247 return failure();
2248
2249 VectorType extractedMaskType =
2250 llvm::dyn_cast<VectorType>(Val: extractOp.getResult().getType());
2251
2252 if (!extractedMaskType)
2253 return failure();
2254
2255 auto maskOperands = createMaskOp.getOperands();
2256 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2257 VectorType maskType = createMaskOp.getVectorType();
2258
2259 bool containsUnknownDims = false;
2260 bool allFalse = getMaskFormat(mask: createMaskOp) == MaskFormat::AllFalse;
2261
2262 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2263 dimIdx++) {
2264 int64_t pos = extractOpPos[dimIdx];
2265 Value operand = maskOperands[dimIdx];
2266 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2267 if (!constantOp) {
2268 // Bounds of this dim unknown.
2269 containsUnknownDims = true;
2270 continue;
2271 }
2272
2273 int64_t createMaskBound =
2274 llvm::cast<IntegerAttr>(Val: constantOp.getValue()).getInt();
2275
2276 if (pos != ShapedType::kDynamic) {
2277 // If any position is outside the range from the `create_mask`, then the
2278 // extracted mask will be all-false.
2279 allFalse |= pos >= createMaskBound;
2280 } else if (createMaskBound < maskType.getDimSize(idx: dimIdx)) {
2281 // This dim is not all-true and since this is a dynamic index we don't
2282 // know if the extraction is within the true or false region.
2283 // Note: Zero dims have already handled via getMaskFormat().
2284 containsUnknownDims = true;
2285 }
2286 }
2287
2288 if (allFalse) {
2289 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2290 op: extractOp, args: DenseElementsAttr::get(type: extractedMaskType, value: false));
2291 } else if (!containsUnknownDims) {
2292 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2293 op: extractOp, args&: extractedMaskType,
2294 args: maskOperands.drop_front(n: extractOpPos.size()));
2295 } else {
2296 return failure();
2297 }
2298 return success();
2299 }
2300};
2301
2302// Folds extract(shape_cast(..)) into shape_cast when the total element count
2303// does not change.
2304LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2305 PatternRewriter &rewriter) {
2306 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2307 if (!castOp)
2308 return failure();
2309
2310 VectorType sourceType = castOp.getSourceVectorType();
2311 auto targetType = dyn_cast<VectorType>(Val: extractOp.getResult().getType());
2312 if (!targetType)
2313 return failure();
2314
2315 if (sourceType.getNumElements() != targetType.getNumElements())
2316 return failure();
2317
2318 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op: extractOp, args&: targetType,
2319 args: castOp.getSource());
2320 return success();
2321}
2322
2323/// Try to canonicalize the extraction of a subvector from a vector defined by
2324/// vector.from_elements. E.g.:
2325///
2326/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2327/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2328/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2329LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2330 PatternRewriter &rewriter) {
2331 // Dynamic positions are not supported.
2332 if (extractOp.hasDynamicPosition())
2333 return failure();
2334
2335 // Scalar extracts are handled by the folder.
2336 auto resultType = dyn_cast<VectorType>(Val: extractOp.getType());
2337 if (!resultType)
2338 return failure();
2339
2340 // Look for extracts from a from_elements op.
2341 auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2342 if (!fromElementsOp)
2343 return failure();
2344 VectorType inputType = fromElementsOp.getType();
2345
2346 // Scalable vectors are not supported.
2347 if (resultType.isScalable() || inputType.isScalable())
2348 return failure();
2349
2350 // Compute the position of first extracted element and flatten/linearize the
2351 // position.
2352 SmallVector<int64_t> firstElementPos =
2353 llvm::to_vector(Range: extractOp.getStaticPosition());
2354 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2355 int flatIndex = 0;
2356 int stride = 1;
2357 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2358 flatIndex += firstElementPos[i] * stride;
2359 stride *= inputType.getDimSize(idx: i);
2360 }
2361
2362 // Replace the op with a smaller from_elements op.
2363 rewriter.replaceOpWithNewOp<FromElementsOp>(
2364 op: extractOp, args&: resultType,
2365 args: fromElementsOp.getElements().slice(n: flatIndex,
2366 m: resultType.getNumElements()));
2367 return success();
2368}
2369
2370} // namespace
2371
2372void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2373 MLIRContext *context) {
2374 results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(arg&: context);
2375 results.add(implFn: foldExtractFromShapeCastToShapeCast);
2376 results.add(implFn: foldExtractFromFromElements);
2377}
2378
2379static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2380 SmallVectorImpl<int64_t> &results) {
2381 for (auto attr : arrayAttr)
2382 results.push_back(Elt: llvm::cast<IntegerAttr>(Val&: attr).getInt());
2383}
2384
2385//===----------------------------------------------------------------------===//
2386// FmaOp
2387//===----------------------------------------------------------------------===//
2388
2389std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2390 return llvm::to_vector<4>(Range: getVectorType().getShape());
2391}
2392
2393//===----------------------------------------------------------------------===//
2394// ToElementsOp
2395//===----------------------------------------------------------------------===//
2396
2397/// Returns true if all the `operands` are defined by `defOp`.
2398/// Otherwise, returns false.
2399static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2400 if (operands.empty())
2401 return false;
2402
2403 return llvm::all_of(Range&: operands, P: [&](Value operand) {
2404 Operation *currentDef = operand.getDefiningOp();
2405 return currentDef == defOp;
2406 });
2407}
2408
2409/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2410/// (%e0, %e1, ...). For example:
2411///
2412/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2413/// %1:3 = vector.to_elements %0 : vector<3xf32>
2414/// user_op %1#0, %1#1, %1#2
2415///
2416/// becomes:
2417///
2418/// user_op %a, %b, %c
2419///
2420static LogicalResult
2421foldToElementsFromElements(ToElementsOp toElementsOp,
2422 SmallVectorImpl<OpFoldResult> &results) {
2423 auto fromElementsOp =
2424 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2425 if (!fromElementsOp)
2426 return failure();
2427
2428 llvm::append_range(C&: results, R: fromElementsOp.getElements());
2429 return success();
2430}
2431
2432LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2433 SmallVectorImpl<OpFoldResult> &results) {
2434 return foldToElementsFromElements(toElementsOp: *this, results);
2435}
2436
2437//===----------------------------------------------------------------------===//
2438// FromElementsOp
2439//===----------------------------------------------------------------------===//
2440
2441/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2442///
2443/// Case #1: Input and output vectors are the same.
2444///
2445/// %0:3 = vector.to_elements %a : vector<3xf32>
2446/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2447/// user_op %1
2448///
2449/// becomes:
2450///
2451/// user_op %a
2452///
2453static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2454 OperandRange fromElemsOperands = fromElementsOp.getElements();
2455 if (fromElemsOperands.empty())
2456 return {};
2457
2458 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2459 if (!toElementsOp)
2460 return {};
2461
2462 if (!haveSameDefiningOp(operands: fromElemsOperands, defOp: toElementsOp))
2463 return {};
2464
2465 // Case #1: Input and output vectors are the same. Forward the input vector.
2466 Value toElementsInput = toElementsOp.getSource();
2467 if (fromElementsOp.getType() == toElementsInput.getType() &&
2468 llvm::equal(LRange&: fromElemsOperands, RRange: toElementsOp.getResults())) {
2469 return toElementsInput;
2470 }
2471
2472 // TODO: Support cases with different input and output shapes and different
2473 // number of elements.
2474
2475 return {};
2476}
2477
2478/// Fold vector.from_elements to a constant when all operands are constants.
2479/// Example:
2480/// %c1 = arith.constant 1 : i32
2481/// %c2 = arith.constant 2 : i32
2482/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2483/// =>
2484/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2485///
2486static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2487 ArrayRef<Attribute> elements) {
2488 if (llvm::any_of(Range&: elements, P: [](Attribute attr) { return !attr; }))
2489 return {};
2490
2491 auto destVecType = fromElementsOp.getDest().getType();
2492 auto destEltType = destVecType.getElementType();
2493 // Constant attributes might have a different type than the return type.
2494 // Convert them before creating the dense elements attribute.
2495 auto convertedElements = llvm::map_to_vector(C&: elements, F: [&](Attribute attr) {
2496 return convertIntegerAttr(attr, expectedType: destEltType);
2497 });
2498
2499 return DenseElementsAttr::get(type: destVecType, values: convertedElements);
2500}
2501
2502OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2503 if (auto res = foldFromElementsToElements(fromElementsOp: *this))
2504 return res;
2505 if (auto res = foldFromElementsToConstant(fromElementsOp: *this, elements: adaptor.getElements()))
2506 return res;
2507
2508 return {};
2509}
2510
2511/// Rewrite a vector.from_elements into a vector.splat if all elements are the
2512/// same SSA value. E.g.:
2513///
2514/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2515/// ==> rewrite to vector.splat %a : vector<3xf32>
2516static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2517 PatternRewriter &rewriter) {
2518 if (!llvm::all_equal(Range: fromElementsOp.getElements()))
2519 return failure();
2520 rewriter.replaceOpWithNewOp<SplatOp>(op: fromElementsOp, args: fromElementsOp.getType(),
2521 args: fromElementsOp.getElements().front());
2522 return success();
2523}
2524
2525/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2526/// on a single extract. Example:
2527/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2528/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2529/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2530///
2531/// becomes
2532/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2533/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2534///
2535/// The requirements for this to be valid are
2536///
2537/// i) The elements are extracted from the same vector (%source).
2538///
2539/// ii) The elements form a suffix of %source. Specifically, the number
2540/// of elements is the same as the product of the last N dimension sizes
2541/// of %source, for some N.
2542///
2543/// iii) The elements are extracted contiguously in ascending order.
2544
2545class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2546
2547 using OpRewritePattern::OpRewritePattern;
2548
2549 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2550 PatternRewriter &rewriter) const override {
2551
2552 // Handled by `rewriteFromElementsAsSplat`
2553 if (fromElements.getType().getNumElements() == 1)
2554 return failure();
2555
2556 // The common source that all elements are extracted from, if one exists.
2557 TypedValue<VectorType> source;
2558 // The position of the combined extract operation, if one is created.
2559 ArrayRef<int64_t> combinedPosition;
2560 // The expected index of extraction of the current element in the loop, if
2561 // elements are extracted contiguously in ascending order.
2562 SmallVector<int64_t> expectedPosition;
2563
2564 for (auto [insertIndex, element] :
2565 llvm::enumerate(First: fromElements.getElements())) {
2566
2567 // Check that the element is from a vector.extract operation.
2568 auto extractOp =
2569 dyn_cast_if_present<vector::ExtractOp>(Val: element.getDefiningOp());
2570 if (!extractOp) {
2571 return rewriter.notifyMatchFailure(arg&: fromElements,
2572 msg: "element not from vector.extract");
2573 }
2574
2575 // Check condition (i) by checking that all elements have the same source
2576 // as the first element.
2577 if (insertIndex == 0) {
2578 source = extractOp.getVector();
2579 } else if (extractOp.getVector() != source) {
2580 return rewriter.notifyMatchFailure(arg&: fromElements,
2581 msg: "element from different vector");
2582 }
2583
2584 ArrayRef<int64_t> position = extractOp.getStaticPosition();
2585 int64_t rank = position.size();
2586 assert(rank == source.getType().getRank() &&
2587 "scalar extract must have full rank position");
2588
2589 // Check condition (ii) by checking that the position that the first
2590 // element is extracted from has sufficient trailing 0s. For example, in
2591 //
2592 // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2593 // [...]
2594 // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2595 //
2596 // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2597 // elements, which is the number of elements of %n, so this is valid.
2598 if (insertIndex == 0) {
2599 const int64_t numElms = fromElements.getType().getNumElements();
2600 int64_t numSuffixElms = 1;
2601 int64_t index = rank;
2602 while (index > 0 && position[index - 1] == 0 &&
2603 numSuffixElms < numElms) {
2604 numSuffixElms *= source.getType().getDimSize(idx: index - 1);
2605 --index;
2606 }
2607 if (numSuffixElms != numElms) {
2608 return rewriter.notifyMatchFailure(
2609 arg&: fromElements, msg: "elements do not form a suffix of source");
2610 }
2611 expectedPosition = llvm::to_vector(Range&: position);
2612 combinedPosition = position.drop_back(N: rank - index);
2613 }
2614
2615 // Check condition (iii).
2616 else if (expectedPosition != position) {
2617 return rewriter.notifyMatchFailure(
2618 arg&: fromElements, msg: "elements not in ascending order (static order)");
2619 }
2620 increment(indices: expectedPosition, shape: source.getType().getShape());
2621 }
2622
2623 auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2624 location: fromElements.getLoc(), args&: source, args&: combinedPosition);
2625
2626 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2627 op: fromElements, args: fromElements.getType(), args&: extracted);
2628
2629 return success();
2630 }
2631
2632 /// Increments n-D `indices` by 1 starting from the innermost dimension.
2633 static void increment(MutableArrayRef<int64_t> indices,
2634 ArrayRef<int64_t> shape) {
2635 for (int dim : llvm::reverse(C: llvm::seq<int>(Begin: 0, End: indices.size()))) {
2636 indices[dim] += 1;
2637 if (indices[dim] < shape[dim])
2638 break;
2639 indices[dim] = 0;
2640 }
2641 }
2642};
2643
2644void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2645 MLIRContext *context) {
2646 results.add(implFn: rewriteFromElementsAsSplat);
2647 results.add<FromElementsToShapeCast>(arg&: context);
2648}
2649
2650//===----------------------------------------------------------------------===//
2651// BroadcastOp
2652//===----------------------------------------------------------------------===//
2653
2654void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2655 SetIntRangeFn setResultRanges) {
2656 setResultRanges(getResult(), argRanges.front());
2657}
2658
2659std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2660 return llvm::to_vector<4>(Range: getResultVectorType().getShape());
2661}
2662
2663/// Return the dimensions of the result vector that were formerly ones in the
2664/// source tensor and thus correspond to "dim-1" broadcasting.
2665static llvm::SetVector<int64_t>
2666computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
2667 ArrayRef<int64_t> dstShape) {
2668 int64_t rankDiff = dstShape.size() - srcShape.size();
2669 int64_t dstDim = rankDiff;
2670 llvm::SetVector<int64_t> res;
2671 for (auto [s1, s2] :
2672 llvm::zip_equal(t&: srcShape, u: dstShape.drop_front(N: rankDiff))) {
2673 if (s1 != s2) {
2674 assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2675 res.insert(X: dstDim);
2676 }
2677 ++dstDim;
2678 }
2679 return res;
2680}
2681
2682llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2683 // Scalar broadcast is without any unit dim broadcast.
2684 auto srcVectorType = llvm::dyn_cast<VectorType>(Val: getSourceType());
2685 if (!srcVectorType)
2686 return {};
2687 return ::computeBroadcastedUnitDims(srcShape: srcVectorType.getShape(),
2688 dstShape: getResultVectorType().getShape());
2689}
2690
2691/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2692/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2693/// This requires (and asserts) that the broadcast is free of "dim-1"
2694/// broadcasting.
2695/// Since vector.broadcast only allows expanding leading dimensions, an extra
2696/// vector.transpose may be inserted to make the broadcast possible.
2697/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2698/// the helper will assert. This means:
2699/// 1. `dstShape` must not be empty.
2700/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2701/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2702// must match the `value` shape.
2703Value BroadcastOp::createOrFoldBroadcastOp(
2704 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2705 const llvm::SetVector<int64_t> &broadcastedDims) {
2706 assert(!dstShape.empty() && "unexpected empty dst shape");
2707
2708 // Well-formedness check.
2709 SmallVector<int64_t> checkShape;
2710 for (int i = 0, e = dstShape.size(); i < e; ++i) {
2711 if (broadcastedDims.contains(key: i))
2712 continue;
2713 checkShape.push_back(Elt: dstShape[i]);
2714 }
2715 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2716 "ill-formed broadcastedDims contains values not confined to "
2717 "destVectorShape");
2718
2719 Location loc = value.getLoc();
2720 Type elementType = getElementTypeOrSelf(type: value.getType());
2721 VectorType srcVectorType = llvm::dyn_cast<VectorType>(Val: value.getType());
2722 VectorType dstVectorType = VectorType::get(shape: dstShape, elementType);
2723
2724 // Step 2. If scalar -> dstShape broadcast, just do it.
2725 if (!srcVectorType) {
2726 assert(checkShape.empty() &&
2727 "ill-formed createOrFoldBroadcastOp arguments");
2728 return b.createOrFold<vector::BroadcastOp>(location: loc, args&: dstVectorType, args&: value);
2729 }
2730
2731 assert(srcVectorType.getShape().equals(checkShape) &&
2732 "ill-formed createOrFoldBroadcastOp arguments");
2733
2734 // Step 3. Since vector.broadcast only allows creating leading dims,
2735 // vector -> dstShape broadcast may require a transpose.
2736 // Traverse the dims in order and construct:
2737 // 1. The leading entries of the broadcastShape that is guaranteed to be
2738 // achievable by a simple broadcast.
2739 // 2. The induced permutation for the subsequent vector.transpose that will
2740 // bring us from `broadcastShape` back to he desired `dstShape`.
2741 // If the induced permutation is not the identity, create a vector.transpose.
2742 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2743 broadcastShape.reserve(N: dstShape.size());
2744 // Consider the example:
2745 // srcShape = 2x4
2746 // dstShape = 1x2x3x4x5
2747 // broadcastedDims = [0, 2, 4]
2748 //
2749 // We want to build:
2750 // broadcastShape = 1x3x5x2x4
2751 // permutation = [0, 2, 4, 1, 3]
2752 // ---V--- -----V-----
2753 // leading broadcast part src shape part
2754 //
2755 // Note that the trailing dims of broadcastShape are exactly the srcShape
2756 // by construction.
2757 // nextSrcShapeDim is used to keep track of where in the permutation the
2758 // "src shape part" occurs.
2759 int64_t nextSrcShapeDim = broadcastedDims.size();
2760 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2761 if (broadcastedDims.contains(key: i)) {
2762 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2763 // bring it to the head of the broadcastShape.
2764 // It will need to be permuted back from `broadcastShape.size() - 1` into
2765 // position `i`.
2766 broadcastShape.push_back(Elt: dstShape[i]);
2767 permutation[i] = broadcastShape.size() - 1;
2768 } else {
2769 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2770 // shape and needs to be permuted into position `i`.
2771 // Don't touch `broadcastShape` here, the whole srcShape will be
2772 // appended after.
2773 permutation[i] = nextSrcShapeDim++;
2774 }
2775 }
2776 // 3.c. Append the srcShape.
2777 llvm::append_range(C&: broadcastShape, R: srcVectorType.getShape());
2778
2779 // Ensure there are no "dim-1" broadcasts.
2780 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2781 .empty() &&
2782 "unexpected \"dim-1\" broadcast");
2783
2784 VectorType broadcastType = VectorType::get(shape: broadcastShape, elementType);
2785 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2786 vector::BroadcastableToResult::Success &&
2787 "must be broadcastable");
2788 Value res = b.createOrFold<vector::BroadcastOp>(location: loc, args&: broadcastType, args&: value);
2789 // Step 4. If we find any dimension that indeed needs to be permuted,
2790 // immediately return a new vector.transpose.
2791 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2792 if (permutation[i] != i)
2793 return b.createOrFold<vector::TransposeOp>(location: loc, args&: res, args&: permutation);
2794 // Otherwise return res.
2795 return res;
2796}
2797
2798BroadcastableToResult mlir::vector::isBroadcastableTo(
2799 Type srcType, VectorType dstVectorType,
2800 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2801 // Broadcast scalar to vector of the same element type.
2802 if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2803 getElementTypeOrSelf(type: srcType) == getElementTypeOrSelf(type: dstVectorType))
2804 return BroadcastableToResult::Success;
2805 // From now on, only vectors broadcast.
2806 VectorType srcVectorType = llvm::dyn_cast<VectorType>(Val&: srcType);
2807 if (!srcVectorType)
2808 return BroadcastableToResult::SourceTypeNotAVector;
2809
2810 int64_t srcRank = srcVectorType.getRank();
2811 int64_t dstRank = dstVectorType.getRank();
2812 if (srcRank > dstRank)
2813 return BroadcastableToResult::SourceRankHigher;
2814 // Source has an exact match or singleton value for all trailing dimensions
2815 // (all leading dimensions are simply duplicated).
2816 int64_t lead = dstRank - srcRank;
2817 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2818 // Have mismatching dims (in the sense of vector.broadcast semantics) been
2819 // encountered?
2820 bool foundMismatchingDims = false;
2821
2822 // Check fixed-width dims.
2823 int64_t srcDim = srcVectorType.getDimSize(idx: dimIdx);
2824 int64_t dstDim = dstVectorType.getDimSize(idx: lead + dimIdx);
2825 if (srcDim != 1 && srcDim != dstDim)
2826 foundMismatchingDims = true;
2827
2828 // Check scalable flags.
2829 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2830 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2831 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2832 // 1 -> [N] is fine, everything else should be rejected when mixing
2833 // fixed-width and scalable dims
2834 (srcDimScalableFlag != dstDimScalableFlag &&
2835 (srcDim != 1 || srcDimScalableFlag)))
2836 foundMismatchingDims = true;
2837
2838 if (foundMismatchingDims) {
2839 if (mismatchingDims != nullptr) {
2840 mismatchingDims->first.dim = srcDim;
2841 mismatchingDims->first.isScalable = srcDimScalableFlag;
2842
2843 mismatchingDims->second.dim = dstDim;
2844 mismatchingDims->second.isScalable = dstDimScalableFlag;
2845 }
2846 return BroadcastableToResult::DimensionMismatch;
2847 }
2848 }
2849
2850 return BroadcastableToResult::Success;
2851}
2852
2853LogicalResult BroadcastOp::verify() {
2854 std::pair<VectorDim, VectorDim> mismatchingDims;
2855 BroadcastableToResult res = isBroadcastableTo(
2856 srcType: getSourceType(), dstVectorType: getResultVectorType(), mismatchingDims: &mismatchingDims);
2857 if (res == BroadcastableToResult::Success)
2858 return success();
2859 if (res == BroadcastableToResult::SourceRankHigher)
2860 return emitOpError(message: "source rank higher than destination rank");
2861 if (res == BroadcastableToResult::DimensionMismatch) {
2862 return emitOpError(message: "dimension mismatch (")
2863 << (mismatchingDims.first.isScalable ? "[" : "")
2864 << mismatchingDims.first.dim
2865 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2866 << (mismatchingDims.second.isScalable ? "[" : "")
2867 << mismatchingDims.second.dim
2868 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2869 }
2870 if (res == BroadcastableToResult::SourceTypeNotAVector)
2871 return emitOpError(message: "source type is not a vector");
2872 llvm_unreachable("unexpected vector.broadcast op error");
2873}
2874
2875OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2876 if (getSourceType() == getResultVectorType())
2877 return getSource();
2878 if (!adaptor.getSource())
2879 return {};
2880 auto vectorType = getResultVectorType();
2881 if (auto attr = llvm::dyn_cast<IntegerAttr>(Val: adaptor.getSource())) {
2882 if (vectorType.getElementType() != attr.getType())
2883 return {};
2884 return DenseElementsAttr::get(type: vectorType, values: attr);
2885 }
2886 if (auto attr = llvm::dyn_cast<FloatAttr>(Val: adaptor.getSource())) {
2887 if (vectorType.getElementType() != attr.getType())
2888 return {};
2889 return DenseElementsAttr::get(type: vectorType, values: attr);
2890 }
2891 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(Val: adaptor.getSource()))
2892 return DenseElementsAttr::get(type: vectorType, values: attr.getSplatValue<Attribute>());
2893 if (llvm::dyn_cast<ub::PoisonAttr>(Val: adaptor.getSource()))
2894 return ub::PoisonAttr::get(ctx: getContext());
2895 return {};
2896}
2897
2898namespace {
2899
2900// Fold broadcast1(broadcast2(x)) into broadcast1(x).
2901struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2902 using OpRewritePattern::OpRewritePattern;
2903
2904 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2905 PatternRewriter &rewriter) const override {
2906 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2907 if (!srcBroadcast)
2908 return failure();
2909 rewriter.replaceOpWithNewOp<BroadcastOp>(op: broadcastOp,
2910 args: broadcastOp.getResultVectorType(),
2911 args: srcBroadcast.getSource());
2912 return success();
2913 }
2914};
2915} // namespace
2916
2917void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2918 MLIRContext *context) {
2919 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2920 // calling `populateCastAwayVectorLeadingOneDimPatterns`
2921 results.add<BroadcastFolder>(arg&: context);
2922}
2923
2924//===----------------------------------------------------------------------===//
2925// ShuffleOp
2926//===----------------------------------------------------------------------===//
2927
2928LogicalResult ShuffleOp::verify() {
2929 VectorType resultType = getResultVectorType();
2930 VectorType v1Type = getV1VectorType();
2931 VectorType v2Type = getV2VectorType();
2932 // Verify ranks.
2933 int64_t resRank = resultType.getRank();
2934 int64_t v1Rank = v1Type.getRank();
2935 int64_t v2Rank = v2Type.getRank();
2936 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2937 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2938 if (!wellFormed0DCase && !wellFormedNDCase)
2939 return emitOpError(message: "rank mismatch");
2940
2941 // Verify all but leading dimension sizes.
2942 for (int64_t r = 1; r < v1Rank; ++r) {
2943 int64_t resDim = resultType.getDimSize(idx: r);
2944 int64_t v1Dim = v1Type.getDimSize(idx: r);
2945 int64_t v2Dim = v2Type.getDimSize(idx: r);
2946 if (resDim != v1Dim || v1Dim != v2Dim)
2947 return emitOpError(message: "dimension mismatch");
2948 }
2949 // Verify mask length.
2950 ArrayRef<int64_t> mask = getMask();
2951 int64_t maskLength = mask.size();
2952 if (maskLength <= 0)
2953 return emitOpError(message: "invalid mask length");
2954 if (maskLength != resultType.getDimSize(idx: 0))
2955 return emitOpError(message: "mask length mismatch");
2956 // Verify all indices.
2957 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(idx: 0)) +
2958 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(idx: 0));
2959 for (auto [idx, maskPos] : llvm::enumerate(First&: mask)) {
2960 if (!isValidPositiveIndexOrPoison(index: maskPos, poisonValue: kPoisonIndex, maxIndex: indexSize))
2961 return emitOpError(message: "mask index #") << (idx + 1) << " out of range";
2962 }
2963 return success();
2964}
2965
2966LogicalResult
2967ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2968 ShuffleOp::Adaptor adaptor,
2969 SmallVectorImpl<Type> &inferredReturnTypes) {
2970 auto v1Type = llvm::cast<VectorType>(Val: adaptor.getV1().getType());
2971 auto v1Rank = v1Type.getRank();
2972 // Construct resulting type: leading dimension matches mask
2973 // length, all trailing dimensions match the operands.
2974 SmallVector<int64_t, 4> shape;
2975 shape.reserve(N: v1Rank);
2976 shape.push_back(Elt: std::max<size_t>(a: 1, b: adaptor.getMask().size()));
2977 // In the 0-D case there is no trailing shape to append.
2978 if (v1Rank > 0)
2979 llvm::append_range(C&: shape, R: v1Type.getShape().drop_front());
2980 inferredReturnTypes.push_back(
2981 Elt: VectorType::get(shape, elementType: v1Type.getElementType()));
2982 return success();
2983}
2984
2985template <typename T>
2986static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2987 T expected = begin;
2988 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2989 return value == expected++;
2990 });
2991}
2992
2993OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2994 auto v1Type = getV1VectorType();
2995 auto v2Type = getV2VectorType();
2996
2997 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
2998 "Vector shuffle does not support scalable vectors");
2999
3000 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3001 // but must be a canonicalization into a vector.broadcast.
3002 if (v1Type.getRank() == 0)
3003 return {};
3004
3005 // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3006 auto mask = getMask();
3007 if (isStepIndexArray(idxArr: mask, begin: 0, width: v1Type.getDimSize(idx: 0)))
3008 return getV1();
3009 // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3010 if (isStepIndexArray(idxArr: mask, begin: v1Type.getDimSize(idx: 0), width: v2Type.getDimSize(idx: 0)))
3011 return getV2();
3012
3013 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3014 if (!v1Attr || !v2Attr)
3015 return {};
3016
3017 // Fold shuffle poison, poison -> poison.
3018 bool isV1Poison = isa<ub::PoisonAttr>(Val: v1Attr);
3019 bool isV2Poison = isa<ub::PoisonAttr>(Val: v2Attr);
3020 if (isV1Poison && isV2Poison)
3021 return ub::PoisonAttr::get(ctx: getContext());
3022
3023 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
3024 // manipulation.
3025 if (v1Type.getRank() != 1)
3026 return {};
3027
3028 // Poison input attributes need special handling as they are not
3029 // DenseElementsAttr. If an index is poison, we select the first element of
3030 // the first non-poison input.
3031 SmallVector<Attribute> v1Elements, v2Elements;
3032 Attribute poisonElement;
3033 if (!isV2Poison) {
3034 v2Elements =
3035 to_vector(Range: cast<DenseElementsAttr>(Val&: v2Attr).getValues<Attribute>());
3036 poisonElement = v2Elements[0];
3037 }
3038 if (!isV1Poison) {
3039 v1Elements =
3040 to_vector(Range: cast<DenseElementsAttr>(Val&: v1Attr).getValues<Attribute>());
3041 poisonElement = v1Elements[0];
3042 }
3043
3044 SmallVector<Attribute> results;
3045 int64_t v1Size = v1Type.getDimSize(idx: 0);
3046 for (int64_t maskIdx : mask) {
3047 Attribute indexedElm;
3048 // TODO: Return a partial poison vector when supported by the UB dialect.
3049 if (maskIdx == ShuffleOp::kPoisonIndex) {
3050 indexedElm = poisonElement;
3051 } else {
3052 if (maskIdx < v1Size)
3053 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3054 else
3055 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3056 }
3057
3058 results.push_back(Elt: indexedElm);
3059 }
3060
3061 return DenseElementsAttr::get(type: getResultVectorType(), values: results);
3062}
3063
3064namespace {
3065
3066// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3067// to a broadcast.
3068struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3069 using OpRewritePattern::OpRewritePattern;
3070
3071 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3072 PatternRewriter &rewriter) const override {
3073 VectorType v1VectorType = shuffleOp.getV1VectorType();
3074 ArrayRef<int64_t> mask = shuffleOp.getMask();
3075 if (v1VectorType.getRank() > 0)
3076 return failure();
3077 if (mask.size() != 1)
3078 return failure();
3079 VectorType resType = VectorType::Builder(v1VectorType).setShape(newShape: {1});
3080 if (mask[0] == 0)
3081 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op: shuffleOp, args&: resType,
3082 args: shuffleOp.getV1());
3083 else
3084 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op: shuffleOp, args&: resType,
3085 args: shuffleOp.getV2());
3086 return success();
3087 }
3088};
3089
3090/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
3091class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3092public:
3093 using OpRewritePattern::OpRewritePattern;
3094
3095 LogicalResult matchAndRewrite(ShuffleOp op,
3096 PatternRewriter &rewriter) const override {
3097 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
3098 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
3099
3100 if (!v1Splat || !v2Splat)
3101 return failure();
3102
3103 if (v1Splat.getInput() != v2Splat.getInput())
3104 return failure();
3105
3106 rewriter.replaceOpWithNewOp<SplatOp>(op, args: op.getType(), args: v1Splat.getInput());
3107 return success();
3108 }
3109};
3110
3111/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3112/// vector.interleave.
3113class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3114public:
3115 using OpRewritePattern::OpRewritePattern;
3116
3117 LogicalResult matchAndRewrite(ShuffleOp op,
3118 PatternRewriter &rewriter) const override {
3119 VectorType resultType = op.getResultVectorType();
3120 if (resultType.isScalable())
3121 return rewriter.notifyMatchFailure(
3122 arg&: op, msg: "ShuffleOp can't represent a scalable interleave");
3123
3124 if (resultType.getRank() != 1)
3125 return rewriter.notifyMatchFailure(
3126 arg&: op, msg: "ShuffleOp can't represent an n-D interleave");
3127
3128 VectorType sourceType = op.getV1VectorType();
3129 if (sourceType != op.getV2VectorType() ||
3130 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3131 return rewriter.notifyMatchFailure(
3132 arg&: op, msg: "ShuffleOp types don't match an interleave");
3133 }
3134
3135 ArrayRef<int64_t> shuffleMask = op.getMask();
3136 int64_t resultVectorSize = resultType.getNumElements();
3137 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3138 int64_t maskValueA = shuffleMask[i * 2];
3139 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3140 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3141 return rewriter.notifyMatchFailure(arg&: op,
3142 msg: "ShuffleOp mask not interleaving");
3143 }
3144
3145 rewriter.replaceOpWithNewOp<InterleaveOp>(op, args: op.getV1(), args: op.getV2());
3146 return success();
3147 }
3148};
3149
3150} // namespace
3151
3152void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3153 MLIRContext *context) {
3154 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3155 arg&: context);
3156}
3157
3158//===----------------------------------------------------------------------===//
3159// InsertElementOp
3160//===----------------------------------------------------------------------===//
3161
3162void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3163 SetIntRangeFn setResultRanges) {
3164 setResultRanges(getResult(), argRanges[0].rangeUnion(other: argRanges[1]));
3165}
3166
3167void InsertElementOp::build(OpBuilder &builder, OperationState &result,
3168 Value source, Value dest) {
3169 build(odsBuilder&: builder, odsState&: result, source, dest, position: {});
3170}
3171
3172LogicalResult InsertElementOp::verify() {
3173 auto dstVectorType = getDestVectorType();
3174 if (dstVectorType.getRank() == 0) {
3175 if (getPosition())
3176 return emitOpError(message: "expected position to be empty with 0-D vector");
3177 return success();
3178 }
3179 if (dstVectorType.getRank() != 1)
3180 return emitOpError(message: "unexpected >1 vector rank");
3181 if (!getPosition())
3182 return emitOpError(message: "expected position for 1-D vector");
3183 return success();
3184}
3185
3186OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
3187 // Skip the 0-D vector here.
3188 if (!adaptor.getPosition())
3189 return {};
3190
3191 auto src = dyn_cast_or_null<TypedAttr>(Val: adaptor.getSource());
3192 auto dst = dyn_cast_or_null<DenseElementsAttr>(Val: adaptor.getDest());
3193 auto pos = dyn_cast_or_null<IntegerAttr>(Val: adaptor.getPosition());
3194 if (!src || !dst || !pos)
3195 return {};
3196
3197 if (src.getType() != getDestVectorType().getElementType())
3198 return {};
3199
3200 auto dstElements = dst.getValues<Attribute>();
3201
3202 SmallVector<Attribute> results(dstElements);
3203
3204 uint64_t posIdx = pos.getInt();
3205 if (posIdx >= results.size())
3206 return {};
3207 results[posIdx] = src;
3208
3209 return DenseElementsAttr::get(type: getDestVectorType(), values: results);
3210}
3211
3212//===----------------------------------------------------------------------===//
3213// InsertOp
3214//===----------------------------------------------------------------------===//
3215
3216void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3217 SetIntRangeFn setResultRanges) {
3218 setResultRanges(getResult(), argRanges[0].rangeUnion(other: argRanges[1]));
3219}
3220
3221void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3222 Value source, Value dest) {
3223 auto vectorTy = cast<VectorType>(Val: dest.getType());
3224 build(odsBuilder&: builder, odsState&: result, valueToStore: source, dest,
3225 position: SmallVector<int64_t>(vectorTy.getRank(), 0));
3226}
3227
3228void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3229 Value source, Value dest, int64_t position) {
3230 build(odsBuilder&: builder, odsState&: result, valueToStore: source, dest, position: ArrayRef<int64_t>{position});
3231}
3232
3233void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3234 Value source, Value dest, OpFoldResult position) {
3235 build(odsBuilder&: builder, odsState&: result, valueToStore: source, dest, position: ArrayRef<OpFoldResult>{position});
3236}
3237
3238void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3239 Value source, Value dest,
3240 ArrayRef<int64_t> position) {
3241 SmallVector<OpFoldResult> posVals;
3242 posVals.reserve(N: position.size());
3243 llvm::transform(Range&: position, d_first: std::back_inserter(x&: posVals),
3244 F: [&](int64_t pos) { return builder.getI64IntegerAttr(value: pos); });
3245 build(odsBuilder&: builder, odsState&: result, valueToStore: source, dest, position: posVals);
3246}
3247
3248void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3249 Value source, Value dest,
3250 ArrayRef<OpFoldResult> position) {
3251 SmallVector<int64_t> staticPos;
3252 SmallVector<Value> dynamicPos;
3253 dispatchIndexOpFoldResults(ofrs: position, dynamicVec&: dynamicPos, staticVec&: staticPos);
3254 build(odsBuilder&: builder, odsState&: result, valueToStore: source, dest, dynamic_position: dynamicPos,
3255 static_position: builder.getDenseI64ArrayAttr(values: staticPos));
3256}
3257
3258LogicalResult InsertOp::verify() {
3259 if (auto srcTy = dyn_cast<VectorType>(Val: getValueToStoreType()))
3260 if (srcTy.getRank() == 0)
3261 return emitError(
3262 message: "expected a scalar instead of a 0-d vector as the source operand");
3263
3264 SmallVector<OpFoldResult> position = getMixedPosition();
3265 auto destVectorType = getDestVectorType();
3266 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3267 return emitOpError(
3268 message: "expected position attribute of rank no greater than dest vector rank");
3269 auto srcVectorType = llvm::dyn_cast<VectorType>(Val: getValueToStoreType());
3270 if (srcVectorType &&
3271 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3272 static_cast<unsigned>(destVectorType.getRank())))
3273 return emitOpError(message: "expected position attribute rank + source rank to "
3274 "match dest vector rank");
3275 if (!srcVectorType &&
3276 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3277 return emitOpError(
3278 message: "expected position attribute rank to match the dest vector rank");
3279 for (auto [idx, pos] : llvm::enumerate(First&: position)) {
3280 if (auto attr = dyn_cast<Attribute>(Val&: pos)) {
3281 int64_t constIdx = cast<IntegerAttr>(Val&: attr).getInt();
3282 if (!isValidPositiveIndexOrPoison(index: constIdx, poisonValue: kPoisonIndex,
3283 maxIndex: destVectorType.getDimSize(idx))) {
3284 return emitOpError(message: "expected position attribute #")
3285 << (idx + 1)
3286 << " to be a non-negative integer smaller than the "
3287 "corresponding "
3288 "dest vector dimension";
3289 }
3290 }
3291 }
3292 return success();
3293}
3294
3295namespace {
3296
3297// If insertOp is only inserting unit dimensions it can be transformed to a
3298// broadcast.
3299class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3300public:
3301 using OpRewritePattern::OpRewritePattern;
3302
3303 LogicalResult matchAndRewrite(InsertOp insertOp,
3304 PatternRewriter &rewriter) const override {
3305 auto srcVecType =
3306 llvm::dyn_cast<VectorType>(Val: insertOp.getValueToStoreType());
3307 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3308 srcVecType.getNumElements())
3309 return failure();
3310 rewriter.replaceOpWithNewOp<BroadcastOp>(
3311 op: insertOp, args: insertOp.getDestVectorType(), args: insertOp.getValueToStore());
3312 return success();
3313 }
3314};
3315
3316/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
3317class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3318public:
3319 using OpRewritePattern::OpRewritePattern;
3320
3321 LogicalResult matchAndRewrite(InsertOp op,
3322 PatternRewriter &rewriter) const override {
3323 auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
3324 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3325
3326 if (!srcSplat || !dstSplat)
3327 return failure();
3328
3329 if (srcSplat.getInput() != dstSplat.getInput())
3330 return failure();
3331
3332 rewriter.replaceOpWithNewOp<SplatOp>(op, args: op.getType(), args: srcSplat.getInput());
3333 return success();
3334 }
3335};
3336} // namespace
3337
3338static Attribute
3339foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
3340 Attribute dstAttr,
3341 int64_t maxVectorSizeFoldThreshold) {
3342 if (insertOp.hasDynamicPosition())
3343 return {};
3344
3345 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(Val&: dstAttr);
3346 if (!denseDst)
3347 return {};
3348
3349 if (!srcAttr) {
3350 return {};
3351 }
3352
3353 VectorType destTy = insertOp.getDestVectorType();
3354 if (destTy.isScalable())
3355 return {};
3356
3357 // Make sure we do not create too many large constants.
3358 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3359 !insertOp->hasOneUse())
3360 return {};
3361
3362 // Calculate the linearized position of the continuous chunk of elements to
3363 // insert.
3364 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3365 copy(Range: insertOp.getStaticPosition(), Out: completePositions.begin());
3366 int64_t insertBeginPosition =
3367 linearize(offsets: completePositions, basis: computeStrides(sizes: destTy.getShape()));
3368
3369 SmallVector<Attribute> insertedValues;
3370 Type destEltType = destTy.getElementType();
3371
3372 /// Converts the expected type to an IntegerAttr if there's
3373 /// a mismatch.
3374 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(Val&: srcAttr)) {
3375 for (auto value : denseSource.getValues<Attribute>())
3376 insertedValues.push_back(Elt: convertIntegerAttr(attr: value, expectedType: destEltType));
3377 } else {
3378 insertedValues.push_back(Elt: convertIntegerAttr(attr: srcAttr, expectedType: destEltType));
3379 }
3380
3381 auto allValues = llvm::to_vector(Range: denseDst.getValues<Attribute>());
3382 copy(Range&: insertedValues, Out: allValues.begin() + insertBeginPosition);
3383 auto newAttr = DenseElementsAttr::get(type: destTy, values: allValues);
3384
3385 return newAttr;
3386}
3387
3388/// Folder to replace the `dest` operand of the insert op with the root dest of
3389/// the insert op use chain.
3390static Value foldInsertUseChain(InsertOp insertOp) {
3391 auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3392 if (!destInsert)
3393 return {};
3394
3395 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3396 return {};
3397
3398 insertOp.setOperand(i: 1, value: destInsert.getDest());
3399 return insertOp.getResult();
3400}
3401
3402void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3403 MLIRContext *context) {
3404 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(arg&: context);
3405}
3406
3407OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3408 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3409 // unless the source vector constant has a single use.
3410 constexpr int64_t vectorSizeFoldThreshold = 256;
3411 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3412 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3413 // (type mismatch).
3414 if (getNumIndices() == 0 && getValueToStoreType() == getType())
3415 return getValueToStore();
3416 // Fold `arith.constant` indices into the `vector.insert` operation.
3417 // Do not stop here as this fold may enable subsequent folds that require
3418 // constant indices.
3419 SmallVector<Value> operands = {getValueToStore(), getDest()};
3420 auto inplaceFolded = extractInsertFoldConstantOp(op: *this, adaptor, operands);
3421
3422 if (auto res = foldInsertUseChain(insertOp: *this))
3423 return res;
3424 if (auto res = foldPoisonIndexInsertExtractOp(
3425 context: getContext(), staticPos: adaptor.getStaticPosition(), poisonVal: kPoisonIndex))
3426 return res;
3427 if (auto res = foldDenseElementsAttrDestInsertOp(
3428 insertOp: *this, srcAttr: adaptor.getValueToStore(), dstAttr: adaptor.getDest(),
3429 maxVectorSizeFoldThreshold: vectorSizeFoldThreshold)) {
3430 return res;
3431 }
3432
3433 return inplaceFolded;
3434}
3435
3436//===----------------------------------------------------------------------===//
3437// InsertStridedSliceOp
3438//===----------------------------------------------------------------------===//
3439
3440void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3441 Value source, Value dest,
3442 ArrayRef<int64_t> offsets,
3443 ArrayRef<int64_t> strides) {
3444 result.addOperands(newOperands: {source, dest});
3445 auto offsetsAttr = getVectorSubscriptAttr(builder, values: offsets);
3446 auto stridesAttr = getVectorSubscriptAttr(builder, values: strides);
3447 result.addTypes(newTypes: dest.getType());
3448 result.addAttribute(name: InsertStridedSliceOp::getOffsetsAttrName(name: result.name),
3449 attr: offsetsAttr);
3450 result.addAttribute(name: InsertStridedSliceOp::getStridesAttrName(name: result.name),
3451 attr: stridesAttr);
3452}
3453
3454// TODO: Should be moved to Tablegen ConfinedAttr attributes.
3455template <typename OpType>
3456static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3457 ArrayAttr arrayAttr,
3458 ArrayRef<int64_t> shape,
3459 StringRef attrName) {
3460 if (arrayAttr.size() > shape.size())
3461 return op.emitOpError("expected ")
3462 << attrName << " attribute of rank no greater than vector rank";
3463 return success();
3464}
3465
3466// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3467// interval. If `halfOpen` is true then the admissible interval is [min, max).
3468// Otherwise, the admissible interval is [min, max].
3469template <typename OpType>
3470static LogicalResult
3471isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
3472 int64_t max, StringRef attrName,
3473 bool halfOpen = true) {
3474 for (auto attr : arrayAttr) {
3475 auto val = llvm::cast<IntegerAttr>(Val&: attr).getInt();
3476 auto upper = max;
3477 if (!halfOpen)
3478 upper += 1;
3479 if (val < min || val >= upper)
3480 return op.emitOpError("expected ") << attrName << " to be confined to ["
3481 << min << ", " << upper << ")";
3482 }
3483 return success();
3484}
3485
3486// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3487// interval. If `halfOpen` is true then the admissible interval is [min, max).
3488// Otherwise, the admissible interval is [min, max].
3489template <typename OpType>
3490static LogicalResult
3491isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
3492 ArrayRef<int64_t> shape, StringRef attrName,
3493 bool halfOpen = true, int64_t min = 0) {
3494 for (auto [index, attrDimPair] :
3495 llvm::enumerate(First: llvm::zip_first(t&: arrayAttr, u&: shape))) {
3496 int64_t val = llvm::cast<IntegerAttr>(Val: std::get<0>(t&: attrDimPair)).getInt();
3497 int64_t max = std::get<1>(t&: attrDimPair);
3498 if (!halfOpen)
3499 max += 1;
3500 if (val < min || val >= max)
3501 return op.emitOpError("expected ")
3502 << attrName << " dimension " << index << " to be confined to ["
3503 << min << ", " << max << ")";
3504 }
3505 return success();
3506}
3507
3508// Returns true if, for all indices i = 0..shape.size()-1, val is in the
3509// [min, max} interval:
3510// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3511// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3512// the admissible interval is [min, max].
3513template <typename OpType>
3514static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
3515 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3516 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3517 bool halfOpen = true, int64_t min = 1) {
3518 assert(arrayAttr1.size() <= shape.size());
3519 assert(arrayAttr2.size() <= shape.size());
3520 for (auto [index, it] :
3521 llvm::enumerate(First: llvm::zip(t&: arrayAttr1, u&: arrayAttr2, args&: shape))) {
3522 auto val1 = llvm::cast<IntegerAttr>(Val: std::get<0>(t&: it)).getInt();
3523 auto val2 = llvm::cast<IntegerAttr>(Val: std::get<1>(t&: it)).getInt();
3524 int64_t max = std::get<2>(t&: it);
3525 if (!halfOpen)
3526 max += 1;
3527 if (val1 + val2 < 0 || val1 + val2 >= max)
3528 return op.emitOpError("expected sum(")
3529 << attrName1 << ", " << attrName2 << ") dimension " << index
3530 << " to be confined to [" << min << ", " << max << ")";
3531 }
3532 return success();
3533}
3534
3535static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
3536 MLIRContext *context) {
3537 auto attrs = llvm::map_range(C&: values, F: [context](int64_t v) -> Attribute {
3538 return IntegerAttr::get(type: IntegerType::get(context, width: 64), value: APInt(64, v));
3539 });
3540 return ArrayAttr::get(context, value: llvm::to_vector<8>(Range&: attrs));
3541}
3542
3543LogicalResult InsertStridedSliceOp::verify() {
3544 auto sourceVectorType = getSourceVectorType();
3545 auto destVectorType = getDestVectorType();
3546 auto offsets = getOffsetsAttr();
3547 auto strides = getStridesAttr();
3548 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3549 return emitOpError(
3550 message: "expected offsets of same size as destination vector rank");
3551 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3552 return emitOpError(message: "expected strides of same size as source vector rank");
3553 if (sourceVectorType.getRank() > destVectorType.getRank())
3554 return emitOpError(
3555 message: "expected source rank to be no greater than destination rank");
3556
3557 auto sourceShape = sourceVectorType.getShape();
3558 auto destShape = destVectorType.getShape();
3559 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3560 destShape.size() - sourceShape.size(), 0);
3561 sourceShapeAsDestShape.append(in_start: sourceShape.begin(), in_end: sourceShape.end());
3562 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3563 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3564 if (failed(Result: isIntegerArrayAttrConfinedToShape(op: *this, arrayAttr: offsets, shape: destShape,
3565 attrName: offName)) ||
3566 failed(Result: isIntegerArrayAttrConfinedToRange(op: *this, arrayAttr: strides, /*min=*/1,
3567 /*max=*/1, attrName: stridesName,
3568 /*halfOpen=*/false)) ||
3569 failed(Result: isSumOfIntegerArrayAttrConfinedToShape(
3570 op: *this, arrayAttr1: offsets,
3571 arrayAttr2: makeI64ArrayAttr(values: sourceShapeAsDestShape, context: getContext()), shape: destShape,
3572 attrName1: offName, attrName2: "source vector shape",
3573 /*halfOpen=*/false, /*min=*/1)))
3574 return failure();
3575
3576 unsigned rankDiff = destShape.size() - sourceShape.size();
3577 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3578 if (sourceVectorType.getScalableDims()[idx] !=
3579 destVectorType.getScalableDims()[idx + rankDiff]) {
3580 return emitOpError(message: "mismatching scalable flags (at source vector idx=")
3581 << idx << ")";
3582 }
3583 if (sourceVectorType.getScalableDims()[idx]) {
3584 auto sourceSize = sourceShape[idx];
3585 auto destSize = destShape[idx + rankDiff];
3586 if (sourceSize != destSize) {
3587 return emitOpError(message: "expected size at idx=")
3588 << idx
3589 << (" to match the corresponding base size from the input "
3590 "vector (")
3591 << sourceSize << (" vs ") << destSize << (")");
3592 }
3593 }
3594 }
3595
3596 return success();
3597}
3598
3599namespace {
3600/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3601/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3602class FoldInsertStridedSliceSplat final
3603 : public OpRewritePattern<InsertStridedSliceOp> {
3604public:
3605 using OpRewritePattern::OpRewritePattern;
3606
3607 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3608 PatternRewriter &rewriter) const override {
3609 auto srcSplatOp =
3610 insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
3611 auto destSplatOp =
3612 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3613
3614 if (!srcSplatOp || !destSplatOp)
3615 return failure();
3616
3617 if (srcSplatOp.getInput() != destSplatOp.getInput())
3618 return failure();
3619
3620 rewriter.replaceOp(op: insertStridedSliceOp, newValues: insertStridedSliceOp.getDest());
3621 return success();
3622 }
3623};
3624
3625/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3626/// to dst.
3627class FoldInsertStridedSliceOfExtract final
3628 : public OpRewritePattern<InsertStridedSliceOp> {
3629public:
3630 using OpRewritePattern::OpRewritePattern;
3631
3632 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3633 PatternRewriter &rewriter) const override {
3634 auto extractStridedSliceOp =
3635 insertStridedSliceOp.getValueToStore()
3636 .getDefiningOp<vector::ExtractStridedSliceOp>();
3637
3638 if (!extractStridedSliceOp)
3639 return failure();
3640
3641 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3642 return failure();
3643
3644 // Check if have the same strides and offsets.
3645 if (extractStridedSliceOp.getStrides() !=
3646 insertStridedSliceOp.getStrides() ||
3647 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3648 return failure();
3649
3650 rewriter.replaceOp(op: insertStridedSliceOp, newValues: insertStridedSliceOp.getDest());
3651 return success();
3652 }
3653};
3654
3655// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3656// ConstantOp.
3657class InsertStridedSliceConstantFolder final
3658 : public OpRewritePattern<InsertStridedSliceOp> {
3659public:
3660 using OpRewritePattern::OpRewritePattern;
3661
3662 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3663 // unless the source vector constant has a single use.
3664 static constexpr int64_t vectorSizeFoldThreshold = 256;
3665
3666 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3667 PatternRewriter &rewriter) const override {
3668 // Return if 'InsertOp' operand is not defined by a compatible vector
3669 // ConstantOp.
3670 TypedValue<VectorType> destVector = op.getDest();
3671 Attribute vectorDestCst;
3672 if (!matchPattern(value: destVector, pattern: m_Constant(bind_value: &vectorDestCst)))
3673 return failure();
3674
3675 VectorType destTy = destVector.getType();
3676 if (destTy.isScalable())
3677 return failure();
3678
3679 // Make sure we do not create too many large constants.
3680 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3681 !destVector.hasOneUse())
3682 return failure();
3683
3684 TypedValue<VectorType> sourceValue = op.getValueToStore();
3685 Attribute sourceCst;
3686 if (!matchPattern(value: sourceValue, pattern: m_Constant(bind_value: &sourceCst)))
3687 return failure();
3688
3689 // TODO: Support poison.
3690 if (isa<ub::PoisonAttr>(Val: vectorDestCst) || isa<ub::PoisonAttr>(Val: sourceCst))
3691 return failure();
3692
3693 // TODO: Handle non-unit strides when they become available.
3694 if (op.hasNonUnitStrides())
3695 return failure();
3696
3697 VectorType sliceVecTy = sourceValue.getType();
3698 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3699 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3700 SmallVector<int64_t, 4> offsets = getI64SubArray(arrayAttr: op.getOffsets());
3701 SmallVector<int64_t, 4> destStrides = computeStrides(sizes: destTy.getShape());
3702
3703 // Calcualte the destination element indices by enumerating all slice
3704 // positions within the destination and linearizing them. The enumeration
3705 // order is lexicographic which yields a sequence of monotonically
3706 // increasing linearized position indices.
3707 // Because the destination may have higher dimensionality then the slice,
3708 // we keep track of two overlapping sets of positions and offsets.
3709 auto denseDest = llvm::cast<DenseElementsAttr>(Val&: vectorDestCst);
3710 auto denseSlice = llvm::cast<DenseElementsAttr>(Val&: sourceCst);
3711 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3712 auto newValues = llvm::to_vector(Range: denseDest.getValues<Attribute>());
3713 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3714 MutableArrayRef<int64_t> currSlicePosition(
3715 currDestPosition.begin() + rankDifference, currDestPosition.end());
3716 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3717 offsets.end());
3718 do {
3719 int64_t linearizedPosition = linearize(offsets: currDestPosition, basis: destStrides);
3720 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3721 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3722 "Invalid slice element");
3723 newValues[linearizedPosition] = *sliceValuesIt;
3724 ++sliceValuesIt;
3725 } while (succeeded(
3726 Result: incSlicePosition(position: currSlicePosition, shape: sliceShape, offsets: sliceOffsets)));
3727
3728 auto newAttr = DenseElementsAttr::get(type: destTy, values: newValues);
3729 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, args&: newAttr);
3730 return success();
3731 }
3732};
3733
3734} // namespace
3735
3736void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3737 RewritePatternSet &results, MLIRContext *context) {
3738 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3739 InsertStridedSliceConstantFolder>(arg&: context);
3740}
3741
3742OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3743 if (getSourceVectorType() == getDestVectorType())
3744 return getValueToStore();
3745 return {};
3746}
3747
3748//===----------------------------------------------------------------------===//
3749// OuterProductOp
3750//===----------------------------------------------------------------------===//
3751
3752/// Build an op without mask, use the type of `acc` as the return type.
3753void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3754 Value lhs, Value rhs, Value acc) {
3755 result.addOperands(newOperands: {lhs, rhs, acc});
3756 result.addTypes(newTypes: acc.getType());
3757}
3758
3759void OuterProductOp::print(OpAsmPrinter &p) {
3760 p << " " << getLhs() << ", " << getRhs();
3761 if (getAcc()) {
3762 p << ", " << getAcc();
3763 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
3764 }
3765 p << " : " << getLhs().getType() << ", " << getRhs().getType();
3766}
3767
3768ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3769 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
3770 Type tLHS, tRHS;
3771 if (parser.parseOperandList(result&: operandsInfo) ||
3772 parser.parseOptionalAttrDict(result&: result.attributes) ||
3773 parser.parseColonType(result&: tLHS) || parser.parseComma() ||
3774 parser.parseType(result&: tRHS))
3775 return failure();
3776 if (operandsInfo.size() < 2)
3777 return parser.emitError(loc: parser.getNameLoc(),
3778 message: "expected at least 2 operands");
3779 VectorType vLHS = llvm::dyn_cast<VectorType>(Val&: tLHS);
3780 VectorType vRHS = llvm::dyn_cast<VectorType>(Val&: tRHS);
3781 if (!vLHS)
3782 return parser.emitError(loc: parser.getNameLoc(),
3783 message: "expected vector type for operand #1");
3784
3785 VectorType resType;
3786 if (vRHS) {
3787 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3788 vRHS.getScalableDims()[0]};
3789 resType = VectorType::get(shape: {vLHS.getDimSize(idx: 0), vRHS.getDimSize(idx: 0)},
3790 elementType: vLHS.getElementType(), scalableDims: scalableDimsRes);
3791 } else {
3792 // Scalar RHS operand
3793 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3794 resType = VectorType::get(shape: {vLHS.getDimSize(idx: 0)}, elementType: vLHS.getElementType(),
3795 scalableDims: scalableDimsRes);
3796 }
3797
3798 if (!result.attributes.get(name: OuterProductOp::getKindAttrName(name: result.name))) {
3799 result.attributes.append(
3800 name: OuterProductOp::getKindAttrName(name: result.name),
3801 attr: CombiningKindAttr::get(context: result.getContext(),
3802 value: OuterProductOp::getDefaultKind()));
3803 }
3804
3805 return failure(
3806 IsFailure: parser.resolveOperand(operand: operandsInfo[0], type: tLHS, result&: result.operands) ||
3807 parser.resolveOperand(operand: operandsInfo[1], type: tRHS, result&: result.operands) ||
3808 (operandsInfo.size() > 2 &&
3809 parser.resolveOperand(operand: operandsInfo[2], type: resType, result&: result.operands)) ||
3810 parser.addTypeToList(type: resType, result&: result.types));
3811}
3812
3813LogicalResult OuterProductOp::verify() {
3814 Type tRHS = getOperandTypeRHS();
3815 VectorType vLHS = getOperandVectorTypeLHS(),
3816 vRHS = llvm::dyn_cast<VectorType>(Val&: tRHS),
3817 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3818
3819 if (vLHS.getRank() != 1)
3820 return emitOpError(message: "expected 1-d vector for operand #1");
3821
3822 if (vRHS) {
3823 // Proper OUTER operation.
3824 if (vRHS.getRank() != 1)
3825 return emitOpError(message: "expected 1-d vector for operand #2");
3826 if (vRES.getRank() != 2)
3827 return emitOpError(message: "expected 2-d vector result");
3828 if (vLHS.getDimSize(idx: 0) != vRES.getDimSize(idx: 0))
3829 return emitOpError(message: "expected #1 operand dim to match result dim #1");
3830 if (vRHS.getDimSize(idx: 0) != vRES.getDimSize(idx: 1))
3831 return emitOpError(message: "expected #2 operand dim to match result dim #2");
3832 if (vLHS.isScalable() && !vRHS.isScalable()) {
3833 // This restriction reflects what's currently supported in terms of
3834 // scalable vectors. However, we could relax this if there's a use case.
3835 return emitOpError(
3836 message: "expected either both or only #2 operand dim to be scalable");
3837 }
3838 } else {
3839 // An AXPY operation.
3840 if (vRES.getRank() != 1)
3841 return emitOpError(message: "expected 1-d vector result");
3842 if (vLHS.getDimSize(idx: 0) != vRES.getDimSize(idx: 0))
3843 return emitOpError(message: "expected #1 operand dim to match result dim #1");
3844 }
3845
3846 if (vACC && vACC != vRES)
3847 return emitOpError(message: "expected operand #3 of same type as result type");
3848
3849 // Verify supported combining kind.
3850 if (!isSupportedCombiningKind(combiningKind: getKind(), elementType: vRES.getElementType()))
3851 return emitOpError(message: "unsupported outerproduct type");
3852
3853 return success();
3854}
3855
3856// MaskableOpInterface methods.
3857
3858/// Returns the mask type expected by this operation. Mostly used for
3859/// verification purposes. It requires the operation to be vectorized."
3860Type OuterProductOp::getExpectedMaskType() {
3861 auto vecType = this->getResultVectorType();
3862 return VectorType::get(shape: vecType.getShape(),
3863 elementType: IntegerType::get(context: vecType.getContext(), /*width=*/1),
3864 scalableDims: vecType.getScalableDims());
3865}
3866
3867//===----------------------------------------------------------------------===//
3868// ExtractStridedSliceOp
3869//===----------------------------------------------------------------------===//
3870
3871// Inference works as follows:
3872// 1. Add 'sizes' from prefix of dims in 'offsets'.
3873// 2. Add sizes from 'vectorType' for remaining dims.
3874// Scalable flags are inherited from 'vectorType'.
3875static Type inferStridedSliceOpResultType(VectorType vectorType,
3876 ArrayAttr offsets, ArrayAttr sizes,
3877 ArrayAttr strides) {
3878 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3879 SmallVector<int64_t, 4> shape;
3880 shape.reserve(N: vectorType.getRank());
3881 unsigned idx = 0;
3882 for (unsigned e = offsets.size(); idx < e; ++idx)
3883 shape.push_back(Elt: llvm::cast<IntegerAttr>(Val: sizes[idx]).getInt());
3884 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3885 shape.push_back(Elt: vectorType.getShape()[idx]);
3886
3887 return VectorType::get(shape, elementType: vectorType.getElementType(),
3888 scalableDims: vectorType.getScalableDims());
3889}
3890
3891void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3892 Value source, ArrayRef<int64_t> offsets,
3893 ArrayRef<int64_t> sizes,
3894 ArrayRef<int64_t> strides) {
3895 result.addOperands(newOperands: source);
3896 auto offsetsAttr = getVectorSubscriptAttr(builder, values: offsets);
3897 auto sizesAttr = getVectorSubscriptAttr(builder, values: sizes);
3898 auto stridesAttr = getVectorSubscriptAttr(builder, values: strides);
3899 result.addTypes(
3900 newTypes: inferStridedSliceOpResultType(vectorType: llvm::cast<VectorType>(Val: source.getType()),
3901 offsets: offsetsAttr, sizes: sizesAttr, strides: stridesAttr));
3902 result.addAttribute(name: ExtractStridedSliceOp::getOffsetsAttrName(name: result.name),
3903 attr: offsetsAttr);
3904 result.addAttribute(name: ExtractStridedSliceOp::getSizesAttrName(name: result.name),
3905 attr: sizesAttr);
3906 result.addAttribute(name: ExtractStridedSliceOp::getStridesAttrName(name: result.name),
3907 attr: stridesAttr);
3908}
3909
3910LogicalResult ExtractStridedSliceOp::verify() {
3911 auto type = getSourceVectorType();
3912 auto offsets = getOffsetsAttr();
3913 auto sizes = getSizesAttr();
3914 auto strides = getStridesAttr();
3915 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3916 return emitOpError(
3917 message: "expected offsets, sizes and strides attributes of same size");
3918
3919 auto shape = type.getShape();
3920 auto offName = getOffsetsAttrName();
3921 auto sizesName = getSizesAttrName();
3922 auto stridesName = getStridesAttrName();
3923 if (failed(
3924 Result: isIntegerArrayAttrSmallerThanShape(op: *this, arrayAttr: offsets, shape, attrName: offName)) ||
3925 failed(
3926 Result: isIntegerArrayAttrSmallerThanShape(op: *this, arrayAttr: sizes, shape, attrName: sizesName)) ||
3927 failed(Result: isIntegerArrayAttrSmallerThanShape(op: *this, arrayAttr: strides, shape,
3928 attrName: stridesName)) ||
3929 failed(
3930 Result: isIntegerArrayAttrConfinedToShape(op: *this, arrayAttr: offsets, shape, attrName: offName)) ||
3931 failed(Result: isIntegerArrayAttrConfinedToShape(op: *this, arrayAttr: sizes, shape, attrName: sizesName,
3932 /*halfOpen=*/false,
3933 /*min=*/1)) ||
3934 failed(Result: isIntegerArrayAttrConfinedToRange(op: *this, arrayAttr: strides, /*min=*/1,
3935 /*max=*/1, attrName: stridesName,
3936 /*halfOpen=*/false)) ||
3937 failed(Result: isSumOfIntegerArrayAttrConfinedToShape(op: *this, arrayAttr1: offsets, arrayAttr2: sizes,
3938 shape, attrName1: offName, attrName2: sizesName,
3939 /*halfOpen=*/false)))
3940 return failure();
3941
3942 auto resultType = inferStridedSliceOpResultType(vectorType: getSourceVectorType(),
3943 offsets, sizes, strides);
3944 if (getResult().getType() != resultType)
3945 return emitOpError(message: "expected result type to be ") << resultType;
3946
3947 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3948 if (type.getScalableDims()[idx]) {
3949 auto inputDim = type.getShape()[idx];
3950 auto inputSize = llvm::cast<IntegerAttr>(Val: sizes[idx]).getInt();
3951 if (inputDim != inputSize)
3952 return emitOpError(message: "expected size at idx=")
3953 << idx
3954 << (" to match the corresponding base size from the input "
3955 "vector (")
3956 << inputSize << (" vs ") << inputDim << (")");
3957 }
3958 }
3959
3960 return success();
3961}
3962
3963// When the source of ExtractStrided comes from a chain of InsertStrided ops try
3964// to use the source of the InsertStrided ops if we can detect that the
3965// extracted vector is a subset of one of the vector inserted.
3966static LogicalResult
3967foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3968 // Helper to extract integer out of ArrayAttr.
3969 auto getElement = [](ArrayAttr array, int idx) {
3970 return llvm::cast<IntegerAttr>(Val: array[idx]).getInt();
3971 };
3972 ArrayAttr extractOffsets = op.getOffsets();
3973 ArrayAttr extractStrides = op.getStrides();
3974 ArrayAttr extractSizes = op.getSizes();
3975 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3976 while (insertOp) {
3977 if (op.getSourceVectorType().getRank() !=
3978 insertOp.getSourceVectorType().getRank())
3979 return failure();
3980 ArrayAttr insertOffsets = insertOp.getOffsets();
3981 ArrayAttr insertStrides = insertOp.getStrides();
3982 // If the rank of extract is greater than the rank of insert, we are likely
3983 // extracting a partial chunk of the vector inserted.
3984 if (extractOffsets.size() > insertOffsets.size())
3985 return failure();
3986 bool patialoverlap = false;
3987 bool disjoint = false;
3988 SmallVector<int64_t, 4> offsetDiffs;
3989 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3990 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3991 return failure();
3992 int64_t start = getElement(insertOffsets, dim);
3993 int64_t end = start + insertOp.getSourceVectorType().getDimSize(idx: dim);
3994 int64_t offset = getElement(extractOffsets, dim);
3995 int64_t size = getElement(extractSizes, dim);
3996 // Check if the start of the extract offset is in the interval inserted.
3997 if (start <= offset && offset < end) {
3998 // If the extract interval overlaps but is not fully included we may
3999 // have a partial overlap that will prevent any folding.
4000 if (offset + size > end)
4001 patialoverlap = true;
4002 offsetDiffs.push_back(Elt: offset - start);
4003 continue;
4004 }
4005 disjoint = true;
4006 break;
4007 }
4008 // The extract element chunk is a subset of the insert element.
4009 if (!disjoint && !patialoverlap) {
4010 op.setOperand(insertOp.getValueToStore());
4011 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4012 OpBuilder b(op.getContext());
4013 op.setOffsetsAttr(b.getI64ArrayAttr(values: offsetDiffs));
4014 return success();
4015 }
4016 // If the chunk extracted is disjoint from the chunk inserted, keep looking
4017 // in the insert chain.
4018 if (disjoint)
4019 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4020 else {
4021 // The extracted vector partially overlap the inserted vector, we cannot
4022 // fold.
4023 return failure();
4024 }
4025 }
4026 return failure();
4027}
4028
4029// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4030static OpFoldResult
4031foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op,
4032 Attribute foldInput) {
4033
4034 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(Val&: foldInput);
4035 if (!dense)
4036 return {};
4037
4038 // TODO: Handle non-unit strides when they become available.
4039 if (op.hasNonUnitStrides())
4040 return {};
4041
4042 VectorType sourceVecTy = op.getSourceVectorType();
4043 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4044 SmallVector<int64_t, 4> sourceStrides = computeStrides(sizes: sourceShape);
4045
4046 VectorType sliceVecTy = op.getType();
4047 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4048 int64_t rank = sliceVecTy.getRank();
4049
4050 // Expand offsets and sizes to match the vector rank.
4051 SmallVector<int64_t, 4> offsets(rank, 0);
4052 copy(Range: getI64SubArray(arrayAttr: op.getOffsets()), Out: offsets.begin());
4053
4054 SmallVector<int64_t, 4> sizes(sourceShape);
4055 copy(Range: getI64SubArray(arrayAttr: op.getSizes()), Out: sizes.begin());
4056
4057 // Calculate the slice elements by enumerating all slice positions and
4058 // linearizing them. The enumeration order is lexicographic which yields a
4059 // sequence of monotonically increasing linearized position indices.
4060 const auto denseValuesBegin = dense.value_begin<Attribute>();
4061 SmallVector<Attribute> sliceValues;
4062 sliceValues.reserve(N: sliceVecTy.getNumElements());
4063 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4064 do {
4065 int64_t linearizedPosition = linearize(offsets: currSlicePosition, basis: sourceStrides);
4066 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4067 "Invalid index");
4068 sliceValues.push_back(Elt: *(denseValuesBegin + linearizedPosition));
4069 } while (succeeded(Result: incSlicePosition(position: currSlicePosition, shape: sliceShape, offsets)));
4070
4071 assert(static_cast<int64_t>(sliceValues.size()) ==
4072 sliceVecTy.getNumElements() &&
4073 "Invalid number of slice elements");
4074 return DenseElementsAttr::get(type: sliceVecTy, values: sliceValues);
4075}
4076
4077OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4078 if (getSourceVectorType() == getResult().getType())
4079 return getVector();
4080 if (succeeded(Result: foldExtractStridedOpFromInsertChain(op: *this)))
4081 return getResult();
4082
4083 // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4084 if (auto splat =
4085 llvm::dyn_cast_if_present<SplatElementsAttr>(Val: adaptor.getVector()))
4086 DenseElementsAttr::get(type: getType(), values: splat.getSplatValue<Attribute>());
4087
4088 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4089 return foldExtractStridedSliceNonSplatConstant(op: *this, foldInput: adaptor.getVector());
4090}
4091
4092void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4093 populateFromInt64AttrArray(arrayAttr: getOffsets(), results);
4094}
4095
4096namespace {
4097
4098// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4099// CreateMaskOp.
4100//
4101// Example:
4102//
4103// %mask = vector.create_mask %ub : vector<16xi1>
4104// %slice = vector.extract_strided_slice [%offset] [8] [1]
4105//
4106// to
4107//
4108// %new_ub = arith.subi %ub, %offset
4109// %mask = vector.create_mask %new_ub : vector<8xi1>
4110class StridedSliceCreateMaskFolder final
4111 : public OpRewritePattern<ExtractStridedSliceOp> {
4112 using OpRewritePattern::OpRewritePattern;
4113
4114public:
4115 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4116 PatternRewriter &rewriter) const override {
4117 Location loc = extractStridedSliceOp.getLoc();
4118 // Return if 'extractStridedSliceOp' operand is not defined by a
4119 // CreateMaskOp.
4120 auto createMaskOp =
4121 extractStridedSliceOp.getVector().getDefiningOp<CreateMaskOp>();
4122 if (!createMaskOp)
4123 return failure();
4124 // Return if 'extractStridedSliceOp' has non-unit strides.
4125 if (extractStridedSliceOp.hasNonUnitStrides())
4126 return failure();
4127 // Gather constant mask dimension sizes.
4128 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4129 // Gather strided slice offsets and sizes.
4130 SmallVector<int64_t> sliceOffsets;
4131 populateFromInt64AttrArray(arrayAttr: extractStridedSliceOp.getOffsets(),
4132 results&: sliceOffsets);
4133 SmallVector<int64_t> sliceSizes;
4134 populateFromInt64AttrArray(arrayAttr: extractStridedSliceOp.getSizes(), results&: sliceSizes);
4135
4136 // Compute slice of vector mask region.
4137 SmallVector<Value> sliceMaskDimSizes;
4138 sliceMaskDimSizes.reserve(N: maskDimSizes.size());
4139 // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4140 // only iterate on the leading dim sizes. The tail accounts for the
4141 // remaining dim sizes.
4142 for (auto [maskDimSize, sliceOffset, sliceSize] :
4143 llvm::zip(t&: maskDimSizes, u&: sliceOffsets, args&: sliceSizes)) {
4144 // No need to clamp on min/max values, because create_mask has clamping
4145 // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4146 // greater than the vector dim size.
4147 IntegerAttr offsetAttr =
4148 rewriter.getIntegerAttr(type: maskDimSize.getType(), value: sliceOffset);
4149 Value offset = rewriter.create<arith::ConstantOp>(location: loc, args&: offsetAttr);
4150 Value sliceMaskDimSize =
4151 rewriter.create<arith::SubIOp>(location: loc, args&: maskDimSize, args&: offset);
4152 sliceMaskDimSizes.push_back(Elt: sliceMaskDimSize);
4153 }
4154 // Add unchanged dimensions.
4155 llvm::append_range(
4156 C&: sliceMaskDimSizes,
4157 R: llvm::drop_begin(RangeOrContainer&: maskDimSizes, N: sliceMaskDimSizes.size()));
4158 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4159 // region.
4160 rewriter.replaceOpWithNewOp<CreateMaskOp>(
4161 op: extractStridedSliceOp, args: extractStridedSliceOp.getResult().getType(),
4162 args&: sliceMaskDimSizes);
4163 return success();
4164 }
4165};
4166
4167// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4168// ConstantMaskOp.
4169class StridedSliceConstantMaskFolder final
4170 : public OpRewritePattern<ExtractStridedSliceOp> {
4171public:
4172 using OpRewritePattern::OpRewritePattern;
4173
4174 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4175 PatternRewriter &rewriter) const override {
4176 // Return if 'extractStridedSliceOp' operand is not defined by a
4177 // ConstantMaskOp.
4178 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
4179 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(Val: defOp);
4180 if (!constantMaskOp)
4181 return failure();
4182 // Return if 'extractStridedSliceOp' has non-unit strides.
4183 if (extractStridedSliceOp.hasNonUnitStrides())
4184 return failure();
4185 // Gather constant mask dimension sizes.
4186 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4187 // Gather strided slice offsets and sizes.
4188 SmallVector<int64_t> sliceOffsets;
4189 populateFromInt64AttrArray(arrayAttr: extractStridedSliceOp.getOffsets(),
4190 results&: sliceOffsets);
4191 SmallVector<int64_t> sliceSizes;
4192 populateFromInt64AttrArray(arrayAttr: extractStridedSliceOp.getSizes(), results&: sliceSizes);
4193
4194 // Compute slice of vector mask region.
4195 SmallVector<int64_t> sliceMaskDimSizes;
4196 sliceMaskDimSizes.reserve(N: maskDimSizes.size());
4197 for (auto [maskDimSize, sliceOffset, sliceSize] :
4198 llvm::zip(t&: maskDimSizes, u&: sliceOffsets, args&: sliceSizes)) {
4199 int64_t sliceMaskDimSize = std::max(
4200 a: static_cast<int64_t>(0),
4201 b: std::min(a: sliceOffset + sliceSize, b: maskDimSize) - sliceOffset);
4202 sliceMaskDimSizes.push_back(Elt: sliceMaskDimSize);
4203 }
4204 // Add unchanged dimensions.
4205 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4206 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4207 sliceMaskDimSizes.push_back(Elt: maskDimSizes[i]);
4208 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4209 // region is a conjunction of mask dim intervals).
4210 if (llvm::is_contained(Range&: sliceMaskDimSizes, Element: 0))
4211 sliceMaskDimSizes.assign(NumElts: maskDimSizes.size(), Elt: 0);
4212
4213 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4214 // region.
4215 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4216 op: extractStridedSliceOp, args: extractStridedSliceOp.getResult().getType(),
4217 args&: sliceMaskDimSizes);
4218 return success();
4219 }
4220};
4221
4222// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4223// BroadcastOp(ExtractStrideSliceOp).
4224class StridedSliceBroadcast final
4225 : public OpRewritePattern<ExtractStridedSliceOp> {
4226public:
4227 using OpRewritePattern::OpRewritePattern;
4228
4229 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4230 PatternRewriter &rewriter) const override {
4231 auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
4232 if (!broadcast)
4233 return failure();
4234 auto srcVecType =
4235 llvm::dyn_cast<VectorType>(Val: broadcast.getSource().getType());
4236 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4237 auto dstVecType = llvm::cast<VectorType>(Val: op.getType());
4238 unsigned dstRank = dstVecType.getRank();
4239 unsigned rankDiff = dstRank - srcRank;
4240 // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4241 // (n -> m with n > m). If they are originally both broadcasted *and*
4242 // sliced, this can be simplified to just broadcasting.
4243 bool needsSlice = false;
4244 for (unsigned i = 0; i < srcRank; i++) {
4245 if (srcVecType.getDimSize(idx: i) != 1 &&
4246 srcVecType.getDimSize(idx: i) != dstVecType.getDimSize(idx: i + rankDiff)) {
4247 needsSlice = true;
4248 break;
4249 }
4250 }
4251 Value source = broadcast.getSource();
4252 if (needsSlice) {
4253 SmallVector<int64_t> offsets =
4254 getI64SubArray(arrayAttr: op.getOffsets(), /*dropFront=*/rankDiff);
4255 SmallVector<int64_t> sizes =
4256 getI64SubArray(arrayAttr: op.getSizes(), /*dropFront=*/rankDiff);
4257 for (unsigned i = 0; i < srcRank; i++) {
4258 if (srcVecType.getDimSize(idx: i) == 1) {
4259 // In case this dimension was broadcasted *and* sliced, the offset
4260 // and size need to be updated now that there is no broadcast before
4261 // the slice.
4262 offsets[i] = 0;
4263 sizes[i] = 1;
4264 }
4265 }
4266 source = rewriter.create<ExtractStridedSliceOp>(
4267 location: op->getLoc(), args&: source, args&: offsets, args&: sizes,
4268 args: getI64SubArray(arrayAttr: op.getStrides(), /*dropFront=*/rankDiff));
4269 }
4270 rewriter.replaceOpWithNewOp<BroadcastOp>(op, args: op.getType(), args&: source);
4271 return success();
4272 }
4273};
4274
4275/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
4276class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4277public:
4278 using OpRewritePattern::OpRewritePattern;
4279
4280 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4281 PatternRewriter &rewriter) const override {
4282 auto splat = op.getVector().getDefiningOp<SplatOp>();
4283 if (!splat)
4284 return failure();
4285 rewriter.replaceOpWithNewOp<SplatOp>(op, args: op.getType(), args: splat.getInput());
4286 return success();
4287 }
4288};
4289
4290/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4291/// slice is contiguous, into extract and shape_cast.
4292///
4293/// Example:
4294/// Before:
4295/// %1 = vector.extract_strided_slice %arg0 {
4296/// offsets = [0, 0, 0, 0, 0],
4297/// sizes = [1, 1, 1, 1, 8],
4298/// strides = [1, 1, 1, 1, 1]
4299/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4300/// After:
4301/// %0 = vector.extract %arg0[0, 0, 0, 0]
4302/// : vector<8xi8> from vector<8x1x1x2x8xi8>
4303/// %1 = vector.shape_cast %0
4304/// : vector<8xi8> to vector<1x1x1x1x8xi8>
4305///
4306class ContiguousExtractStridedSliceToExtract final
4307 : public OpRewritePattern<ExtractStridedSliceOp> {
4308public:
4309 using OpRewritePattern::OpRewritePattern;
4310
4311 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4312 PatternRewriter &rewriter) const override {
4313 if (op.hasNonUnitStrides())
4314 return failure();
4315 Value source = op.getOperand();
4316 auto sourceType = cast<VectorType>(Val: source.getType());
4317 if (sourceType.isScalable() || sourceType.getRank() == 0)
4318 return failure();
4319
4320 // Compute the number of offsets to pass to ExtractOp::build. That is the
4321 // difference between the source rank and the desired slice rank. We walk
4322 // the dimensions from innermost out, and stop when the next slice dimension
4323 // is not full-size.
4324 SmallVector<int64_t> sizes = getI64SubArray(arrayAttr: op.getSizes());
4325 int numOffsets;
4326 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4327 if (sizes[numOffsets - 1] != sourceType.getDimSize(idx: numOffsets - 1))
4328 break;
4329 }
4330
4331 // If the created extract op would have no offsets, then this whole
4332 // extract_strided_slice is the identity and should have been handled by
4333 // other canonicalizations.
4334 if (numOffsets == 0)
4335 return failure();
4336
4337 // If not even the inner-most dimension is full-size, this op can't be
4338 // rewritten as an ExtractOp.
4339 if (numOffsets == sourceType.getRank() &&
4340 static_cast<int>(sizes.size()) == sourceType.getRank())
4341 return failure();
4342
4343 // The outer dimensions must have unit size.
4344 for (int i = 0; i < numOffsets; ++i) {
4345 if (sizes[i] != 1)
4346 return failure();
4347 }
4348
4349 // Avoid generating slices that have leading unit dimensions. The shape_cast
4350 // op that we create below would take bad generic fallback patterns
4351 // (ShapeCastOpRewritePattern).
4352 while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4353 sizes[numOffsets] == 1) {
4354 ++numOffsets;
4355 }
4356
4357 SmallVector<int64_t> offsets = getI64SubArray(arrayAttr: op.getOffsets());
4358 auto extractOffsets = ArrayRef(offsets).take_front(N: numOffsets);
4359 Value extract = rewriter.create<vector::ExtractOp>(location: op->getLoc(), args&: source,
4360 args&: extractOffsets);
4361 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, args: op.getType(), args&: extract);
4362 return success();
4363 }
4364};
4365
4366} // namespace
4367
4368void ExtractStridedSliceOp::getCanonicalizationPatterns(
4369 RewritePatternSet &results, MLIRContext *context) {
4370 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4371 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4372 results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4373 StridedSliceBroadcast, StridedSliceSplat,
4374 ContiguousExtractStridedSliceToExtract>(arg&: context);
4375}
4376
4377//===----------------------------------------------------------------------===//
4378// TransferReadOp
4379//===----------------------------------------------------------------------===//
4380
4381/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4382void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4383 VectorType vectorType, Value source,
4384 ValueRange indices, std::optional<Value> padding,
4385 AffineMapAttr permutationMapAttr,
4386 /*optional*/ ArrayAttr inBoundsAttr) {
4387
4388 Type elemType = llvm::cast<ShapedType>(Val: source.getType()).getElementType();
4389 if (!padding)
4390 padding = builder.create<ub::PoisonOp>(location: result.location, args&: elemType);
4391 build(odsBuilder&: builder, odsState&: result, vector: vectorType, base: source, indices, permutation_map: permutationMapAttr,
4392 padding: *padding, /*mask=*/Value(), in_bounds: inBoundsAttr);
4393}
4394
4395/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4396void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4397 VectorType vectorType, Value source,
4398 ValueRange indices, std::optional<Value> padding,
4399 AffineMap permutationMap,
4400 std::optional<ArrayRef<bool>> inBounds) {
4401 auto permutationMapAttr = AffineMapAttr::get(value: permutationMap);
4402 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4403 ? builder.getBoolArrayAttr(values: inBounds.value())
4404 : builder.getBoolArrayAttr(
4405 values: SmallVector<bool>(vectorType.getRank(), false));
4406 Type elemType = llvm::cast<ShapedType>(Val: source.getType()).getElementType();
4407 if (!padding)
4408 padding = builder.create<ub::PoisonOp>(location: result.location, args&: elemType);
4409 build(builder, result, vectorType, source, indices, padding: *padding,
4410 permutationMapAttr, inBoundsAttr);
4411}
4412
4413/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4414void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4415 VectorType vectorType, Value source,
4416 ValueRange indices, std::optional<Value> padding,
4417 std::optional<ArrayRef<bool>> inBounds) {
4418 AffineMap permutationMap = getTransferMinorIdentityMap(
4419 shapedType: llvm::cast<ShapedType>(Val: source.getType()), vectorType);
4420 auto permutationMapAttr = AffineMapAttr::get(value: permutationMap);
4421 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4422 ? builder.getBoolArrayAttr(values: inBounds.value())
4423 : builder.getBoolArrayAttr(
4424 values: SmallVector<bool>(vectorType.getRank(), false));
4425 Type elemType = llvm::cast<ShapedType>(Val: source.getType()).getElementType();
4426 if (!padding)
4427 padding = builder.create<ub::PoisonOp>(location: result.location, args&: elemType);
4428 build(odsBuilder&: builder, odsState&: result, vector: vectorType, base: source, indices, permutation_map: permutationMapAttr,
4429 padding: *padding,
4430 /*mask=*/Value(), in_bounds: inBoundsAttr);
4431}
4432
4433template <typename EmitFun>
4434static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4435 EmitFun emitOpError) {
4436 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4437 for (auto expr : permutationMap.getResults()) {
4438 auto dim = dyn_cast<AffineDimExpr>(Val&: expr);
4439 auto zero = dyn_cast<AffineConstantExpr>(Val&: expr);
4440 if (zero) {
4441 if (zero.getValue() != 0) {
4442 return emitOpError(
4443 "requires a projected permutation_map (at most one dim or the zero "
4444 "constant can appear in each result)");
4445 }
4446 continue;
4447 }
4448 if (!dim) {
4449 return emitOpError("requires a projected permutation_map (at most one "
4450 "dim or the zero constant can appear in each result)");
4451 }
4452 if (seen[dim.getPosition()]) {
4453 return emitOpError(
4454 "requires a permutation_map that is a permutation (found one dim "
4455 "used more than once)");
4456 }
4457 seen[dim.getPosition()] = true;
4458 }
4459 return success();
4460}
4461
4462static LogicalResult
4463verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4464 VectorType vectorType, VectorType maskType,
4465 VectorType inferredMaskType, AffineMap permutationMap,
4466 ArrayAttr inBounds) {
4467 if (op->hasAttr(name: "masked")) {
4468 return op->emitOpError(message: "masked attribute has been removed. "
4469 "Use in_bounds instead.");
4470 }
4471
4472 if (!llvm::isa<MemRefType, RankedTensorType>(Val: shapedType))
4473 return op->emitOpError(
4474 message: "requires source to be a memref or ranked tensor type");
4475
4476 auto elementType = shapedType.getElementType();
4477 DataLayout dataLayout = DataLayout::closest(op);
4478 if (auto vectorElementType = llvm::dyn_cast<VectorType>(Val&: elementType)) {
4479 // Memref or tensor has vector element type.
4480 unsigned sourceVecSize =
4481 dataLayout.getTypeSizeInBits(t: vectorElementType.getElementType()) *
4482 vectorElementType.getShape().back();
4483 unsigned resultVecSize =
4484 dataLayout.getTypeSizeInBits(t: vectorType.getElementType()) *
4485 vectorType.getShape().back();
4486 if (resultVecSize % sourceVecSize != 0)
4487 return op->emitOpError(
4488 message: "requires the bitwidth of the minor 1-D vector to be an integral "
4489 "multiple of the bitwidth of the minor 1-D vector of the source");
4490
4491 unsigned sourceVecEltRank = vectorElementType.getRank();
4492 unsigned resultVecRank = vectorType.getRank();
4493 if (sourceVecEltRank > resultVecRank)
4494 return op->emitOpError(
4495 message: "requires source vector element and vector result ranks to match.");
4496 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4497 // Check that permutation map results match 'rankOffset' of vector type.
4498 if (permutationMap.getNumResults() != rankOffset)
4499 return op->emitOpError(message: "requires a permutation_map with result dims of "
4500 "the same rank as the vector type");
4501
4502 if (maskType)
4503 return op->emitOpError(message: "does not support masks with vector element type");
4504 } else {
4505 // Memref or tensor has scalar element type.
4506 unsigned minorSize =
4507 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4508 unsigned resultVecSize =
4509 dataLayout.getTypeSizeInBits(t: vectorType.getElementType()) * minorSize;
4510 if (resultVecSize % dataLayout.getTypeSizeInBits(t: elementType) != 0)
4511 return op->emitOpError(
4512 message: "requires the bitwidth of the minor 1-D vector to be an integral "
4513 "multiple of the bitwidth of the source element type");
4514
4515 // Check that permutation map results match rank of vector type.
4516 if (permutationMap.getNumResults() != vectorType.getRank())
4517 return op->emitOpError(message: "requires a permutation_map with result dims of "
4518 "the same rank as the vector type");
4519 }
4520
4521 if (permutationMap.getNumSymbols() != 0)
4522 return op->emitOpError(message: "requires permutation_map without symbols");
4523
4524 if (permutationMap.getNumInputs() != shapedType.getRank())
4525 return op->emitOpError(message: "requires a permutation_map with input dims of the "
4526 "same rank as the source type");
4527
4528 if (maskType && maskType != inferredMaskType)
4529 return op->emitOpError(message: "inferred mask type (")
4530 << inferredMaskType << ") and mask operand type (" << maskType
4531 << ") don't match";
4532
4533 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4534 return op->emitOpError(message: "expects the in_bounds attr of same rank "
4535 "as permutation_map results: ")
4536 << AffineMapAttr::get(value: permutationMap)
4537 << " vs inBounds of size: " << inBounds.size();
4538
4539 return success();
4540}
4541
4542static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4543 SmallVector<StringRef, 3> elidedAttrs;
4544 elidedAttrs.push_back(Elt: TransferReadOp::getOperandSegmentSizeAttr());
4545 if (op.getPermutationMap().isMinorIdentity())
4546 elidedAttrs.push_back(Elt: op.getPermutationMapAttrName());
4547 // Elide in_bounds attribute if all dims are out-of-bounds.
4548 if (llvm::none_of(Range: op.getInBoundsValues(), P: [](bool b) { return b; }))
4549 elidedAttrs.push_back(Elt: op.getInBoundsAttrName());
4550 p.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs);
4551}
4552
4553void TransferReadOp::print(OpAsmPrinter &p) {
4554 p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
4555 if (getMask())
4556 p << ", " << getMask();
4557 printTransferAttrs(p, op: *this);
4558 p << " : " << getShapedType() << ", " << getVectorType();
4559}
4560
4561VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4562 AffineMap permMap) {
4563 auto i1Type = IntegerType::get(context: permMap.getContext(), width: 1);
4564 AffineMap invPermMap = inversePermutation(map: compressUnusedDims(map: permMap));
4565 assert(invPermMap && "Inversed permutation map couldn't be computed");
4566 SmallVector<int64_t, 8> maskShape = invPermMap.compose(values: vecType.getShape());
4567
4568 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4569 // 0-D mask into a single-element 1-D mask.
4570 if (maskShape.empty())
4571 maskShape.push_back(Elt: 1);
4572
4573 SmallVector<bool> scalableDims =
4574 applyPermutationMap(map: invPermMap, source: vecType.getScalableDims());
4575
4576 return VectorType::get(shape: maskShape, elementType: i1Type, scalableDims);
4577}
4578
4579ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4580 auto &builder = parser.getBuilder();
4581 SMLoc typesLoc;
4582 OpAsmParser::UnresolvedOperand sourceInfo;
4583 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
4584 OpAsmParser::UnresolvedOperand paddingInfo;
4585 SmallVector<Type, 2> types;
4586 OpAsmParser::UnresolvedOperand maskInfo;
4587 // Parsing with support for paddingValue.
4588 if (parser.parseOperand(result&: sourceInfo) ||
4589 parser.parseOperandList(result&: indexInfo, delimiter: OpAsmParser::Delimiter::Square) ||
4590 parser.parseComma() || parser.parseOperand(result&: paddingInfo))
4591 return failure();
4592 ParseResult hasMask = parser.parseOptionalComma();
4593 if (hasMask.succeeded()) {
4594 if (parser.parseOperand(result&: maskInfo))
4595 return failure();
4596 }
4597 if (parser.parseOptionalAttrDict(result&: result.attributes) ||
4598 parser.getCurrentLocation(loc: &typesLoc) || parser.parseColonTypeList(result&: types))
4599 return failure();
4600 if (types.size() != 2)
4601 return parser.emitError(loc: typesLoc, message: "requires two types");
4602 auto indexType = builder.getIndexType();
4603 auto shapedType = llvm::dyn_cast<ShapedType>(Val&: types[0]);
4604 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(Val: shapedType))
4605 return parser.emitError(loc: typesLoc, message: "requires memref or ranked tensor type");
4606 VectorType vectorType = llvm::dyn_cast<VectorType>(Val&: types[1]);
4607 if (!vectorType)
4608 return parser.emitError(loc: typesLoc, message: "requires vector type");
4609 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(name: result.name);
4610 Attribute permMapAttr = result.attributes.get(name: permMapAttrName);
4611 AffineMap permMap;
4612 if (!permMapAttr) {
4613 if (shapedType.getRank() <
4614 getEffectiveVectorRankForXferOp(shapedType, vectorType))
4615 return parser.emitError(loc: typesLoc,
4616 message: "expected a custom permutation_map when "
4617 "rank(source) != rank(destination)");
4618 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4619 result.attributes.set(name: permMapAttrName, value: AffineMapAttr::get(value: permMap));
4620 } else {
4621 permMap = llvm::cast<AffineMapAttr>(Val&: permMapAttr).getValue();
4622 }
4623 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(name: result.name);
4624 Attribute inBoundsAttr = result.attributes.get(name: inBoundsAttrName);
4625 if (!inBoundsAttr) {
4626 result.addAttribute(name: inBoundsAttrName,
4627 attr: builder.getBoolArrayAttr(
4628 values: SmallVector<bool>(permMap.getNumResults(), false)));
4629 }
4630 if (parser.resolveOperand(operand: sourceInfo, type: shapedType, result&: result.operands) ||
4631 parser.resolveOperands(operands&: indexInfo, type: indexType, result&: result.operands) ||
4632 parser.resolveOperand(operand: paddingInfo, type: shapedType.getElementType(),
4633 result&: result.operands))
4634 return failure();
4635 if (hasMask.succeeded()) {
4636 if (llvm::dyn_cast<VectorType>(Val: shapedType.getElementType()))
4637 return parser.emitError(
4638 loc: maskInfo.location, message: "does not support masks with vector element type");
4639 if (vectorType.getRank() != permMap.getNumResults()) {
4640 return parser.emitError(loc: typesLoc,
4641 message: "expected the same rank for the vector and the "
4642 "results of the permutation map");
4643 }
4644 // Instead of adding the mask type as an op type, compute it based on the
4645 // vector type and the permutation map (to keep the type signature small).
4646 auto maskType = inferTransferOpMaskType(vecType: vectorType, permMap);
4647 if (parser.resolveOperand(operand: maskInfo, type: maskType, result&: result.operands))
4648 return failure();
4649 }
4650 result.addAttribute(name: TransferReadOp::getOperandSegmentSizeAttr(),
4651 attr: builder.getDenseI32ArrayAttr(
4652 values: {1, static_cast<int32_t>(indexInfo.size()), 1,
4653 static_cast<int32_t>(hasMask.succeeded())}));
4654 return parser.addTypeToList(type: vectorType, result&: result.types);
4655}
4656
4657LogicalResult TransferReadOp::verify() {
4658 // Consistency of elemental types in source and vector.
4659 ShapedType shapedType = getShapedType();
4660 VectorType vectorType = getVectorType();
4661 VectorType maskType = getMaskType();
4662 auto paddingType = getPadding().getType();
4663 auto permutationMap = getPermutationMap();
4664 VectorType inferredMaskType =
4665 maskType ? inferTransferOpMaskType(vecType: vectorType, permMap: permutationMap)
4666 : VectorType();
4667 auto sourceElementType = shapedType.getElementType();
4668
4669 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4670 return emitOpError(message: "requires ") << shapedType.getRank() << " indices";
4671
4672 if (failed(Result: verifyTransferOp(op: cast<VectorTransferOpInterface>(Val: getOperation()),
4673 shapedType, vectorType, maskType,
4674 inferredMaskType, permutationMap, inBounds: getInBounds())))
4675 return failure();
4676
4677 if (auto sourceVectorElementType =
4678 llvm::dyn_cast<VectorType>(Val&: sourceElementType)) {
4679 // Source has vector element type.
4680 // Check that 'sourceVectorElementType' and 'paddingType' types match.
4681 if (sourceVectorElementType != paddingType)
4682 return emitOpError(
4683 message: "requires source element type and padding type to match.");
4684
4685 } else {
4686 // Check that 'paddingType' is valid to store in a vector type.
4687 if (!VectorType::isValidElementType(t: paddingType))
4688 return emitOpError(message: "requires valid padding vector elemental type");
4689
4690 // Check that padding type and vector element types match.
4691 if (paddingType != sourceElementType)
4692 return emitOpError(
4693 message: "requires formal padding and source of the same elemental type");
4694 }
4695
4696 return verifyPermutationMap(permutationMap,
4697 emitOpError: [&](Twine t) { return emitOpError(message: t); });
4698}
4699
4700// MaskableOpInterface methods.
4701
4702/// Returns the mask type expected by this operation. Mostly used for
4703/// verification purposes. It requires the operation to be vectorized."
4704Type TransferReadOp::getExpectedMaskType() {
4705 return inferTransferOpMaskType(vecType: getVectorType(), permMap: getPermutationMap());
4706}
4707
4708//===----------------------------------------------------------------------===//
4709// TransferReadOp: VectorTransferOpInterface methods.
4710//===----------------------------------------------------------------------===//
4711VectorType TransferReadOp::getVectorType() {
4712 return cast<VectorType>(Val: getVector().getType());
4713}
4714
4715template <typename TransferOp>
4716static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4717 // TODO: support more aggressive createOrFold on:
4718 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4719 if (op.getShapedType().isDynamicDim(indicesIdx))
4720 return false;
4721 Value index = op.getIndices()[indicesIdx];
4722 std::optional<int64_t> cstOp = getConstantIntValue(ofr: index);
4723 if (!cstOp.has_value())
4724 return false;
4725
4726 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4727 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4728
4729 return cstOp.value() + vectorSize <= sourceSize;
4730}
4731
4732template <typename TransferOp>
4733static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4734 // TODO: support 0-d corner case.
4735 // TODO: Be less conservative.
4736 if (op.getTransferRank() == 0)
4737 return failure();
4738 AffineMap permutationMap = op.getPermutationMap();
4739 bool changed = false;
4740 SmallVector<bool, 4> newInBounds;
4741 newInBounds.reserve(N: op.getTransferRank());
4742 // Idxs of non-bcast dims - used when analysing bcast dims.
4743 SmallVector<unsigned> nonBcastDims;
4744
4745 // 1. Process non-broadcast dims
4746 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4747 // 1.1. Already marked as in-bounds, nothing to see here.
4748 if (op.isDimInBounds(i)) {
4749 newInBounds.push_back(Elt: true);
4750 continue;
4751 }
4752 // 1.2. Currently out-of-bounds, check whether we can statically determine
4753 // it is inBounds.
4754 bool inBounds = false;
4755 auto dimExpr = dyn_cast<AffineDimExpr>(Val: permutationMap.getResult(idx: i));
4756 if (dimExpr) {
4757 inBounds = isInBounds(op, /*resultIdx=*/i,
4758 /*indicesIdx=*/dimExpr.getPosition());
4759 nonBcastDims.push_back(Elt: i);
4760 }
4761
4762 newInBounds.push_back(Elt: inBounds);
4763 // We commit the pattern if it is "more inbounds".
4764 changed |= inBounds;
4765 }
4766
4767 // 2. Handle broadcast dims
4768 // If all non-broadcast dims are "in bounds", then all bcast dims should be
4769 // "in bounds" as well.
4770 bool allNonBcastDimsInBounds = llvm::all_of(
4771 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4772 if (allNonBcastDimsInBounds) {
4773 for (size_t idx : permutationMap.getBroadcastDims()) {
4774 changed |= !newInBounds[idx];
4775 newInBounds[idx] = true;
4776 }
4777 }
4778
4779 if (!changed)
4780 return failure();
4781 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4782 OpBuilder b(op.getContext());
4783 op.setInBoundsAttr(b.getBoolArrayAttr(values: newInBounds));
4784 return success();
4785}
4786
4787template <typename TransferOp>
4788static LogicalResult foldTransferFullMask(TransferOp op) {
4789 auto mask = op.getMask();
4790 if (!mask)
4791 return failure();
4792
4793 if (getMaskFormat(mask) != MaskFormat::AllTrue)
4794 return failure();
4795
4796 op.getMaskMutable().clear();
4797 return success();
4798}
4799
4800/// ```
4801/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4802/// : vector<1x4xf32>, tensor<4x4xf32>
4803/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4804/// : tensor<4x4xf32>, vector<1x4xf32>
4805/// ```
4806/// -> Folds into
4807/// ```
4808/// %v0
4809/// ```
4810static Value foldRAW(TransferReadOp readOp) {
4811 if (!llvm::isa<RankedTensorType>(Val: readOp.getShapedType()))
4812 return {};
4813 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4814 while (defWrite) {
4815 if (checkSameValueRAW(defWrite, read: readOp))
4816 return defWrite.getVector();
4817 if (!isDisjointTransferIndices(
4818 transferA: cast<VectorTransferOpInterface>(Val: defWrite.getOperation()),
4819 transferB: cast<VectorTransferOpInterface>(Val: readOp.getOperation())))
4820 break;
4821 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
4822 }
4823 return {};
4824}
4825
4826OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4827 if (Value vec = foldRAW(readOp: *this))
4828 return vec;
4829 /// transfer_read(memrefcast) -> transfer_read
4830 if (succeeded(Result: foldTransferInBoundsAttribute(op: *this)))
4831 return getResult();
4832 if (succeeded(Result: foldTransferFullMask(op: *this)))
4833 return getResult();
4834 if (succeeded(Result: memref::foldMemRefCast(op: *this)))
4835 return getResult();
4836 if (succeeded(Result: tensor::foldTensorCast(op: *this)))
4837 return getResult();
4838 return OpFoldResult();
4839}
4840
4841std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4842 return llvm::to_vector<4>(Range: getVectorType().getShape());
4843}
4844
4845void TransferReadOp::getEffects(
4846 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4847 &effects) {
4848 if (llvm::isa<MemRefType>(Val: getShapedType()))
4849 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &getBaseMutable(),
4850 Args: SideEffects::DefaultResource::get());
4851}
4852
4853Speculation::Speculatability TransferReadOp::getSpeculatability() {
4854 if (hasPureTensorSemantics())
4855 return Speculation::Speculatable;
4856 return Speculation::NotSpeculatable;
4857}
4858
4859namespace {
4860/// Store to load forwarding for transfer operations with permuation maps.
4861/// Even if the permutation maps are different we can still propagate the store
4862/// into the load if the size of the dimensions read and written match. Then we
4863/// can replace the transfer_read + transfer_write by vector.broadcast and
4864/// vector.transpose.
4865/// Example:
4866/// ```
4867/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4868/// {in_bounds = [true, true],
4869/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4870/// vector<4x1xf32>, tensor<4x4x4xf32>
4871/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4872/// {in_bounds = [true, true, true, true],
4873/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4874/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4875/// ```
4876/// To:
4877/// ```
4878/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4879/// %r = vector.transpose %0, [3, 0, 2, 1] :
4880/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4881/// ```
4882struct TransferReadAfterWriteToBroadcast
4883 : public OpRewritePattern<TransferReadOp> {
4884 using OpRewritePattern::OpRewritePattern;
4885
4886 LogicalResult matchAndRewrite(TransferReadOp readOp,
4887 PatternRewriter &rewriter) const override {
4888 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
4889 if (!defWrite)
4890 return failure();
4891 // Bail if we need an alias analysis.
4892 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
4893 return failure();
4894 // Bail if we need a bounds analysis.
4895 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
4896 return failure();
4897 // TODO: If the written transfer chunk is a superset of the read transfer
4898 // chunk we could do an extract_strided_slice.
4899 if (readOp.getTransferChunkAccessed() !=
4900 defWrite.getTransferChunkAccessed())
4901 return failure();
4902 // TODO: Support cases where a dim is explicitly written but implicitly
4903 // read (i.e., a unit dim that is rank reduced).
4904 if (getUnusedDimsBitVector(maps: {readOp.getPermutationMap()}) !=
4905 getUnusedDimsBitVector(maps: {defWrite.getPermutationMap()}))
4906 return failure();
4907 // This pattern should only catch the broadcast case, the non-broadcast case
4908 // should be done separately to keep application conditions clean and
4909 // separate.
4910 AffineMap readMap = compressUnusedDims(map: readOp.getPermutationMap());
4911 AffineMap writeMap = compressUnusedDims(map: defWrite.getPermutationMap());
4912 bool bcast = !readMap.getBroadcastDims().empty() ||
4913 !writeMap.getBroadcastDims().empty();
4914 if (!bcast)
4915 return failure();
4916 // At this point, we know we have a bcast.
4917 // Bail in the masked case (too complex atm and needed to properly account
4918 // for padding).
4919 if (readOp.getMask() || defWrite.getMask())
4920 return failure();
4921 // If indices are not the same a shift may be required, bail.
4922 if (readOp.getIndices() != defWrite.getIndices())
4923 return failure();
4924
4925 Value vec = defWrite.getVector();
4926 // TODO: loop through the chain of transfer_write if we can prove that they
4927 // don't overlap with the transfer_read. This requires improving
4928 // `isDisjointTransferIndices` helper.
4929 AffineMap map = readMap.compose(map: writeMap);
4930 if (map.getNumResults() == 0)
4931 return failure();
4932 // Calculate the permutation to apply to go from the vector stored to the
4933 // vector read.
4934 SmallVector<unsigned> permutation;
4935 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutedDims&: permutation))
4936 return failure();
4937
4938 Location loc = readOp.getLoc();
4939 // Calculate the broadcast shape by applying the reverse permutation to the
4940 // final shape we want.
4941 ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4942 SmallVector<int64_t> broadcastShape(destShape.size());
4943 SmallVector<bool> broadcastScalableFlags(destShape.size());
4944 for (const auto &pos : llvm::enumerate(First&: permutation)) {
4945 broadcastShape[pos.value()] = destShape[pos.index()];
4946 broadcastScalableFlags[pos.value()] =
4947 readOp.getVectorType().getScalableDims()[pos.index()];
4948 }
4949 VectorType broadcastedType = VectorType::get(
4950 shape: broadcastShape, elementType: defWrite.getVectorType().getElementType(),
4951 scalableDims: broadcastScalableFlags);
4952 vec = rewriter.create<vector::BroadcastOp>(location: loc, args&: broadcastedType, args&: vec);
4953 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4954 rewriter.replaceOpWithNewOp<vector::TransposeOp>(op: readOp, args&: vec,
4955 args&: transposePerm);
4956 return success();
4957 }
4958};
4959} // namespace
4960
4961void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4962 MLIRContext *context) {
4963 results.add<TransferReadAfterWriteToBroadcast>(arg&: context);
4964}
4965
4966//===----------------------------------------------------------------------===//
4967// TransferWriteOp
4968//===----------------------------------------------------------------------===//
4969
4970/// 1. Builder with type inference.
4971void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4972 Value vector, Value dest, ValueRange indices,
4973 AffineMapAttr permutationMapAttr,
4974 /*optional*/ Value mask,
4975 /*optional*/ ArrayAttr inBoundsAttr) {
4976 Type resultType = llvm::dyn_cast<RankedTensorType>(Val: dest.getType());
4977 build(odsBuilder&: builder, odsState&: result, result: resultType, valueToStore: vector, base: dest, indices, permutation_map: permutationMapAttr,
4978 mask, in_bounds: inBoundsAttr);
4979}
4980
4981/// 2. Builder with type inference that sets an empty mask (variant with attrs).
4982void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4983 Value vector, Value dest, ValueRange indices,
4984 AffineMapAttr permutationMapAttr,
4985 /*optional*/ ArrayAttr inBoundsAttr) {
4986 build(builder, result, vector, dest, indices, permutationMapAttr,
4987 /*mask=*/Value(), inBoundsAttr);
4988}
4989
4990/// 3. Builder with type inference that sets an empty mask (variant without
4991/// attrs)
4992void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4993 Value vector, Value dest, ValueRange indices,
4994 AffineMap permutationMap,
4995 std::optional<ArrayRef<bool>> inBounds) {
4996 auto permutationMapAttr = AffineMapAttr::get(value: permutationMap);
4997 auto inBoundsAttr =
4998 (inBounds && !inBounds.value().empty())
4999 ? builder.getBoolArrayAttr(values: inBounds.value())
5000 : builder.getBoolArrayAttr(values: SmallVector<bool>(
5001 llvm::cast<VectorType>(Val: vector.getType()).getRank(), false));
5002 build(builder, result, vector, dest, indices, permutationMapAttr,
5003 /*mask=*/Value(), inBoundsAttr);
5004}
5005
5006/// 4. Builder with type inference that sets an empty mask and sets permutation
5007/// map to 'getMinorIdentityMap'.
5008void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5009 Value vector, Value dest, ValueRange indices,
5010 std::optional<ArrayRef<bool>> inBounds) {
5011 auto vectorType = llvm::cast<VectorType>(Val: vector.getType());
5012 AffineMap permutationMap = getTransferMinorIdentityMap(
5013 shapedType: llvm::cast<ShapedType>(Val: dest.getType()), vectorType);
5014 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5015}
5016
5017ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5018 OperationState &result) {
5019 auto &builder = parser.getBuilder();
5020 SMLoc typesLoc;
5021 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5022 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5023 SmallVector<Type, 2> types;
5024 OpAsmParser::UnresolvedOperand maskInfo;
5025 if (parser.parseOperand(result&: vectorInfo) || parser.parseComma() ||
5026 parser.parseOperand(result&: sourceInfo) ||
5027 parser.parseOperandList(result&: indexInfo, delimiter: OpAsmParser::Delimiter::Square))
5028 return failure();
5029 ParseResult hasMask = parser.parseOptionalComma();
5030 if (hasMask.succeeded() && parser.parseOperand(result&: maskInfo))
5031 return failure();
5032 if (parser.parseOptionalAttrDict(result&: result.attributes) ||
5033 parser.getCurrentLocation(loc: &typesLoc) || parser.parseColonTypeList(result&: types))
5034 return failure();
5035 if (types.size() != 2)
5036 return parser.emitError(loc: typesLoc, message: "requires two types");
5037 auto indexType = builder.getIndexType();
5038 VectorType vectorType = llvm::dyn_cast<VectorType>(Val&: types[0]);
5039 if (!vectorType)
5040 return parser.emitError(loc: typesLoc, message: "requires vector type");
5041 ShapedType shapedType = llvm::dyn_cast<ShapedType>(Val&: types[1]);
5042 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(Val: shapedType))
5043 return parser.emitError(loc: typesLoc, message: "requires memref or ranked tensor type");
5044 auto permMapAttrName =
5045 TransferWriteOp::getPermutationMapAttrName(name: result.name);
5046 auto permMapAttr = result.attributes.get(name: permMapAttrName);
5047 AffineMap permMap;
5048 if (!permMapAttr) {
5049 if (shapedType.getRank() <
5050 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5051 return parser.emitError(loc: typesLoc,
5052 message: "expected a custom permutation_map when "
5053 "rank(source) != rank(destination)");
5054 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5055 result.attributes.set(name: permMapAttrName, value: AffineMapAttr::get(value: permMap));
5056 } else {
5057 permMap = llvm::cast<AffineMapAttr>(Val&: permMapAttr).getValue();
5058 }
5059 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(name: result.name);
5060 Attribute inBoundsAttr = result.attributes.get(name: inBoundsAttrName);
5061 if (!inBoundsAttr) {
5062 result.addAttribute(name: inBoundsAttrName,
5063 attr: builder.getBoolArrayAttr(
5064 values: SmallVector<bool>(permMap.getNumResults(), false)));
5065 }
5066 if (parser.resolveOperand(operand: vectorInfo, type: vectorType, result&: result.operands) ||
5067 parser.resolveOperand(operand: sourceInfo, type: shapedType, result&: result.operands) ||
5068 parser.resolveOperands(operands&: indexInfo, type: indexType, result&: result.operands))
5069 return failure();
5070 if (hasMask.succeeded()) {
5071 if (llvm::dyn_cast<VectorType>(Val: shapedType.getElementType()))
5072 return parser.emitError(
5073 loc: maskInfo.location, message: "does not support masks with vector element type");
5074 if (vectorType.getRank() != permMap.getNumResults()) {
5075 return parser.emitError(loc: typesLoc,
5076 message: "expected the same rank for the vector and the "
5077 "results of the permutation map");
5078 }
5079 auto maskType = inferTransferOpMaskType(vecType: vectorType, permMap);
5080 if (parser.resolveOperand(operand: maskInfo, type: maskType, result&: result.operands))
5081 return failure();
5082 }
5083 result.addAttribute(name: TransferWriteOp::getOperandSegmentSizeAttr(),
5084 attr: builder.getDenseI32ArrayAttr(
5085 values: {1, 1, static_cast<int32_t>(indexInfo.size()),
5086 static_cast<int32_t>(hasMask.succeeded())}));
5087 return failure(IsFailure: llvm::isa<RankedTensorType>(Val: shapedType) &&
5088 parser.addTypeToList(type: shapedType, result&: result.types));
5089}
5090
5091void TransferWriteOp::print(OpAsmPrinter &p) {
5092 p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5093 if (getMask())
5094 p << ", " << getMask();
5095 printTransferAttrs(p, op: *this);
5096 p << " : " << getVectorType() << ", " << getShapedType();
5097}
5098
5099LogicalResult TransferWriteOp::verify() {
5100 // Consistency of elemental types in shape and vector.
5101 ShapedType shapedType = getShapedType();
5102 VectorType vectorType = getVectorType();
5103 VectorType maskType = getMaskType();
5104 auto permutationMap = getPermutationMap();
5105 VectorType inferredMaskType =
5106 maskType ? inferTransferOpMaskType(vecType: vectorType, permMap: permutationMap)
5107 : VectorType();
5108
5109 if (llvm::size(Range: getIndices()) != shapedType.getRank())
5110 return emitOpError(message: "requires ") << shapedType.getRank() << " indices";
5111
5112 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5113 // as the semantics is unclear. This can be revisited later if necessary.
5114 if (hasBroadcastDim())
5115 return emitOpError(message: "should not have broadcast dimensions");
5116
5117 if (failed(Result: verifyTransferOp(op: cast<VectorTransferOpInterface>(Val: getOperation()),
5118 shapedType, vectorType, maskType,
5119 inferredMaskType, permutationMap, inBounds: getInBounds())))
5120 return failure();
5121
5122 return verifyPermutationMap(permutationMap,
5123 emitOpError: [&](Twine t) { return emitOpError(message: t); });
5124}
5125
5126//===----------------------------------------------------------------------===//
5127// TransferWriteOp: MaskableOpInterface methods.
5128//===----------------------------------------------------------------------===//
5129
5130/// Returns the mask type expected by this operation. Mostly used for
5131/// verification purposes.
5132Type TransferWriteOp::getExpectedMaskType() {
5133 return inferTransferOpMaskType(vecType: getVectorType(), permMap: getPermutationMap());
5134}
5135
5136//===----------------------------------------------------------------------===//
5137// TransferWriteOp: VectorTransferOpInterface methods.
5138//===----------------------------------------------------------------------===//
5139Value TransferWriteOp::getVector() { return getOperand(i: 0); }
5140VectorType TransferWriteOp::getVectorType() {
5141 return cast<VectorType>(Val: getValueToStore().getType());
5142}
5143
5144//===----------------------------------------------------------------------===//
5145// TransferWriteOp: fold methods.
5146//===----------------------------------------------------------------------===//
5147/// Fold:
5148/// ```
5149/// %t1 = ...
5150/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5151/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5152/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5153/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5154/// ```
5155///
5156/// into:
5157///
5158/// ```
5159/// %t0
5160/// ```
5161///
5162/// The producer of t1 may or may not be DCE'd depending on whether it is a
5163/// block argument or has side effects.
5164static LogicalResult foldReadInitWrite(TransferWriteOp write,
5165 ArrayRef<Attribute>,
5166 SmallVectorImpl<OpFoldResult> &results) {
5167 // TODO: support 0-d corner case.
5168 if (write.getTransferRank() == 0)
5169 return failure();
5170 auto rankedTensorType =
5171 llvm::dyn_cast<RankedTensorType>(Val: write.getBase().getType());
5172 // If not operating on tensors, bail.
5173 if (!rankedTensorType)
5174 return failure();
5175 // If no read, bail.
5176 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5177 if (!read)
5178 return failure();
5179 // TODO: support 0-d corner case.
5180 if (read.getTransferRank() == 0)
5181 return failure();
5182 // For now, only accept minor identity. Future: composition is minor identity.
5183 if (!read.getPermutationMap().isMinorIdentity() ||
5184 !write.getPermutationMap().isMinorIdentity())
5185 return failure();
5186 // Bail on mismatching ranks.
5187 if (read.getTransferRank() != write.getTransferRank())
5188 return failure();
5189 // Bail on potential out-of-bounds accesses.
5190 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5191 return failure();
5192 // Tensor types must be the same.
5193 if (read.getBase().getType() != rankedTensorType)
5194 return failure();
5195 // Vector types must be the same.
5196 if (read.getVectorType() != write.getVectorType())
5197 return failure();
5198 // Vector and Tensor shapes must match.
5199 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5200 return failure();
5201 // If any index is nonzero.
5202 auto isNotConstantZero = [](Value v) {
5203 auto cstOp = getConstantIntValue(ofr: v);
5204 return !cstOp.has_value() || cstOp.value() != 0;
5205 };
5206 if (llvm::any_of(Range: read.getIndices(), P: isNotConstantZero) ||
5207 llvm::any_of(Range: write.getIndices(), P: isNotConstantZero))
5208 return failure();
5209 // Success.
5210 results.push_back(Elt: read.getBase());
5211 return success();
5212}
5213
5214static bool checkSameValueWAR(vector::TransferReadOp read,
5215 vector::TransferWriteOp write) {
5216 return read.getBase() == write.getBase() &&
5217 read.getIndices() == write.getIndices() &&
5218 read.getPermutationMap() == write.getPermutationMap() &&
5219 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5220 !write.getMask();
5221}
5222/// Fold transfer_write write after read:
5223/// ```
5224/// %t0 = ...
5225/// %v = vector.transfer_read %t0[%c0...] :
5226/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5227/// %t1 = vector.transfer_write %v, %t0[%c0...] :
5228/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5229/// ```
5230///
5231/// into:
5232///
5233/// ```
5234/// %t0
5235/// ```
5236static LogicalResult foldWAR(TransferWriteOp write,
5237 SmallVectorImpl<OpFoldResult> &results) {
5238 if (!llvm::isa<RankedTensorType>(Val: write.getBase().getType()))
5239 return failure();
5240 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5241 if (!read)
5242 return failure();
5243
5244 if (!checkSameValueWAR(read, write))
5245 return failure();
5246 results.push_back(Elt: read.getBase());
5247 return success();
5248}
5249
5250LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5251 SmallVectorImpl<OpFoldResult> &results) {
5252 if (succeeded(Result: foldReadInitWrite(write: *this, adaptor.getOperands(), results)))
5253 return success();
5254 if (succeeded(Result: foldWAR(write: *this, results)))
5255 return success();
5256 if (succeeded(Result: foldTransferInBoundsAttribute(op: *this)))
5257 return success();
5258 if (succeeded(Result: foldTransferFullMask(op: *this)))
5259 return success();
5260 return memref::foldMemRefCast(op: *this);
5261}
5262
5263//===----------------------------------------------------------------------===//
5264// TransferWriteOp: other methods.
5265//===----------------------------------------------------------------------===//
5266std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5267 return llvm::to_vector<4>(Range: getVectorType().getShape());
5268}
5269
5270void TransferWriteOp::getEffects(
5271 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5272 &effects) {
5273 if (llvm::isa<MemRefType>(Val: getShapedType()))
5274 effects.emplace_back(Args: MemoryEffects::Write::get(), Args: &getBaseMutable(),
5275 Args: SideEffects::DefaultResource::get());
5276}
5277
5278Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5279 if (hasPureTensorSemantics())
5280 return Speculation::Speculatable;
5281 return Speculation::NotSpeculatable;
5282}
5283
5284namespace {
5285/// Remove dead transfer write from the SSA chain so that it an be eliminated by
5286/// DCE
5287/// ```
5288/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5289/// : vector<1x4xf32>, tensor<4x4xf32>
5290/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5291/// : vector<1x4xf32>, tensor<4x4xf32>
5292/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5293/// : vector<1x4xf32>, tensor<4x4xf32>
5294/// ```
5295///
5296/// into:
5297///
5298/// ```
5299/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5300/// : vector<1x4xf32>, tensor<4x4xf32>
5301/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5302/// : vector<1x4xf32>, tensor<4x4xf32>
5303/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5304/// : vector<1x4xf32>, tensor<4x4xf32>
5305/// ```
5306///
5307/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5308/// any other uses.
5309class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5310public:
5311 using OpRewritePattern::OpRewritePattern;
5312 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5313 PatternRewriter &rewriter) const override {
5314 if (!llvm::isa<RankedTensorType>(Val: writeOp.getShapedType()))
5315 return failure();
5316 vector::TransferWriteOp writeToModify = writeOp;
5317
5318 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5319 while (defWrite) {
5320 if (checkSameValueWAW(write: writeOp, priorWrite: defWrite)) {
5321 rewriter.modifyOpInPlace(root: writeToModify, callable: [&]() {
5322 writeToModify.getBaseMutable().assign(value: defWrite.getBase());
5323 });
5324 return success();
5325 }
5326 if (!isDisjointTransferIndices(
5327 transferA: cast<VectorTransferOpInterface>(Val: defWrite.getOperation()),
5328 transferB: cast<VectorTransferOpInterface>(Val: writeOp.getOperation())))
5329 break;
5330 // If the previous write op doesn't have any other use we an safely look
5331 // at the previous store to see if it can be removed.
5332 if (!defWrite->hasOneUse())
5333 break;
5334 writeToModify = defWrite;
5335 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5336 }
5337 return failure();
5338 }
5339};
5340
5341/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
5342/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
5343/// overwritten and inserted into another tensor. After this rewrite, the
5344/// operations bufferize in-place since all of them work on the same slice.
5345///
5346/// For example:
5347/// ```mlir
5348/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
5349/// : vector<8x16xf32>, tensor<8x16xf32>
5350/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
5351/// : tensor<8x16xf32> to tensor<?x?xf32>
5352/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5353/// : tensor<?x?xf32> into tensor<27x37xf32>
5354/// ```
5355/// folds to
5356/// ```mlir
5357/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5358/// : tensor<27x37xf32> to tensor<?x?xf32>
5359/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
5360/// : vector<8x16xf32>, tensor<?x?xf32>
5361/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5362/// : tensor<?x?xf32> into tensor<27x37xf32>
5363/// ```
5364struct SwapExtractSliceOfTransferWrite
5365 : public OpRewritePattern<tensor::InsertSliceOp> {
5366public:
5367 using OpRewritePattern::OpRewritePattern;
5368
5369 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5370 PatternRewriter &rewriter) const override {
5371 if (!insertOp.hasUnitStride())
5372 return failure();
5373 auto extractOp =
5374 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5375 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5376 return failure();
5377 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5378 if (!transferOp || !transferOp->hasOneUse())
5379 return failure();
5380
5381 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5382 // rank-reducing.
5383 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5384 return rewriter.notifyMatchFailure(arg&: insertOp,
5385 msg: "use-def chain is rank-reducing");
5386 }
5387
5388 // Fail if tensor::ExtractSliceOp has non-zero offset.
5389 if (!extractOp.hasZeroOffset()) {
5390 return rewriter.notifyMatchFailure(arg&: insertOp,
5391 msg: "ExtractSliceOp has non-zero offset");
5392 }
5393
5394 // Fail if tensor::TransferWriteOp has non-zero offset.
5395 if (!llvm::all_of(Range: transferOp.getIndices(), P: [](Value value) {
5396 return getConstantIntValue(ofr: value) == static_cast<int64_t>(0);
5397 })) {
5398 return rewriter.notifyMatchFailure(arg&: insertOp,
5399 msg: "TranferWriteOp has non-zero offset");
5400 }
5401
5402 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5403 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5404 return rewriter.notifyMatchFailure(
5405 arg&: insertOp, msg: "InsertSliceOp and ExtractSliceOp ranks differ");
5406 }
5407
5408 for (auto [insertSize, extractSize] :
5409 llvm::zip_equal(t: insertOp.getMixedSizes(), u: extractOp.getMixedSizes())) {
5410 if (!isEqualConstantIntOrValue(ofr1: insertSize, ofr2: extractSize)) {
5411 return rewriter.notifyMatchFailure(
5412 arg&: insertOp, msg: "InsertSliceOp and ExtractSliceOp sizes differ");
5413 }
5414 }
5415
5416 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5417 assert(transferOp.getVectorType().hasStaticShape() &&
5418 "expected vector to have a static shape");
5419 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5420 SmallVector<int64_t> resultShape = applyPermutationMap(
5421 map: transferOp.getPermutationMap(), source: transferOp.getShapedType().getShape());
5422 if (transferOp.getMask() || !vectorShape.equals(RHS: resultShape)) {
5423 return rewriter.notifyMatchFailure(
5424 arg&: insertOp, msg: "TransferWriteOp may not write the full tensor.");
5425 }
5426
5427 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5428 // Set all in_bounds to false and let the folder infer them.
5429 SmallVector<bool> newInBounds(vectorShape.size(), false);
5430 auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
5431 location: extractOp.getLoc(), args: insertOp.getSourceType(), args: insertOp.getDest(),
5432 args: insertOp.getMixedOffsets(), args: insertOp.getMixedSizes(),
5433 args: insertOp.getMixedStrides());
5434 auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
5435 location: transferOp.getLoc(), args: transferOp.getVector(), args: newExtractOp.getResult(),
5436 args: transferOp.getIndices(), args: transferOp.getPermutationMapAttr(),
5437 args: rewriter.getBoolArrayAttr(values: newInBounds));
5438 rewriter.modifyOpInPlace(root: insertOp, callable: [&]() {
5439 insertOp.getSourceMutable().assign(value: newTransferWriteOp.getResult());
5440 });
5441 return success();
5442 }
5443};
5444
5445} // namespace
5446
5447void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5448 MLIRContext *context) {
5449 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(arg&: context);
5450}
5451
5452//===----------------------------------------------------------------------===//
5453// LoadOp
5454//===----------------------------------------------------------------------===//
5455
5456static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5457 VectorType vecTy,
5458 MemRefType memRefTy) {
5459 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5460 // need any strides limitations.
5461 if (!vecTy.isScalable() &&
5462 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5463 return success();
5464
5465 if (!memRefTy.isLastDimUnitStride())
5466 return op->emitOpError(message: "most minor memref dim must have unit stride");
5467 return success();
5468}
5469
5470LogicalResult vector::LoadOp::verify() {
5471 VectorType resVecTy = getVectorType();
5472 MemRefType memRefTy = getMemRefType();
5473
5474 if (failed(Result: verifyLoadStoreMemRefLayout(op: *this, vecTy: resVecTy, memRefTy)))
5475 return failure();
5476
5477 if (memRefTy.getRank() < resVecTy.getRank())
5478 return emitOpError(
5479 message: "destination memref has lower rank than the result vector");
5480
5481 // Checks for vector memrefs.
5482 Type memElemTy = memRefTy.getElementType();
5483 if (auto memVecTy = llvm::dyn_cast<VectorType>(Val&: memElemTy)) {
5484 if (memVecTy != resVecTy)
5485 return emitOpError(message: "base memref and result vector types should match");
5486 memElemTy = memVecTy.getElementType();
5487 }
5488
5489 if (resVecTy.getElementType() != memElemTy)
5490 return emitOpError(message: "base and result element types should match");
5491 if (llvm::size(Range: getIndices()) != memRefTy.getRank())
5492 return emitOpError(message: "requires ") << memRefTy.getRank() << " indices";
5493 return success();
5494}
5495
5496OpFoldResult LoadOp::fold(FoldAdaptor) {
5497 if (succeeded(Result: memref::foldMemRefCast(op: *this)))
5498 return getResult();
5499 return OpFoldResult();
5500}
5501
5502std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5503 return llvm::to_vector<4>(Range: getVectorType().getShape());
5504}
5505
5506//===----------------------------------------------------------------------===//
5507// StoreOp
5508//===----------------------------------------------------------------------===//
5509
5510LogicalResult vector::StoreOp::verify() {
5511 VectorType valueVecTy = getVectorType();
5512 MemRefType memRefTy = getMemRefType();
5513
5514 if (failed(Result: verifyLoadStoreMemRefLayout(op: *this, vecTy: valueVecTy, memRefTy)))
5515 return failure();
5516
5517 if (memRefTy.getRank() < valueVecTy.getRank())
5518 return emitOpError(message: "source memref has lower rank than the vector to store");
5519
5520 // Checks for vector memrefs.
5521 Type memElemTy = memRefTy.getElementType();
5522 if (auto memVecTy = llvm::dyn_cast<VectorType>(Val&: memElemTy)) {
5523 if (memVecTy != valueVecTy)
5524 return emitOpError(
5525 message: "base memref and valueToStore vector types should match");
5526 memElemTy = memVecTy.getElementType();
5527 }
5528
5529 if (valueVecTy.getElementType() != memElemTy)
5530 return emitOpError(message: "base and valueToStore element type should match");
5531 if (llvm::size(Range: getIndices()) != memRefTy.getRank())
5532 return emitOpError(message: "requires ") << memRefTy.getRank() << " indices";
5533 return success();
5534}
5535
5536LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5537 SmallVectorImpl<OpFoldResult> &results) {
5538 return memref::foldMemRefCast(op: *this);
5539}
5540
5541std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5542 return llvm::to_vector<4>(Range: getVectorType().getShape());
5543}
5544
5545//===----------------------------------------------------------------------===//
5546// MaskedLoadOp
5547//===----------------------------------------------------------------------===//
5548
5549LogicalResult MaskedLoadOp::verify() {
5550 VectorType maskVType = getMaskVectorType();
5551 VectorType passVType = getPassThruVectorType();
5552 VectorType resVType = getVectorType();
5553 MemRefType memType = getMemRefType();
5554
5555 if (resVType.getElementType() != memType.getElementType())
5556 return emitOpError(message: "base and result element type should match");
5557 if (llvm::size(Range: getIndices()) != memType.getRank())
5558 return emitOpError(message: "requires ") << memType.getRank() << " indices";
5559 if (resVType.getShape() != maskVType.getShape())
5560 return emitOpError(message: "expected result shape to match mask shape");
5561 if (resVType != passVType)
5562 return emitOpError(message: "expected pass_thru of same type as result type");
5563 return success();
5564}
5565
5566namespace {
5567class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5568public:
5569 using OpRewritePattern::OpRewritePattern;
5570 LogicalResult matchAndRewrite(MaskedLoadOp load,
5571 PatternRewriter &rewriter) const override {
5572 switch (getMaskFormat(mask: load.getMask())) {
5573 case MaskFormat::AllTrue:
5574 rewriter.replaceOpWithNewOp<vector::LoadOp>(
5575 op: load, args: load.getType(), args: load.getBase(), args: load.getIndices());
5576 return success();
5577 case MaskFormat::AllFalse:
5578 rewriter.replaceOp(op: load, newValues: load.getPassThru());
5579 return success();
5580 case MaskFormat::Unknown:
5581 return failure();
5582 }
5583 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5584 }
5585};
5586} // namespace
5587
5588void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5589 MLIRContext *context) {
5590 results.add<MaskedLoadFolder>(arg&: context);
5591}
5592
5593OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5594 if (succeeded(Result: memref::foldMemRefCast(op: *this)))
5595 return getResult();
5596 return OpFoldResult();
5597}
5598
5599//===----------------------------------------------------------------------===//
5600// MaskedStoreOp
5601//===----------------------------------------------------------------------===//
5602
5603LogicalResult MaskedStoreOp::verify() {
5604 VectorType maskVType = getMaskVectorType();
5605 VectorType valueVType = getVectorType();
5606 MemRefType memType = getMemRefType();
5607
5608 if (valueVType.getElementType() != memType.getElementType())
5609 return emitOpError(message: "base and valueToStore element type should match");
5610 if (llvm::size(Range: getIndices()) != memType.getRank())
5611 return emitOpError(message: "requires ") << memType.getRank() << " indices";
5612 if (valueVType.getShape() != maskVType.getShape())
5613 return emitOpError(message: "expected valueToStore shape to match mask shape");
5614 return success();
5615}
5616
5617namespace {
5618class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5619public:
5620 using OpRewritePattern::OpRewritePattern;
5621 LogicalResult matchAndRewrite(MaskedStoreOp store,
5622 PatternRewriter &rewriter) const override {
5623 switch (getMaskFormat(mask: store.getMask())) {
5624 case MaskFormat::AllTrue:
5625 rewriter.replaceOpWithNewOp<vector::StoreOp>(
5626 op: store, args: store.getValueToStore(), args: store.getBase(), args: store.getIndices());
5627 return success();
5628 case MaskFormat::AllFalse:
5629 rewriter.eraseOp(op: store);
5630 return success();
5631 case MaskFormat::Unknown:
5632 return failure();
5633 }
5634 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5635 }
5636};
5637} // namespace
5638
5639void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5640 MLIRContext *context) {
5641 results.add<MaskedStoreFolder>(arg&: context);
5642}
5643
5644LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5645 SmallVectorImpl<OpFoldResult> &results) {
5646 return memref::foldMemRefCast(op: *this);
5647}
5648
5649//===----------------------------------------------------------------------===//
5650// GatherOp
5651//===----------------------------------------------------------------------===//
5652
5653LogicalResult GatherOp::verify() {
5654 VectorType indVType = getIndexVectorType();
5655 VectorType maskVType = getMaskVectorType();
5656 VectorType resVType = getVectorType();
5657 ShapedType baseType = getBaseType();
5658
5659 if (!llvm::isa<MemRefType, RankedTensorType>(Val: baseType))
5660 return emitOpError(message: "requires base to be a memref or ranked tensor type");
5661
5662 if (resVType.getElementType() != baseType.getElementType())
5663 return emitOpError(message: "base and result element type should match");
5664 if (llvm::size(Range: getIndices()) != baseType.getRank())
5665 return emitOpError(message: "requires ") << baseType.getRank() << " indices";
5666 if (resVType.getShape() != indVType.getShape())
5667 return emitOpError(message: "expected result dim to match indices dim");
5668 if (resVType.getShape() != maskVType.getShape())
5669 return emitOpError(message: "expected result dim to match mask dim");
5670 if (resVType != getPassThruVectorType())
5671 return emitOpError(message: "expected pass_thru of same type as result type");
5672 return success();
5673}
5674
5675// MaskableOpInterface methods.
5676
5677/// Returns the mask type expected by this operation. Mostly used for
5678/// verification purposes. It requires the operation to be vectorized."
5679Type GatherOp::getExpectedMaskType() {
5680 auto vecType = this->getIndexVectorType();
5681 return VectorType::get(shape: vecType.getShape(),
5682 elementType: IntegerType::get(context: vecType.getContext(), /*width=*/1),
5683 scalableDims: vecType.getScalableDims());
5684}
5685
5686std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5687 return llvm::to_vector<4>(Range: getVectorType().getShape());
5688}
5689
5690/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5691static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5692 auto vecType = dyn_cast<VectorType>(Val: indexVec.getType());
5693 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5694 return failure();
5695
5696 if (indexVec.getDefiningOp<StepOp>())
5697 return success();
5698
5699 DenseIntElementsAttr elements;
5700 if (!matchPattern(value: indexVec, pattern: m_Constant(bind_value: &elements)))
5701 return failure();
5702
5703 return success(
5704 IsSuccess: llvm::equal(LRange&: elements, RRange: llvm::seq<int64_t>(Begin: 0, End: vecType.getNumElements())));
5705}
5706
5707namespace {
5708class GatherFolder final : public OpRewritePattern<GatherOp> {
5709public:
5710 using OpRewritePattern::OpRewritePattern;
5711 LogicalResult matchAndRewrite(GatherOp gather,
5712 PatternRewriter &rewriter) const override {
5713 switch (getMaskFormat(mask: gather.getMask())) {
5714 case MaskFormat::AllTrue:
5715 return failure(); // no unmasked equivalent
5716 case MaskFormat::AllFalse:
5717 rewriter.replaceOp(op: gather, newValues: gather.getPassThru());
5718 return success();
5719 case MaskFormat::Unknown:
5720 return failure();
5721 }
5722 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
5723 }
5724};
5725
5726/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5727/// maskedload. Only 1D fixed vectors are supported for now.
5728class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
5729public:
5730 using OpRewritePattern::OpRewritePattern;
5731 LogicalResult matchAndRewrite(GatherOp op,
5732 PatternRewriter &rewriter) const override {
5733 if (!isa<MemRefType>(Val: op.getBase().getType()))
5734 return rewriter.notifyMatchFailure(arg&: op, msg: "base must be of memref type");
5735
5736 if (failed(Result: isZeroBasedContiguousSeq(indexVec: op.getIndexVec())))
5737 return failure();
5738
5739 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, args: op.getType(), args: op.getBase(),
5740 args: op.getIndices(), args: op.getMask(),
5741 args: op.getPassThru());
5742 return success();
5743 }
5744};
5745} // namespace
5746
5747void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
5748 MLIRContext *context) {
5749 results.add<GatherFolder, FoldContiguousGather>(arg&: context);
5750}
5751
5752//===----------------------------------------------------------------------===//
5753// ScatterOp
5754//===----------------------------------------------------------------------===//
5755
5756LogicalResult ScatterOp::verify() {
5757 VectorType indVType = getIndexVectorType();
5758 VectorType maskVType = getMaskVectorType();
5759 VectorType valueVType = getVectorType();
5760 MemRefType memType = getMemRefType();
5761
5762 if (valueVType.getElementType() != memType.getElementType())
5763 return emitOpError(message: "base and valueToStore element type should match");
5764 if (llvm::size(Range: getIndices()) != memType.getRank())
5765 return emitOpError(message: "requires ") << memType.getRank() << " indices";
5766 if (valueVType.getShape() != indVType.getShape())
5767 return emitOpError(message: "expected valueToStore dim to match indices dim");
5768 if (valueVType.getShape() != maskVType.getShape())
5769 return emitOpError(message: "expected valueToStore dim to match mask dim");
5770 return success();
5771}
5772
5773namespace {
5774class ScatterFolder final : public OpRewritePattern<ScatterOp> {
5775public:
5776 using OpRewritePattern::OpRewritePattern;
5777 LogicalResult matchAndRewrite(ScatterOp scatter,
5778 PatternRewriter &rewriter) const override {
5779 switch (getMaskFormat(mask: scatter.getMask())) {
5780 case MaskFormat::AllTrue:
5781 return failure(); // no unmasked equivalent
5782 case MaskFormat::AllFalse:
5783 rewriter.eraseOp(op: scatter);
5784 return success();
5785 case MaskFormat::Unknown:
5786 return failure();
5787 }
5788 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
5789 }
5790};
5791
5792/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5793/// maskedstore. Only 1D fixed vectors are supported for now.
5794class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
5795public:
5796 using OpRewritePattern::OpRewritePattern;
5797 LogicalResult matchAndRewrite(ScatterOp op,
5798 PatternRewriter &rewriter) const override {
5799 if (failed(Result: isZeroBasedContiguousSeq(indexVec: op.getIndexVec())))
5800 return failure();
5801
5802 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5803 op, args: op.getBase(), args: op.getIndices(), args: op.getMask(), args: op.getValueToStore());
5804 return success();
5805 }
5806};
5807} // namespace
5808
5809void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
5810 MLIRContext *context) {
5811 results.add<ScatterFolder, FoldContiguousScatter>(arg&: context);
5812}
5813
5814//===----------------------------------------------------------------------===//
5815// ExpandLoadOp
5816//===----------------------------------------------------------------------===//
5817
5818LogicalResult ExpandLoadOp::verify() {
5819 VectorType maskVType = getMaskVectorType();
5820 VectorType passVType = getPassThruVectorType();
5821 VectorType resVType = getVectorType();
5822 MemRefType memType = getMemRefType();
5823
5824 if (resVType.getElementType() != memType.getElementType())
5825 return emitOpError(message: "base and result element type should match");
5826 if (llvm::size(Range: getIndices()) != memType.getRank())
5827 return emitOpError(message: "requires ") << memType.getRank() << " indices";
5828 if (resVType.getDimSize(idx: 0) != maskVType.getDimSize(idx: 0))
5829 return emitOpError(message: "expected result dim to match mask dim");
5830 if (resVType != passVType)
5831 return emitOpError(message: "expected pass_thru of same type as result type");
5832 return success();
5833}
5834
5835namespace {
5836class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
5837public:
5838 using OpRewritePattern::OpRewritePattern;
5839 LogicalResult matchAndRewrite(ExpandLoadOp expand,
5840 PatternRewriter &rewriter) const override {
5841 switch (getMaskFormat(mask: expand.getMask())) {
5842 case MaskFormat::AllTrue:
5843 rewriter.replaceOpWithNewOp<vector::LoadOp>(
5844 op: expand, args: expand.getType(), args: expand.getBase(), args: expand.getIndices());
5845 return success();
5846 case MaskFormat::AllFalse:
5847 rewriter.replaceOp(op: expand, newValues: expand.getPassThru());
5848 return success();
5849 case MaskFormat::Unknown:
5850 return failure();
5851 }
5852 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
5853 }
5854};
5855} // namespace
5856
5857void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5858 MLIRContext *context) {
5859 results.add<ExpandLoadFolder>(arg&: context);
5860}
5861
5862//===----------------------------------------------------------------------===//
5863// CompressStoreOp
5864//===----------------------------------------------------------------------===//
5865
5866LogicalResult CompressStoreOp::verify() {
5867 VectorType maskVType = getMaskVectorType();
5868 VectorType valueVType = getVectorType();
5869 MemRefType memType = getMemRefType();
5870
5871 if (valueVType.getElementType() != memType.getElementType())
5872 return emitOpError(message: "base and valueToStore element type should match");
5873 if (llvm::size(Range: getIndices()) != memType.getRank())
5874 return emitOpError(message: "requires ") << memType.getRank() << " indices";
5875 if (valueVType.getDimSize(idx: 0) != maskVType.getDimSize(idx: 0))
5876 return emitOpError(message: "expected valueToStore dim to match mask dim");
5877 return success();
5878}
5879
5880namespace {
5881class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
5882public:
5883 using OpRewritePattern::OpRewritePattern;
5884 LogicalResult matchAndRewrite(CompressStoreOp compress,
5885 PatternRewriter &rewriter) const override {
5886 switch (getMaskFormat(mask: compress.getMask())) {
5887 case MaskFormat::AllTrue:
5888 rewriter.replaceOpWithNewOp<vector::StoreOp>(
5889 op: compress, args: compress.getValueToStore(), args: compress.getBase(),
5890 args: compress.getIndices());
5891 return success();
5892 case MaskFormat::AllFalse:
5893 rewriter.eraseOp(op: compress);
5894 return success();
5895 case MaskFormat::Unknown:
5896 return failure();
5897 }
5898 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5899 }
5900};
5901} // namespace
5902
5903void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5904 MLIRContext *context) {
5905 results.add<CompressStoreFolder>(arg&: context);
5906}
5907
5908//===----------------------------------------------------------------------===//
5909// ShapeCastOp
5910//===----------------------------------------------------------------------===//
5911
5912void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5913 SetIntRangeFn setResultRanges) {
5914 setResultRanges(getResult(), argRanges.front());
5915}
5916
5917LogicalResult ShapeCastOp::verify() {
5918
5919 VectorType sourceType = getSourceVectorType();
5920 VectorType resultType = getResultVectorType();
5921
5922 // Check that element type is preserved
5923 if (sourceType.getElementType() != resultType.getElementType())
5924 return emitOpError(message: "has different source and result element types");
5925
5926 // Check that number of elements is preserved
5927 int64_t sourceNElms = sourceType.getNumElements();
5928 int64_t resultNElms = resultType.getNumElements();
5929 if (sourceNElms != resultNElms) {
5930 return emitOpError() << "has different number of elements at source ("
5931 << sourceNElms << ") and result (" << resultNElms
5932 << ")";
5933 }
5934
5935 // Check that (non-)scalability is preserved
5936 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
5937 int64_t resultNScalableDims = resultType.getNumScalableDims();
5938 if (sourceNScalableDims != resultNScalableDims)
5939 return emitOpError() << "has different number of scalable dims at source ("
5940 << sourceNScalableDims << ") and result ("
5941 << resultNScalableDims << ")";
5942
5943 return success();
5944}
5945
5946/// Return true if `transpose` does not permute a pair of non-unit dims.
5947/// By `order preserving` we mean that the flattened versions of the input and
5948/// output vectors are (numerically) identical. In other words `transpose` is
5949/// effectively a shape cast.
5950static bool isOrderPreserving(TransposeOp transpose) {
5951 ArrayRef<int64_t> permutation = transpose.getPermutation();
5952 VectorType sourceType = transpose.getSourceVectorType();
5953 ArrayRef<int64_t> inShape = sourceType.getShape();
5954 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
5955 auto isNonScalableUnitDim = [&](int64_t dim) {
5956 return inShape[dim] == 1 && !inDimIsScalable[dim];
5957 };
5958 int64_t current = 0;
5959 for (auto p : permutation) {
5960 if (!isNonScalableUnitDim(p)) {
5961 if (p < current) {
5962 return false;
5963 }
5964 current = p;
5965 }
5966 }
5967 return true;
5968}
5969
5970OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5971
5972 VectorType resultType = getType();
5973
5974 // No-op shape cast.
5975 if (getSource().getType() == resultType)
5976 return getSource();
5977
5978 // shape_cast(shape_cast(x)) -> shape_cast(x)
5979 if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
5980 setOperand(precedingShapeCast.getSource());
5981 return getResult();
5982 }
5983
5984 // shape_cast(transpose(x)) -> shape_cast(x)
5985 if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5986 if (isOrderPreserving(transpose)) {
5987 setOperand(transpose.getVector());
5988 return getResult();
5989 }
5990 return {};
5991 }
5992
5993 // Y = shape_cast(broadcast(X))
5994 // -> X, if X and Y have same type
5995 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5996 if (bcastOp.getSourceType() == resultType)
5997 return bcastOp.getSource();
5998 }
5999
6000 // shape_cast(constant) -> constant
6001 if (auto splatAttr =
6002 llvm::dyn_cast_if_present<SplatElementsAttr>(Val: adaptor.getSource()))
6003 return splatAttr.reshape(newType: getType());
6004
6005 // shape_cast(poison) -> poison
6006 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(Val: adaptor.getSource())) {
6007 return ub::PoisonAttr::get(ctx: getContext());
6008 }
6009
6010 return {};
6011}
6012
6013namespace {
6014
6015/// Helper function that computes a new vector type based on the input vector
6016/// type by removing the trailing one dims:
6017///
6018/// vector<4x1x1xi1> --> vector<4x1xi1>
6019///
6020static VectorType trimTrailingOneDims(VectorType oldType) {
6021 ArrayRef<int64_t> oldShape = oldType.getShape();
6022 ArrayRef<int64_t> newShape = oldShape;
6023
6024 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6025 ArrayRef<bool> newScalableDims = oldScalableDims;
6026
6027 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6028 newShape = newShape.drop_back(N: 1);
6029 newScalableDims = newScalableDims.drop_back(N: 1);
6030 }
6031
6032 // Make sure we have at least 1 dimension.
6033 // TODO: Add support for 0-D vectors.
6034 if (newShape.empty()) {
6035 newShape = oldShape.take_back();
6036 newScalableDims = oldScalableDims.take_back();
6037 }
6038
6039 return VectorType::get(shape: newShape, elementType: oldType.getElementType(), scalableDims: newScalableDims);
6040}
6041
6042/// Folds qualifying shape_cast(create_mask) into a new create_mask
6043///
6044/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6045/// dimension. If the input vector comes from `vector.create_mask` for which
6046/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6047/// to fold shape_cast into create_mask.
6048///
6049/// BEFORE:
6050/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6051/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6052/// AFTER:
6053/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6054class ShapeCastCreateMaskFolderTrailingOneDim final
6055 : public OpRewritePattern<ShapeCastOp> {
6056public:
6057 using OpRewritePattern::OpRewritePattern;
6058
6059 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6060 PatternRewriter &rewriter) const override {
6061 Value shapeOpSrc = shapeOp->getOperand(idx: 0);
6062 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6063 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6064 if (!createMaskOp && !constantMaskOp)
6065 return failure();
6066
6067 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6068 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6069
6070 VectorType newVecType = trimTrailingOneDims(oldType: shapeOpSrcTy);
6071 if (newVecType != shapeOpResTy)
6072 return failure();
6073
6074 auto numDimsToDrop =
6075 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6076
6077 // No unit dims to drop
6078 if (!numDimsToDrop)
6079 return failure();
6080
6081 if (createMaskOp) {
6082 auto maskOperands = createMaskOp.getOperands();
6083 auto numMaskOperands = maskOperands.size();
6084
6085 // Check every mask dim size to see whether it can be dropped
6086 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6087 --i) {
6088 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6089 if (!constant || (constant.value() != 1))
6090 return failure();
6091 }
6092 SmallVector<Value> newMaskOperands =
6093 maskOperands.drop_back(n: numDimsToDrop);
6094
6095 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(op: shapeOp, args&: shapeOpResTy,
6096 args&: newMaskOperands);
6097 return success();
6098 }
6099
6100 if (constantMaskOp) {
6101 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6102 auto numMaskOperands = maskDimSizes.size();
6103
6104 // Check every mask dim size to see whether it can be dropped
6105 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6106 --i) {
6107 if (maskDimSizes[i] != 1)
6108 return failure();
6109 }
6110
6111 auto newMaskOperands = maskDimSizes.drop_back(N: numDimsToDrop);
6112 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(op: shapeOp, args&: shapeOpResTy,
6113 args&: newMaskOperands);
6114 return success();
6115 }
6116
6117 return failure();
6118 }
6119};
6120
6121/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
6122/// i) Y = ShapeCast(X), or
6123/// ii) Y = Broadcast(X)
6124/// If both (i) and (ii) are possible, (i) is chosen.
6125class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6126public:
6127 using OpRewritePattern::OpRewritePattern;
6128
6129 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6130 PatternRewriter &rewriter) const override {
6131 auto broadcastOp =
6132 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6133 if (!broadcastOp)
6134 return failure();
6135
6136 auto srcVectorType = dyn_cast<VectorType>(Val: broadcastOp.getSourceType());
6137 bool srcIsScalar = !srcVectorType;
6138
6139 // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
6140 // Example:
6141 // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
6142 // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
6143 // to
6144 // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
6145 if (srcVectorType) {
6146 if (srcVectorType.getNumElements() ==
6147 shapeCastOp.getResultVectorType().getNumElements()) {
6148 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6149 op: shapeCastOp, args: shapeCastOp.getResultVectorType(),
6150 args: broadcastOp.getSource());
6151 return success();
6152 }
6153 }
6154
6155 // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6156 // Example
6157 // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6158 // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6159 // to
6160 // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6161 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6162 if (srcIsScalar || isBroadcastableTo(srcType: srcVectorType, dstVectorType) ==
6163 BroadcastableToResult::Success) {
6164 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6165 op: shapeCastOp, args&: dstVectorType, args: broadcastOp.getSource());
6166 return success();
6167 }
6168 return failure();
6169 }
6170};
6171
6172} // namespace
6173
6174void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6175 MLIRContext *context) {
6176 results
6177 .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6178 arg&: context);
6179}
6180
6181//===----------------------------------------------------------------------===//
6182// VectorBitCastOp
6183//===----------------------------------------------------------------------===//
6184
6185LogicalResult BitCastOp::verify() {
6186 auto sourceVectorType = getSourceVectorType();
6187 auto resultVectorType = getResultVectorType();
6188
6189 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6190 if (sourceVectorType.getDimSize(idx: i) != resultVectorType.getDimSize(idx: i))
6191 return emitOpError(message: "dimension size mismatch at: ") << i;
6192 }
6193
6194 DataLayout dataLayout = DataLayout::closest(op: *this);
6195 auto sourceElementBits =
6196 dataLayout.getTypeSizeInBits(t: sourceVectorType.getElementType());
6197 auto resultElementBits =
6198 dataLayout.getTypeSizeInBits(t: resultVectorType.getElementType());
6199
6200 if (sourceVectorType.getRank() == 0) {
6201 if (sourceElementBits != resultElementBits)
6202 return emitOpError(message: "source/result bitwidth of the 0-D vector element "
6203 "types must be equal");
6204 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
6205 resultElementBits * resultVectorType.getShape().back()) {
6206 return emitOpError(
6207 message: "source/result bitwidth of the minor 1-D vectors must be equal");
6208 }
6209
6210 return success();
6211}
6212
6213OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6214 // Nop cast.
6215 if (getSource().getType() == getResult().getType())
6216 return getSource();
6217
6218 // Canceling bitcasts.
6219 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6220 if (getResult().getType() == otherOp.getSource().getType())
6221 return otherOp.getSource();
6222
6223 setOperand(otherOp.getSource());
6224 return getResult();
6225 }
6226
6227 Attribute sourceConstant = adaptor.getSource();
6228 if (!sourceConstant)
6229 return {};
6230
6231 Type srcElemType = getSourceVectorType().getElementType();
6232 Type dstElemType = getResultVectorType().getElementType();
6233
6234 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(Val&: sourceConstant)) {
6235 if (floatPack.isSplat()) {
6236 auto splat = floatPack.getSplatValue<FloatAttr>();
6237
6238 // Casting fp16 into fp32.
6239 if (srcElemType.isF16() && dstElemType.isF32()) {
6240 uint32_t bits = static_cast<uint32_t>(
6241 splat.getValue().bitcastToAPInt().getZExtValue());
6242 // Duplicate the 16-bit pattern.
6243 bits = (bits << 16) | (bits & 0xffff);
6244 APInt intBits(32, bits);
6245 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6246 return DenseElementsAttr::get(type: getResultVectorType(), values: floatBits);
6247 }
6248 }
6249 }
6250
6251 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(Val&: sourceConstant)) {
6252 if (intPack.isSplat()) {
6253 auto splat = intPack.getSplatValue<IntegerAttr>();
6254
6255 if (llvm::isa<IntegerType>(Val: dstElemType)) {
6256 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
6257 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
6258
6259 // Casting to a larger integer bit width.
6260 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6261 APInt intBits = splat.getValue().zext(width: dstBitWidth);
6262
6263 // Duplicate the lower width element.
6264 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6265 intBits = (intBits << srcBitWidth) | intBits;
6266 return DenseElementsAttr::get(type: getResultVectorType(), values: intBits);
6267 }
6268 }
6269 }
6270 }
6271
6272 return {};
6273}
6274
6275//===----------------------------------------------------------------------===//
6276// TypeCastOp
6277//===----------------------------------------------------------------------===//
6278
6279static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6280 auto vectorType = llvm::dyn_cast<VectorType>(Val: memRefType.getElementType());
6281 SmallVector<int64_t, 8> res(memRefType.getShape());
6282 if (vectorType)
6283 res.append(in_start: vectorType.getShape().begin(), in_end: vectorType.getShape().end());
6284 return res;
6285}
6286
6287/// Build the canonical memRefType with a single vector.
6288/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
6289void TypeCastOp::build(OpBuilder &builder, OperationState &result,
6290 Value source) {
6291 result.addOperands(newOperands: source);
6292 MemRefType memRefType = llvm::cast<MemRefType>(Val: source.getType());
6293 VectorType vectorType =
6294 VectorType::get(shape: extractShape(memRefType),
6295 elementType: getElementTypeOrSelf(type: getElementTypeOrSelf(type: memRefType)));
6296 result.addTypes(newTypes: MemRefType::get(shape: {}, elementType: vectorType, layout: MemRefLayoutAttrInterface(),
6297 memorySpace: memRefType.getMemorySpace()));
6298}
6299
6300LogicalResult TypeCastOp::verify() {
6301 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
6302 if (!canonicalType.getLayout().isIdentity())
6303 return emitOpError(message: "expects operand to be a memref with identity layout");
6304 if (!getResultMemRefType().getLayout().isIdentity())
6305 return emitOpError(message: "expects result to be a memref with identity layout");
6306 if (getResultMemRefType().getMemorySpace() !=
6307 getMemRefType().getMemorySpace())
6308 return emitOpError(message: "expects result in same memory space");
6309
6310 auto sourceType = getMemRefType();
6311 auto resultType = getResultMemRefType();
6312 if (getElementTypeOrSelf(type: getElementTypeOrSelf(type: sourceType)) !=
6313 getElementTypeOrSelf(type: getElementTypeOrSelf(type: resultType)))
6314 return emitOpError(
6315 message: "expects result and operand with same underlying scalar type: ")
6316 << resultType;
6317 if (extractShape(memRefType: sourceType) != extractShape(memRefType: resultType))
6318 return emitOpError(
6319 message: "expects concatenated result and operand shapes to be equal: ")
6320 << resultType;
6321 return success();
6322}
6323
6324//===----------------------------------------------------------------------===//
6325// TransposeOp
6326//===----------------------------------------------------------------------===//
6327
6328void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6329 Value vector, ArrayRef<int64_t> permutation) {
6330 VectorType vt = llvm::cast<VectorType>(Val: vector.getType());
6331 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6332 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6333 for (unsigned i = 0; i < permutation.size(); ++i) {
6334 transposedShape[i] = vt.getShape()[permutation[i]];
6335 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6336 }
6337
6338 result.addOperands(newOperands: vector);
6339 result.addTypes(newTypes: VectorType::get(shape: transposedShape, elementType: vt.getElementType(),
6340 scalableDims: transposedScalableDims));
6341 result.addAttribute(name: TransposeOp::getPermutationAttrName(name: result.name),
6342 attr: builder.getDenseI64ArrayAttr(values: permutation));
6343}
6344
6345OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6346 // Eliminate splat constant transpose ops.
6347 if (auto splat =
6348 llvm::dyn_cast_if_present<SplatElementsAttr>(Val: adaptor.getVector()))
6349 return splat.reshape(newType: getResultVectorType());
6350
6351 // Eliminate poison transpose ops.
6352 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(Val: adaptor.getVector()))
6353 return ub::PoisonAttr::get(ctx: getContext());
6354
6355 // Eliminate identity transposes, and more generally any transposes that
6356 // preserves the shape without permuting elements.
6357 //
6358 // Examples of what to fold:
6359 // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6360 // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6361 // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6362 //
6363 // Example of what NOT to fold:
6364 // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6365 //
6366 if (getSourceVectorType() == getResultVectorType() &&
6367 isOrderPreserving(transpose: *this))
6368 return getVector();
6369
6370 return {};
6371}
6372
6373LogicalResult vector::TransposeOp::verify() {
6374 VectorType vectorType = getSourceVectorType();
6375 VectorType resultType = getResultVectorType();
6376 int64_t rank = resultType.getRank();
6377 if (vectorType.getRank() != rank)
6378 return emitOpError(message: "vector result rank mismatch: ") << rank;
6379 // Verify transposition array.
6380 ArrayRef<int64_t> perm = getPermutation();
6381 int64_t size = perm.size();
6382 if (rank != size)
6383 return emitOpError(message: "transposition length mismatch: ") << size;
6384 SmallVector<bool, 8> seen(rank, false);
6385 for (const auto &ta : llvm::enumerate(First&: perm)) {
6386 if (ta.value() < 0 || ta.value() >= rank)
6387 return emitOpError(message: "transposition index out of range: ") << ta.value();
6388 if (seen[ta.value()])
6389 return emitOpError(message: "duplicate position index: ") << ta.value();
6390 seen[ta.value()] = true;
6391 if (resultType.getDimSize(idx: ta.index()) != vectorType.getDimSize(idx: ta.value()))
6392 return emitOpError(message: "dimension size mismatch at: ") << ta.value();
6393 }
6394 return success();
6395}
6396
6397std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6398 return llvm::to_vector<4>(Range: getResultVectorType().getShape());
6399}
6400
6401namespace {
6402
6403// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6404class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6405public:
6406 using OpRewritePattern::OpRewritePattern;
6407
6408 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6409 PatternRewriter &rewriter) const override {
6410 // Composes two permutations: result[i] = permutation1[permutation2[i]].
6411 auto composePermutations = [](ArrayRef<int64_t> permutation1,
6412 ArrayRef<int64_t> permutation2) {
6413 SmallVector<int64_t, 4> result;
6414 for (auto index : permutation2)
6415 result.push_back(Elt: permutation1[index]);
6416 return result;
6417 };
6418
6419 // Return if the input of 'transposeOp' is not defined by another transpose.
6420 vector::TransposeOp parentTransposeOp =
6421 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6422 if (!parentTransposeOp)
6423 return failure();
6424
6425 SmallVector<int64_t, 4> permutation = composePermutations(
6426 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6427 // Replace 'transposeOp' with a new transpose operation.
6428 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6429 op: transposeOp, args: transposeOp.getResult().getType(),
6430 args: parentTransposeOp.getVector(), args&: permutation);
6431 return success();
6432 }
6433};
6434
6435// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6436class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6437public:
6438 using OpRewritePattern::OpRewritePattern;
6439
6440 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6441 PatternRewriter &rewriter) const override {
6442 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6443 if (!splatOp)
6444 return failure();
6445
6446 rewriter.replaceOpWithNewOp<vector::SplatOp>(
6447 op: transposeOp, args: transposeOp.getResultVectorType(), args: splatOp.getInput());
6448 return success();
6449 }
6450};
6451
6452/// Folds transpose(create_mask) into a new transposed create_mask.
6453class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6454public:
6455 using OpRewritePattern::OpRewritePattern;
6456
6457 LogicalResult matchAndRewrite(TransposeOp transpOp,
6458 PatternRewriter &rewriter) const override {
6459 Value transposeSrc = transpOp.getVector();
6460 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6461 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6462 if (!createMaskOp && !constantMaskOp)
6463 return failure();
6464
6465 // Get the transpose permutation and apply it to the vector.create_mask or
6466 // vector.constant_mask operands.
6467 ArrayRef<int64_t> permutation = transpOp.getPermutation();
6468
6469 if (createMaskOp) {
6470 auto maskOperands = createMaskOp.getOperands();
6471 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6472 applyPermutationToVector(inVec&: newOperands, permutation);
6473
6474 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6475 op: transpOp, args: transpOp.getResultVectorType(), args&: newOperands);
6476 return success();
6477 }
6478
6479 // ConstantMaskOp case.
6480 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6481 auto newMaskDimSizes = applyPermutation(input: maskDimSizes, permutation);
6482
6483 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6484 op: transpOp, args: transpOp.getResultVectorType(), args&: newMaskDimSizes);
6485 return success();
6486 }
6487};
6488
6489/// Folds transpose(shape_cast) into a new shape_cast.
6490class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6491public:
6492 using OpRewritePattern::OpRewritePattern;
6493
6494 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6495 PatternRewriter &rewriter) const override {
6496 auto shapeCastOp =
6497 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6498 if (!shapeCastOp)
6499 return failure();
6500 if (!isOrderPreserving(transpose: transposeOp))
6501 return failure();
6502
6503 VectorType resultType = transposeOp.getType();
6504
6505 // We don't need to check isValidShapeCast at this point, because it is
6506 // guaranteed that merging the transpose into the the shape_cast is a valid
6507 // shape_cast, because the transpose just inserts/removes ones.
6508
6509 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op: transposeOp, args&: resultType,
6510 args: shapeCastOp.getSource());
6511 return success();
6512 }
6513};
6514
6515/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6516/// 'order preserving', where 'order preserving' means the flattened
6517/// inputs and outputs of the transpose have identical (numerical) values.
6518///
6519/// Example:
6520/// ```
6521/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
6522/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
6523/// to vector<8x1xi32>
6524/// ```
6525/// can be rewritten as the equivalent
6526/// ```
6527/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
6528/// ```
6529/// The algorithm works by partitioning dimensions into groups that can be
6530/// locally permuted while preserving order, and checks that the transpose
6531/// only permutes within these groups.
6532///
6533/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6534/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6535/// broadcasting from 1x1x4x1x1x7.
6536/// ^^^ ^ ^^^ ^
6537/// groups: 0 1 2 3
6538/// Order preserving permutations for this example are ones that only permute
6539/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6540class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6541public:
6542 using OpRewritePattern::OpRewritePattern;
6543 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6544 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6545
6546 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6547 PatternRewriter &rewriter) const override {
6548
6549 vector::BroadcastOp broadcast =
6550 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6551 if (!broadcast) {
6552 return rewriter.notifyMatchFailure(arg&: transpose,
6553 msg: "not preceded by a broadcast");
6554 }
6555
6556 auto inputType = dyn_cast<VectorType>(Val: broadcast.getSourceType());
6557 VectorType outputType = transpose.getResultVectorType();
6558
6559 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6560 bool inputIsScalar = !inputType;
6561 if (inputIsScalar) {
6562 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op: transpose, args&: outputType,
6563 args: broadcast.getSource());
6564 return success();
6565 }
6566
6567 ArrayRef<int64_t> permutation = transpose.getPermutation();
6568 ArrayRef<int64_t> inputShape = inputType.getShape();
6569 int64_t inputRank = inputType.getRank();
6570 int64_t outputRank = transpose.getType().getRank();
6571 int64_t deltaRank = outputRank - inputRank;
6572
6573 int low = 0;
6574 for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6575 bool notOne = inputShape[inputIndex] != 1;
6576 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6577 bool groupEndFound = notOne || prevNotOne;
6578 if (groupEndFound) {
6579 int high = inputIndex + deltaRank;
6580 // Return failure if not all permutation destinations for indices in
6581 // [low, high) are in [low, high), i.e. the permutation is not local to
6582 // the group.
6583 for (int i = low; i < high; ++i) {
6584 if (permutation[i] < low || permutation[i] >= high) {
6585 return rewriter.notifyMatchFailure(
6586 arg&: transpose, msg: "permutation not local to group");
6587 }
6588 }
6589 low = high;
6590 }
6591 }
6592
6593 // We don't need to check the final group [low, outputRank) because if it is
6594 // not locally bound, there must be a preceding group that already failed
6595 // the check (impossible to have just 1 non-locally bound group).
6596
6597 // The preceding logic also ensures that at this point, the output of the
6598 // transpose is definitely broadcastable from the input shape, assert so:
6599 assert(vector::isBroadcastableTo(inputType, outputType) ==
6600 vector::BroadcastableToResult::Success &&
6601 "not broadcastable directly to transpose output");
6602
6603 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op: transpose, args&: outputType,
6604 args: broadcast.getSource());
6605
6606 return success();
6607 }
6608};
6609
6610} // namespace
6611
6612void vector::TransposeOp::getCanonicalizationPatterns(
6613 RewritePatternSet &results, MLIRContext *context) {
6614 results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
6615 FoldTransposeSplat, FoldTransposeBroadcast>(arg&: context);
6616}
6617
6618//===----------------------------------------------------------------------===//
6619// ConstantMaskOp
6620//===----------------------------------------------------------------------===//
6621
6622void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
6623 VectorType type, ConstantMaskKind kind) {
6624 assert(kind == ConstantMaskKind::AllTrue ||
6625 kind == ConstantMaskKind::AllFalse);
6626 build(odsBuilder&: builder, odsState&: result, resultType0: type,
6627 mask_dim_sizes: kind == ConstantMaskKind::AllTrue
6628 ? type.getShape()
6629 : SmallVector<int64_t>(type.getRank(), 0));
6630}
6631
6632LogicalResult ConstantMaskOp::verify() {
6633 auto resultType = llvm::cast<VectorType>(Val: getResult().getType());
6634 // Check the corner case of 0-D vectors first.
6635 if (resultType.getRank() == 0) {
6636 if (getMaskDimSizes().size() != 1)
6637 return emitError(message: "array attr must have length 1 for 0-D vectors");
6638 auto dim = getMaskDimSizes()[0];
6639 if (dim != 0 && dim != 1)
6640 return emitError(message: "mask dim size must be either 0 or 1 for 0-D vectors");
6641 return success();
6642 }
6643
6644 // Verify that array attr size matches the rank of the vector result.
6645 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
6646 return emitOpError(
6647 message: "must specify array attr of size equal vector result rank");
6648 // Verify that each array attr element is in bounds of corresponding vector
6649 // result dimension size.
6650 auto resultShape = resultType.getShape();
6651 auto resultScalableDims = resultType.getScalableDims();
6652 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
6653 for (const auto [index, maskDimSize] : llvm::enumerate(First&: maskDimSizes)) {
6654 if (maskDimSize < 0 || maskDimSize > resultShape[index])
6655 return emitOpError(
6656 message: "array attr of size out of bounds of vector result dimension size");
6657 if (resultScalableDims[index] && maskDimSize != 0 &&
6658 maskDimSize != resultShape[index])
6659 return emitOpError(
6660 message: "only supports 'none set' or 'all set' scalable dimensions");
6661 }
6662 // Verify that if one mask dim size is zero, they all should be zero (because
6663 // the mask region is a conjunction of each mask dimension interval).
6664 bool anyZeros = llvm::is_contained(Range&: maskDimSizes, Element: 0);
6665 bool allZeros = llvm::all_of(Range&: maskDimSizes, P: [](int64_t s) { return s == 0; });
6666 if (anyZeros && !allZeros)
6667 return emitOpError(message: "expected all mask dim sizes to be zeros, "
6668 "as a result of conjunction with zero mask dim");
6669 return success();
6670}
6671
6672bool ConstantMaskOp::isAllOnesMask() {
6673 auto resultType = getVectorType();
6674 // Check the corner case of 0-D vectors first.
6675 if (resultType.getRank() == 0) {
6676 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
6677 return getMaskDimSizes()[0] == 1;
6678 }
6679 for (const auto [resultSize, maskDimSize] :
6680 llvm::zip_equal(t: resultType.getShape(), u: getMaskDimSizes())) {
6681 if (maskDimSize < resultSize)
6682 return false;
6683 }
6684 return true;
6685}
6686
6687OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
6688 ArrayRef<int64_t> bounds = getMaskDimSizes();
6689 ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
6690
6691 auto createBoolSplat = [&](bool x) {
6692 return SplatElementsAttr::get(type: getVectorType(),
6693 values: BoolAttr::get(context: getContext(), value: x));
6694 };
6695
6696 // Check the corner case of 0-D vectors first.
6697 if (vectorSizes.empty()) {
6698 assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
6699 return createBoolSplat(bounds[0] == 1);
6700 }
6701 // Fold vector.constant_mask to splat if possible.
6702 if (bounds == vectorSizes)
6703 return createBoolSplat(true);
6704 if (llvm::all_of(Range&: bounds, P: [](int64_t x) { return x == 0; }))
6705 return createBoolSplat(false);
6706 return OpFoldResult();
6707}
6708
6709//===----------------------------------------------------------------------===//
6710// CreateMaskOp
6711//===----------------------------------------------------------------------===//
6712
6713void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
6714 VectorType type,
6715 ArrayRef<OpFoldResult> mixedOperands) {
6716 SmallVector<Value> operands =
6717 getValueOrCreateConstantIndexOp(b&: builder, loc: result.location, valueOrAttrVec: mixedOperands);
6718 build(odsBuilder&: builder, odsState&: result, resultType0: type, operands);
6719}
6720
6721LogicalResult CreateMaskOp::verify() {
6722 auto vectorType = llvm::cast<VectorType>(Val: getResult().getType());
6723 // Verify that an operand was specified for each result vector each dimension.
6724 if (vectorType.getRank() == 0) {
6725 if (getNumOperands() != 1)
6726 return emitOpError(
6727 message: "must specify exactly one operand for 0-D create_mask");
6728 } else if (getNumOperands() !=
6729 llvm::cast<VectorType>(Val: getResult().getType()).getRank()) {
6730 return emitOpError(
6731 message: "must specify an operand for each result vector dimension");
6732 }
6733 return success();
6734}
6735
6736namespace {
6737
6738/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
6739///
6740/// Ex 1:
6741/// %c2 = arith.constant 2 : index
6742/// %c3 = arith.constant 3 : index
6743/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
6744/// Becomes:
6745/// vector.constant_mask [3, 2] : vector<4x3xi1>
6746///
6747/// Ex 2:
6748/// %c_neg_1 = arith.constant -1 : index
6749/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
6750/// becomes:
6751/// vector.constant_mask [0] : vector<[8]xi1>
6752///
6753/// Ex 3:
6754/// %c8 = arith.constant 8 : index
6755/// %c16 = arith.constant 16 : index
6756/// %0 = vector.vscale
6757/// %1 = arith.muli %0, %c16 : index
6758/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6759/// becomes:
6760/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
6761class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
6762public:
6763 using OpRewritePattern::OpRewritePattern;
6764
6765 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
6766 PatternRewriter &rewriter) const override {
6767 VectorType maskType = createMaskOp.getVectorType();
6768 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
6769 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
6770
6771 // Special case: Rank zero shape.
6772 constexpr std::array<int64_t, 1> rankZeroShape{1};
6773 constexpr std::array<bool, 1> rankZeroScalableDims{false};
6774 if (maskType.getRank() == 0) {
6775 maskTypeDimSizes = rankZeroShape;
6776 maskTypeDimScalableFlags = rankZeroScalableDims;
6777 }
6778
6779 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
6780 // collect the `constantDims` (for the ConstantMaskOp).
6781 SmallVector<int64_t, 4> constantDims;
6782 for (auto [i, dimSize] : llvm::enumerate(First: createMaskOp.getOperands())) {
6783 if (auto intSize = getConstantIntValue(ofr: dimSize)) {
6784 // Constant value.
6785 // If the mask dim is non-scalable this can be any value.
6786 // If the mask dim is scalable only zero (all-false) is supported.
6787 if (maskTypeDimScalableFlags[i] && intSize >= 0)
6788 return failure();
6789 constantDims.push_back(Elt: *intSize);
6790 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(value: dimSize)) {
6791 // Constant vscale multiple (e.g. 4 x vscale).
6792 // Must be all-true to fold to a ConstantMask.
6793 if (vscaleMultiplier < maskTypeDimSizes[i])
6794 return failure();
6795 constantDims.push_back(Elt: *vscaleMultiplier);
6796 } else {
6797 return failure();
6798 }
6799 }
6800
6801 // Clamp values to constant_mask bounds.
6802 for (auto [value, maskDimSize] : llvm::zip(t&: constantDims, u&: maskTypeDimSizes))
6803 value = std::clamp<int64_t>(val: value, lo: 0, hi: maskDimSize);
6804
6805 // If one of dim sizes is zero, set all dims to zero.
6806 if (llvm::is_contained(Range&: constantDims, Element: 0))
6807 constantDims.assign(NumElts: constantDims.size(), Elt: 0);
6808
6809 // Replace 'createMaskOp' with ConstantMaskOp.
6810 rewriter.replaceOpWithNewOp<ConstantMaskOp>(op: createMaskOp, args&: maskType,
6811 args&: constantDims);
6812 return success();
6813 }
6814};
6815
6816} // namespace
6817
6818void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6819 MLIRContext *context) {
6820 results.add<CreateMaskFolder>(arg&: context);
6821}
6822
6823//===----------------------------------------------------------------------===//
6824// MaskOp
6825//===----------------------------------------------------------------------===//
6826
6827void MaskOp::build(
6828 OpBuilder &builder, OperationState &result, Value mask,
6829 Operation *maskableOp,
6830 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6831 assert(maskRegionBuilder &&
6832 "builder callback for 'maskRegion' must be present");
6833
6834 result.addOperands(newOperands: mask);
6835 OpBuilder::InsertionGuard guard(builder);
6836 Region *maskRegion = result.addRegion();
6837 builder.createBlock(parent: maskRegion);
6838 maskRegionBuilder(builder, maskableOp);
6839}
6840
6841void MaskOp::build(
6842 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6843 Value mask, Operation *maskableOp,
6844 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6845 build(odsBuilder&: builder, odsState&: result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
6846 maskRegion: maskRegionBuilder);
6847}
6848
6849void MaskOp::build(
6850 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
6851 Value mask, Value passthru, Operation *maskableOp,
6852 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
6853 build(builder, result, mask, maskableOp, maskRegionBuilder);
6854 if (passthru)
6855 result.addOperands(newOperands: passthru);
6856 result.addTypes(newTypes&: resultTypes);
6857}
6858
6859ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
6860 // Create the op region.
6861 result.regions.reserve(N: 1);
6862 Region &maskRegion = *result.addRegion();
6863
6864 auto &builder = parser.getBuilder();
6865
6866 // Parse all the operands.
6867 OpAsmParser::UnresolvedOperand mask;
6868 if (parser.parseOperand(result&: mask))
6869 return failure();
6870
6871 // Optional passthru operand.
6872 OpAsmParser::UnresolvedOperand passthru;
6873 ParseResult parsePassthru = parser.parseOptionalComma();
6874 if (parsePassthru.succeeded() && parser.parseOperand(result&: passthru))
6875 return failure();
6876
6877 // Parse op region.
6878 if (parser.parseRegion(region&: maskRegion, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
6879 return failure();
6880
6881 MaskOp::ensureTerminator(region&: maskRegion, builder, loc: result.location);
6882
6883 // Parse the optional attribute list.
6884 if (parser.parseOptionalAttrDict(result&: result.attributes))
6885 return failure();
6886
6887 // Parse all the types.
6888 Type maskType;
6889 if (parser.parseColonType(result&: maskType))
6890 return failure();
6891
6892 SmallVector<Type> resultTypes;
6893 if (parser.parseOptionalArrowTypeList(result&: resultTypes))
6894 return failure();
6895 result.types.append(RHS: resultTypes);
6896
6897 // Resolve operands.
6898 if (parser.resolveOperand(operand: mask, type: maskType, result&: result.operands))
6899 return failure();
6900
6901 if (parsePassthru.succeeded()) {
6902 if (resultTypes.empty())
6903 return parser.emitError(
6904 loc: parser.getNameLoc(),
6905 message: "expects a result if passthru operand is provided");
6906
6907 if (parser.resolveOperand(operand: passthru, type: resultTypes[0], result&: result.operands))
6908 return failure();
6909 }
6910
6911 return success();
6912}
6913
6914void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
6915 p << " " << getMask();
6916 if (getPassthru())
6917 p << ", " << getPassthru();
6918
6919 // Print single masked operation and skip terminator.
6920 p << " { ";
6921 Block *singleBlock = &getMaskRegion().getBlocks().front();
6922 if (singleBlock && !singleBlock->getOperations().empty())
6923 p.printCustomOrGenericOp(op: &singleBlock->front());
6924 p << " }";
6925
6926 p.printOptionalAttrDict(attrs: getOperation()->getAttrs());
6927
6928 p << " : " << getMask().getType();
6929 if (getNumResults() > 0)
6930 p << " -> " << getResultTypes();
6931}
6932
6933void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
6934 // 1. For an empty `vector.mask`, create a default terminator.
6935 if (region.empty() || region.front().empty()) {
6936 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
6937 MaskOp>::ensureTerminator(region, builder, loc);
6938 return;
6939 }
6940
6941 // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
6942 Block &block = region.front();
6943 if (isa<vector::YieldOp>(Val: block.back()))
6944 return;
6945
6946 // 3. For a non-empty `vector.mask` without an explicit terminator:
6947
6948 // Create default terminator if the number of masked operations is not
6949 // one. This case will trigger a verification failure.
6950 if (block.getOperations().size() != 1) {
6951 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
6952 MaskOp>::ensureTerminator(region, builder, loc);
6953 return;
6954 }
6955
6956 // Create a terminator that yields the results from the masked operation.
6957 OpBuilder opBuilder(builder.getContext());
6958 Operation *maskedOp = &block.front();
6959 opBuilder.setInsertionPointToEnd(&block);
6960 opBuilder.create<vector::YieldOp>(location: loc, args: maskedOp->getResults());
6961}
6962
6963LogicalResult MaskOp::verify() {
6964 // Structural checks.
6965 Block &block = getMaskRegion().getBlocks().front();
6966 if (block.getOperations().empty())
6967 return emitOpError(message: "expects a terminator within the mask region");
6968
6969 unsigned numMaskRegionOps = block.getOperations().size();
6970 if (numMaskRegionOps > 2)
6971 return emitOpError(message: "expects only one operation to mask");
6972
6973 // Terminator checks.
6974 auto terminator = dyn_cast<vector::YieldOp>(Val&: block.back());
6975 if (!terminator)
6976 return emitOpError(message: "expects a terminator within the mask region");
6977
6978 if (terminator->getNumOperands() != getNumResults())
6979 return emitOpError(
6980 message: "expects number of results to match mask region yielded values");
6981
6982 // Empty vector.mask. Nothing else to check.
6983 if (numMaskRegionOps == 1)
6984 return success();
6985
6986 auto maskableOp = dyn_cast<MaskableOpInterface>(Val&: block.front());
6987 if (!maskableOp)
6988 return emitOpError(message: "expects a MaskableOpInterface within the mask region");
6989
6990 // Result checks.
6991 if (maskableOp->getNumResults() != getNumResults())
6992 return emitOpError(message: "expects number of results to match maskable operation "
6993 "number of results");
6994
6995 if (!llvm::equal(LRange: maskableOp->getResults(), RRange: terminator.getOperands()))
6996 return emitOpError(message: "expects all the results from the MaskableOpInterface "
6997 "to match all the values returned by the terminator");
6998
6999 if (!llvm::equal(LRange: maskableOp->getResultTypes(), RRange: getResultTypes()))
7000 return emitOpError(
7001 message: "expects result type to match maskable operation result type");
7002
7003 if (llvm::count_if(Range: maskableOp->getResultTypes(),
7004 P: [](Type t) { return llvm::isa<VectorType>(Val: t); }) > 1)
7005 return emitOpError(message: "multiple vector results not supported");
7006
7007 // Mask checks.
7008 Type expectedMaskType = maskableOp.getExpectedMaskType();
7009 if (getMask().getType() != expectedMaskType)
7010 return emitOpError(message: "expects a ")
7011 << expectedMaskType << " mask for the maskable operation";
7012
7013 // Passthru checks.
7014 Value passthru = getPassthru();
7015 if (passthru) {
7016 if (!maskableOp.supportsPassthru())
7017 return emitOpError(
7018 message: "doesn't expect a passthru argument for this maskable operation");
7019
7020 if (maskableOp->getNumResults() != 1)
7021 return emitOpError(message: "expects result when passthru argument is provided");
7022
7023 if (passthru.getType() != maskableOp->getResultTypes()[0])
7024 return emitOpError(message: "expects passthru type to match result type");
7025 }
7026
7027 return success();
7028}
7029
7030/// Folds empty `vector.mask` with no passthru operand and with or without
7031/// return values. For example:
7032///
7033/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7034/// vector<8xi1> -> vector<8xf32>
7035/// %1 = user_op %0 : vector<8xf32>
7036///
7037/// becomes:
7038///
7039/// %0 = user_op %a : vector<8xf32>
7040///
7041/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7042/// as it requires creating new operations.
7043
7044static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7045 SmallVectorImpl<OpFoldResult> &results) {
7046 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7047 return failure();
7048
7049 Block *block = maskOp.getMaskBlock();
7050 auto terminator = cast<vector::YieldOp>(Val&: block->front());
7051 if (terminator.getNumOperands() == 0) {
7052 // `vector.mask` has no results, just remove the `vector.mask`.
7053 return success();
7054 }
7055
7056 // `vector.mask` has results, propagate the results.
7057 llvm::append_range(C&: results, R: terminator.getOperands());
7058 return success();
7059}
7060
7061LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7062 SmallVectorImpl<OpFoldResult> &results) {
7063 if (succeeded(Result: foldEmptyMaskOp(maskOp: *this, adaptor, results)))
7064 return success();
7065
7066 MaskFormat maskFormat = getMaskFormat(mask: getMask());
7067 if (maskFormat != MaskFormat::AllTrue)
7068 return failure();
7069
7070 // Move maskable operation outside of the `vector.mask` region.
7071 Operation *maskableOp = getMaskableOp();
7072 maskableOp->dropAllUses();
7073 maskableOp->moveBefore(existingOp: getOperation());
7074
7075 llvm::append_range(C&: results, R: maskableOp->getResults());
7076 return success();
7077}
7078
7079/// Canonialize empty `vector.mask` operations that can't be handled in
7080/// `VectorMask::fold` as they require creating new operations.
7081///
7082/// Example 1: Empty `vector.mask` with passthru operand.
7083///
7084/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
7085/// vector<8xi1> -> vector<8xf32>
7086///
7087/// becomes:
7088///
7089/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
7090///
7091class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
7092 using OpRewritePattern::OpRewritePattern;
7093
7094 LogicalResult matchAndRewrite(MaskOp maskOp,
7095 PatternRewriter &rewriter) const override {
7096 if (!maskOp.isEmpty())
7097 return failure();
7098
7099 if (!maskOp.hasPassthru())
7100 return failure();
7101
7102 Block *block = maskOp.getMaskBlock();
7103 auto terminator = cast<vector::YieldOp>(Val&: block->front());
7104 assert(terminator.getNumOperands() == 1 &&
7105 "expected one result when passthru is provided");
7106
7107 rewriter.replaceOpWithNewOp<arith::SelectOp>(
7108 op: maskOp, args: maskOp.getResultTypes(), args: maskOp.getMask(),
7109 args: terminator.getOperand(i: 0), args: maskOp.getPassthru());
7110
7111 return success();
7112 }
7113};
7114
7115void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7116 MLIRContext *context) {
7117 results.add<CanonializeEmptyMaskOp>(arg&: context);
7118}
7119
7120// MaskingOpInterface definitions.
7121
7122/// Returns the operation masked by this 'vector.mask'.
7123Operation *MaskOp::getMaskableOp() {
7124 Block *block = getMaskBlock();
7125 if (block->getOperations().size() < 2)
7126 return nullptr;
7127
7128 return &block->front();
7129}
7130
7131/// Returns true if 'vector.mask' has a passthru value.
7132bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
7133
7134//===----------------------------------------------------------------------===//
7135// ScanOp
7136//===----------------------------------------------------------------------===//
7137
7138LogicalResult ScanOp::verify() {
7139 VectorType srcType = getSourceType();
7140 VectorType initialType = getInitialValueType();
7141 // Check reduction dimension < rank.
7142 int64_t srcRank = srcType.getRank();
7143 int64_t reductionDim = getReductionDim();
7144 if (reductionDim >= srcRank)
7145 return emitOpError(message: "reduction dimension ")
7146 << reductionDim << " has to be less than " << srcRank;
7147
7148 // Check that rank(initial_value) = rank(src) - 1.
7149 int64_t initialValueRank = initialType.getRank();
7150 if (initialValueRank != srcRank - 1)
7151 return emitOpError(message: "initial value rank ")
7152 << initialValueRank << " has to be equal to " << srcRank - 1;
7153
7154 // Check shapes of initial value and src.
7155 ArrayRef<int64_t> srcShape = srcType.getShape();
7156 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7157 SmallVector<int64_t> expectedShape;
7158 for (int i = 0; i < srcRank; i++) {
7159 if (i != reductionDim)
7160 expectedShape.push_back(Elt: srcShape[i]);
7161 }
7162 if (!llvm::equal(LRange&: initialValueShapes, RRange&: expectedShape)) {
7163 return emitOpError(message: "incompatible input/initial value shapes");
7164 }
7165
7166 // Verify supported reduction kind.
7167 Type eltType = getDestType().getElementType();
7168 if (!isSupportedCombiningKind(combiningKind: getKind(), elementType: eltType))
7169 return emitOpError(message: "unsupported reduction type ")
7170 << eltType << " for kind '" << stringifyCombiningKind(val: getKind())
7171 << "'";
7172
7173 return success();
7174}
7175
7176void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
7177 RewritePatternSet &patterns, PatternBenefit benefit) {
7178 patterns
7179 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7180 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7181 StridedSliceConstantMaskFolder, TransposeFolder>(
7182 arg: patterns.getContext(), args&: benefit);
7183}
7184
7185//===----------------------------------------------------------------------===//
7186// SplatOp
7187//===----------------------------------------------------------------------===//
7188
7189OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
7190 auto constOperand = adaptor.getInput();
7191 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(Val: constOperand))
7192 return {};
7193
7194 // SplatElementsAttr::get treats single value for second arg as being a splat.
7195 return SplatElementsAttr::get(type: getType(), list: {constOperand});
7196}
7197
7198void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7199 SetIntRangeFn setResultRanges) {
7200 setResultRanges(getResult(), argRanges.front());
7201}
7202
7203Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
7204 CombiningKind kind, Value v1, Value acc,
7205 arith::FastMathFlagsAttr fastmath,
7206 Value mask) {
7207 Type t1 = getElementTypeOrSelf(type: v1.getType());
7208 Type tAcc = getElementTypeOrSelf(type: acc.getType());
7209 Value result;
7210
7211 switch (kind) {
7212 case CombiningKind::ADD:
7213 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7214 result = b.createOrFold<arith::AddIOp>(location: loc, args&: v1, args&: acc);
7215 else if (llvm::isa<FloatType>(Val: t1) && llvm::isa<FloatType>(Val: tAcc))
7216 result = b.createOrFold<arith::AddFOp>(location: loc, args&: v1, args&: acc, args&: fastmath);
7217 else
7218 llvm_unreachable("invalid value types for ADD reduction");
7219 break;
7220 case CombiningKind::AND:
7221 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7222 result = b.createOrFold<arith::AndIOp>(location: loc, args&: v1, args&: acc);
7223 break;
7224 case CombiningKind::MAXNUMF:
7225 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7226 "expected float values");
7227 result = b.createOrFold<arith::MaxNumFOp>(location: loc, args&: v1, args&: acc, args&: fastmath);
7228 break;
7229 case CombiningKind::MAXIMUMF:
7230 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7231 "expected float values");
7232 result = b.createOrFold<arith::MaximumFOp>(location: loc, args&: v1, args&: acc, args&: fastmath);
7233 break;
7234 case CombiningKind::MINNUMF:
7235 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7236 "expected float values");
7237 result = b.createOrFold<arith::MinNumFOp>(location: loc, args&: v1, args&: acc, args&: fastmath);
7238 break;
7239 case CombiningKind::MINIMUMF:
7240 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7241 "expected float values");
7242 result = b.createOrFold<arith::MinimumFOp>(location: loc, args&: v1, args&: acc, args&: fastmath);
7243 break;
7244 case CombiningKind::MAXSI:
7245 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7246 result = b.createOrFold<arith::MaxSIOp>(location: loc, args&: v1, args&: acc);
7247 break;
7248 case CombiningKind::MINSI:
7249 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7250 result = b.createOrFold<arith::MinSIOp>(location: loc, args&: v1, args&: acc);
7251 break;
7252 case CombiningKind::MAXUI:
7253 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7254 result = b.createOrFold<arith::MaxUIOp>(location: loc, args&: v1, args&: acc);
7255 break;
7256 case CombiningKind::MINUI:
7257 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7258 result = b.createOrFold<arith::MinUIOp>(location: loc, args&: v1, args&: acc);
7259 break;
7260 case CombiningKind::MUL:
7261 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7262 result = b.createOrFold<arith::MulIOp>(location: loc, args&: v1, args&: acc);
7263 else if (llvm::isa<FloatType>(Val: t1) && llvm::isa<FloatType>(Val: tAcc))
7264 result = b.createOrFold<arith::MulFOp>(location: loc, args&: v1, args&: acc, args&: fastmath);
7265 else
7266 llvm_unreachable("invalid value types for MUL reduction");
7267 break;
7268 case CombiningKind::OR:
7269 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7270 result = b.createOrFold<arith::OrIOp>(location: loc, args&: v1, args&: acc);
7271 break;
7272 case CombiningKind::XOR:
7273 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7274 result = b.createOrFold<arith::XOrIOp>(location: loc, args&: v1, args&: acc);
7275 break;
7276 };
7277
7278 assert(result && "unknown CombiningKind");
7279 return selectPassthru(builder&: b, mask, newValue: result, passthru: acc);
7280}
7281
7282//===----------------------------------------------------------------------===//
7283// Vector Masking Utilities
7284//===----------------------------------------------------------------------===//
7285
7286/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
7287/// as masked operation.
7288void mlir::vector::createMaskOpRegion(OpBuilder &builder,
7289 Operation *maskableOp) {
7290 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
7291 Block *insBlock = builder.getInsertionBlock();
7292 // Create a block and move the op to that block.
7293 insBlock->getOperations().splice(
7294 where: insBlock->begin(), L2&: maskableOp->getBlock()->getOperations(), N: maskableOp);
7295 builder.create<YieldOp>(location: maskableOp->getLoc(), args: maskableOp->getResults());
7296}
7297
7298/// Creates a vector.mask operation around a maskable operation. Returns the
7299/// vector.mask operation if the mask provided is valid. Otherwise, returns
7300/// the maskable operation itself.
7301Operation *mlir::vector::maskOperation(OpBuilder &builder,
7302 Operation *maskableOp, Value mask,
7303 Value passthru) {
7304 if (!mask)
7305 return maskableOp;
7306 if (passthru)
7307 return builder.create<MaskOp>(location: maskableOp->getLoc(),
7308 args: maskableOp->getResultTypes(), args&: mask, args&: passthru,
7309 args&: maskableOp, args&: createMaskOpRegion);
7310 return builder.create<MaskOp>(location: maskableOp->getLoc(),
7311 args: maskableOp->getResultTypes(), args&: mask, args&: maskableOp,
7312 args&: createMaskOpRegion);
7313}
7314
7315/// Creates a vector select operation that picks values from `newValue` or
7316/// `passthru` for each result vector lane based on `mask`. This utility is used
7317/// to propagate the pass-thru value of vector.mask or for cases where only the
7318/// pass-thru value propagation is needed. VP intrinsics do not support
7319/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
7320/// usually able to match op + select patterns and fold them into a native
7321/// target instructions.
7322Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
7323 Value newValue, Value passthru) {
7324 if (!mask)
7325 return newValue;
7326
7327 return builder.create<arith::SelectOp>(location: newValue.getLoc(), args: newValue.getType(),
7328 args&: mask, args&: newValue, args&: passthru);
7329}
7330
7331//===----------------------------------------------------------------------===//
7332// TableGen'd op method definitions
7333//===----------------------------------------------------------------------===//
7334
7335#define GET_ATTRDEF_CLASSES
7336#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7337
7338#define GET_OP_CLASSES
7339#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
7340

source code of mlir/lib/Dialect/Vector/IR/VectorOps.cpp