1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15
16#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Arith/Utils/Utils.h"
19#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Utils/IndexingUtils.h"
23#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
24#include "mlir/IR/AffineExpr.h"
25#include "mlir/IR/AffineMap.h"
26#include "mlir/IR/Builders.h"
27#include "mlir/IR/BuiltinAttributes.h"
28#include "mlir/IR/BuiltinOps.h"
29#include "mlir/IR/BuiltinTypes.h"
30#include "mlir/IR/DialectImplementation.h"
31#include "mlir/IR/IRMapping.h"
32#include "mlir/IR/OpImplementation.h"
33#include "mlir/IR/PatternMatch.h"
34#include "mlir/IR/TypeUtilities.h"
35#include "mlir/Interfaces/SubsetOpInterface.h"
36#include "mlir/Interfaces/ValueBoundsOpInterface.h"
37#include "mlir/Support/LLVM.h"
38#include "mlir/Transforms/InliningUtils.h"
39#include "llvm/ADT/ArrayRef.h"
40#include "llvm/ADT/STLExtras.h"
41#include "llvm/ADT/SmallVector.h"
42#include "llvm/ADT/StringSet.h"
43#include "llvm/ADT/TypeSwitch.h"
44#include "llvm/ADT/bit.h"
45
46#include <cassert>
47#include <cstdint>
48#include <numeric>
49
50#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
51// Pull in all enum type and utility function definitions.
52#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
53
54using namespace mlir;
55using namespace mlir::vector;
56
57/// Helper enum to classify mask value.
58enum class MaskFormat {
59 AllTrue = 0,
60 AllFalse = 1,
61 Unknown = 2,
62};
63
64/// Helper method to classify a mask value. Currently, the method
65/// looks "under the hood" of a constant value with dense attributes
66/// and a constant mask operation (since the client may be called at
67/// various stages during progressive lowering).
68static MaskFormat getMaskFormat(Value mask) {
69 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
70 // Inspect constant dense values. We count up for bits that
71 // are set, count down for bits that are cleared, and bail
72 // when a mix is detected.
73 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
74 int64_t val = 0;
75 for (bool b : denseElts.getValues<bool>())
76 if (b && val >= 0)
77 val++;
78 else if (!b && val <= 0)
79 val--;
80 else
81 return MaskFormat::Unknown;
82 if (val > 0)
83 return MaskFormat::AllTrue;
84 if (val < 0)
85 return MaskFormat::AllFalse;
86 }
87 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
88 // Inspect constant mask index. If the index exceeds the
89 // dimension size, all bits are set. If the index is zero
90 // or less, no bits are set.
91 ArrayAttr masks = m.getMaskDimSizes();
92 auto shape = m.getType().getShape();
93 bool allTrue = true;
94 bool allFalse = true;
95 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96 int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
97 if (i < dimSize)
98 allTrue = false;
99 if (i > 0)
100 allFalse = false;
101 }
102 if (allTrue)
103 return MaskFormat::AllTrue;
104 if (allFalse)
105 return MaskFormat::AllFalse;
106 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
107 // Finds all-false create_masks. An all-true create_mask requires all
108 // dims to be constants, so that'll be folded to a constant_mask, then
109 // detected in the constant_mask case.
110 auto maskOperands = m.getOperands();
111 for (Value operand : maskOperands) {
112 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
113 int64_t dimSize =
114 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
115 if (dimSize <= 0)
116 return MaskFormat::AllFalse;
117 }
118 }
119 return MaskFormat::Unknown;
120 }
121 return MaskFormat::Unknown;
122}
123
124/// Default callback to build a region with a 'vector.yield' terminator with no
125/// arguments.
126void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) {
127 builder.create<vector::YieldOp>(loc);
128}
129
130// Helper for verifying combining kinds in contractions and reductions.
131static bool isSupportedCombiningKind(CombiningKind combiningKind,
132 Type elementType) {
133 switch (combiningKind) {
134 case CombiningKind::ADD:
135 case CombiningKind::MUL:
136 return elementType.isIntOrIndexOrFloat();
137 case CombiningKind::MINUI:
138 case CombiningKind::MINSI:
139 case CombiningKind::MAXUI:
140 case CombiningKind::MAXSI:
141 case CombiningKind::AND:
142 case CombiningKind::OR:
143 case CombiningKind::XOR:
144 return elementType.isIntOrIndex();
145 case CombiningKind::MINNUMF:
146 case CombiningKind::MAXNUMF:
147 case CombiningKind::MINIMUMF:
148 case CombiningKind::MAXIMUMF:
149 return llvm::isa<FloatType>(Val: elementType);
150 }
151 return false;
152}
153
154AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
155 VectorType vectorType) {
156 int64_t elementVectorRank = 0;
157 VectorType elementVectorType =
158 llvm::dyn_cast<VectorType>(shapedType.getElementType());
159 if (elementVectorType)
160 elementVectorRank += elementVectorType.getRank();
161 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
162 // TODO: replace once we have 0-d vectors.
163 if (shapedType.getRank() == 0 &&
164 vectorType.getShape() == ArrayRef<int64_t>{1})
165 return AffineMap::get(
166 /*numDims=*/0, /*numSymbols=*/0,
167 getAffineConstantExpr(0, shapedType.getContext()));
168 return AffineMap::getMinorIdentityMap(
169 dims: shapedType.getRank(), results: vectorType.getRank() - elementVectorRank,
170 context: shapedType.getContext());
171}
172
173bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
174 vector::TransferReadOp read) {
175 return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
176 !read.getMask() && defWrite.getIndices() == read.getIndices() &&
177 defWrite.getVectorType() == read.getVectorType() &&
178 defWrite.getPermutationMap() == read.getPermutationMap();
179}
180
181bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
182 vector::TransferWriteOp priorWrite) {
183 return priorWrite.getIndices() == write.getIndices() &&
184 priorWrite.getMask() == write.getMask() &&
185 priorWrite.getVectorType() == write.getVectorType() &&
186 priorWrite.getPermutationMap() == write.getPermutationMap();
187}
188
189bool mlir::vector::isDisjointTransferIndices(
190 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
191 bool testDynamicValueUsingBounds) {
192 // For simplicity only look at transfer of same type.
193 if (transferA.getVectorType() != transferB.getVectorType())
194 return false;
195 unsigned rankOffset = transferA.getLeadingShapedRank();
196 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
197 Value indexA = transferA.getIndices()[i];
198 Value indexB = transferB.getIndices()[i];
199 std::optional<int64_t> cstIndexA = getConstantIntValue(ofr: indexA);
200 std::optional<int64_t> cstIndexB = getConstantIntValue(ofr: indexB);
201
202 if (i < rankOffset) {
203 // For leading dimensions, if we can prove that index are different we
204 // know we are accessing disjoint slices.
205 if (cstIndexA.has_value() && cstIndexB.has_value()) {
206 if (*cstIndexA != *cstIndexB)
207 return true;
208 continue;
209 }
210 if (testDynamicValueUsingBounds) {
211 // First try to see if we can fully compose and simplify the affine
212 // expression as a fast track.
213 FailureOr<uint64_t> delta =
214 affine::fullyComposeAndComputeConstantDelta(value1: indexA, value2: indexB);
215 if (succeeded(result: delta) && *delta != 0)
216 return true;
217
218 FailureOr<bool> testEqual =
219 ValueBoundsConstraintSet::areEqual(var1: indexA, var2: indexB);
220 if (succeeded(result: testEqual) && !testEqual.value())
221 return true;
222 }
223 } else {
224 // For this dimension, we slice a part of the memref we need to make sure
225 // the intervals accessed don't overlap.
226 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
227 if (cstIndexA.has_value() && cstIndexB.has_value()) {
228 int64_t distance = std::abs(i: *cstIndexA - *cstIndexB);
229 if (distance >= vectorDim)
230 return true;
231 continue;
232 }
233 if (testDynamicValueUsingBounds) {
234 // First try to see if we can fully compose and simplify the affine
235 // expression as a fast track.
236 FailureOr<int64_t> delta =
237 affine::fullyComposeAndComputeConstantDelta(value1: indexA, value2: indexB);
238 if (succeeded(result: delta) && std::abs(i: *delta) >= vectorDim)
239 return true;
240
241 FailureOr<int64_t> computeDelta =
242 ValueBoundsConstraintSet::computeConstantDelta(value1: indexA, value2: indexB);
243 if (succeeded(result: computeDelta)) {
244 if (std::abs(i: computeDelta.value()) >= vectorDim)
245 return true;
246 }
247 }
248 }
249 }
250 return false;
251}
252
253bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
254 VectorTransferOpInterface transferB,
255 bool testDynamicValueUsingBounds) {
256 if (transferA.getSource() != transferB.getSource())
257 return false;
258 return isDisjointTransferIndices(transferA, transferB,
259 testDynamicValueUsingBounds);
260}
261
262// Helper to iterate over n-D vector slice elements. Calculate the next
263// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
264// Modifies the `position` in place. Returns a failure when `position` becomes
265// the end position.
266static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
267 ArrayRef<int64_t> shape,
268 ArrayRef<int64_t> offsets) {
269 for (auto [posInDim, dimSize, offsetInDim] :
270 llvm::reverse(C: llvm::zip_equal(t&: position, u&: shape, args&: offsets))) {
271 ++posInDim;
272 if (posInDim < dimSize + offsetInDim)
273 return success();
274
275 // Carry the overflow to the next loop iteration.
276 posInDim = offsetInDim;
277 }
278
279 return failure();
280}
281
282/// Returns the integer numbers in `values`. `values` are expected to be
283/// constant operations.
284SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) {
285 SmallVector<int64_t> ints;
286 llvm::transform(Range&: values, d_first: std::back_inserter(x&: ints), F: [](Value value) {
287 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
288 assert(constOp && "Unexpected non-constant index");
289 return constOp.value();
290 });
291 return ints;
292}
293
294/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
295/// be constant operations.
296SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {
297 SmallVector<int64_t> ints;
298 llvm::transform(
299 Range&: foldResults, d_first: std::back_inserter(x&: ints), F: [](OpFoldResult foldResult) {
300 assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
301 return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
302 });
303 return ints;
304}
305
306/// Convert `foldResults` into Values. Integer attributes are converted to
307/// constant op.
308SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
309 ArrayRef<OpFoldResult> foldResults) {
310 SmallVector<Value> values;
311 llvm::transform(Range&: foldResults, d_first: std::back_inserter(x&: values),
312 F: [&](OpFoldResult foldResult) {
313 if (auto attr = foldResult.dyn_cast<Attribute>())
314 return builder
315 .create<arith::ConstantIndexOp>(
316 loc, cast<IntegerAttr>(attr).getInt())
317 .getResult();
318
319 return foldResult.get<Value>();
320 });
321 return values;
322}
323
324//===----------------------------------------------------------------------===//
325// CombiningKindAttr
326//===----------------------------------------------------------------------===//
327
328namespace mlir {
329namespace vector {
330namespace detail {
331struct BitmaskEnumStorage : public AttributeStorage {
332 using KeyTy = uint64_t;
333
334 BitmaskEnumStorage(KeyTy val) : value(val) {}
335
336 bool operator==(const KeyTy &key) const { return value == key; }
337
338 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
339 const KeyTy &key) {
340 return new (allocator.allocate<BitmaskEnumStorage>())
341 BitmaskEnumStorage(key);
342 }
343
344 KeyTy value = 0;
345};
346} // namespace detail
347} // namespace vector
348} // namespace mlir
349
350//===----------------------------------------------------------------------===//
351// VectorDialect
352//===----------------------------------------------------------------------===//
353
354namespace {
355/// This class defines the interface for handling inlining with vector dialect
356/// operations.
357struct VectorInlinerInterface : public DialectInlinerInterface {
358 using DialectInlinerInterface::DialectInlinerInterface;
359
360 /// All vector dialect ops can be inlined.
361 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
362 return true;
363 }
364};
365} // namespace
366
367void VectorDialect::initialize() {
368 addAttributes<
369#define GET_ATTRDEF_LIST
370#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
371 >();
372
373 addOperations<
374#define GET_OP_LIST
375#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
376 >();
377
378 addInterfaces<VectorInlinerInterface>();
379
380 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
381 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
382 YieldOp>();
383 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
384 TransferWriteOp>();
385 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
386 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
387}
388
389/// Materialize a single constant operation from a given attribute value with
390/// the desired resultant type.
391Operation *VectorDialect::materializeConstant(OpBuilder &builder,
392 Attribute value, Type type,
393 Location loc) {
394 return arith::ConstantOp::materialize(builder, value, type, loc);
395}
396
397IntegerType vector::getVectorSubscriptType(Builder &builder) {
398 return builder.getIntegerType(64);
399}
400
401ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
402 ArrayRef<int64_t> values) {
403 return builder.getI64ArrayAttr(values);
404}
405
406//===----------------------------------------------------------------------===//
407// MultiDimReductionOp
408//===----------------------------------------------------------------------===//
409
410void vector::MultiDimReductionOp::build(OpBuilder &builder,
411 OperationState &result, Value source,
412 Value acc, ArrayRef<bool> reductionMask,
413 CombiningKind kind) {
414 SmallVector<int64_t> reductionDims;
415 for (const auto &en : llvm::enumerate(reductionMask))
416 if (en.value())
417 reductionDims.push_back(en.index());
418 build(builder, result, kind, source, acc,
419 builder.getI64ArrayAttr(reductionDims));
420}
421
422OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
423 // Single parallel dim, this is a noop.
424 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
425 return getSource();
426 return {};
427}
428
429std::optional<SmallVector<int64_t, 4>>
430MultiDimReductionOp::getShapeForUnroll() {
431 return llvm::to_vector<4>(getSourceVectorType().getShape());
432}
433
434LogicalResult MultiDimReductionOp::verify() {
435 SmallVector<int64_t> targetShape;
436 SmallVector<bool> scalableDims;
437 Type inferredReturnType;
438 auto sourceScalableDims = getSourceVectorType().getScalableDims();
439 for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
440 if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
441 return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
442 })) {
443 targetShape.push_back(it.value());
444 scalableDims.push_back(sourceScalableDims[it.index()]);
445 }
446 // TODO: update to also allow 0-d vectors when available.
447 if (targetShape.empty())
448 inferredReturnType = getSourceVectorType().getElementType();
449 else
450 inferredReturnType = VectorType::get(
451 targetShape, getSourceVectorType().getElementType(), scalableDims);
452 if (getType() != inferredReturnType)
453 return emitOpError() << "destination type " << getType()
454 << " is incompatible with source type "
455 << getSourceVectorType();
456
457 return success();
458}
459
460/// Returns the mask type expected by this operation.
461Type MultiDimReductionOp::getExpectedMaskType() {
462 auto vecType = getSourceVectorType();
463 return VectorType::get(vecType.getShape(),
464 IntegerType::get(vecType.getContext(), /*width=*/1),
465 vecType.getScalableDims());
466}
467
468namespace {
469// Only unit dimensions that are being reduced are folded. If the dimension is
470// unit, but not reduced, it is not folded, thereby keeping the output type the
471// same. If not all dimensions which are reduced are of unit dimension, this
472// transformation does nothing. This is just a generalization of
473// ElideSingleElementReduction for ReduceOp.
474struct ElideUnitDimsInMultiDimReduction
475 : public OpRewritePattern<MultiDimReductionOp> {
476 using OpRewritePattern::OpRewritePattern;
477
478 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
479 PatternRewriter &rewriter) const override {
480 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
481 for (const auto &dim : enumerate(shape)) {
482 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
483 return failure();
484 }
485
486 // Vector mask setup.
487 OpBuilder::InsertionGuard guard(rewriter);
488 Operation *rootOp;
489 Value mask;
490 if (reductionOp.isMasked()) {
491 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
492 rootOp = reductionOp.getMaskingOp();
493 mask = reductionOp.getMaskingOp().getMask();
494 } else {
495 rootOp = reductionOp;
496 }
497
498 Location loc = reductionOp.getLoc();
499 Value acc = reductionOp.getAcc();
500 Value cast;
501 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
502 if (mask) {
503 VectorType newMaskType =
504 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
505 dstVecType.getScalableDims());
506 mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
507 }
508 cast = rewriter.create<vector::ShapeCastOp>(
509 loc, reductionOp.getDestType(), reductionOp.getSource());
510 } else {
511 // This means we are reducing all the dimensions, and all reduction
512 // dimensions are of size 1. So a simple extraction would do.
513 SmallVector<int64_t> zeroIdx(shape.size(), 0);
514 if (mask)
515 mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
516 cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
517 zeroIdx);
518 }
519
520 Value result =
521 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
522 cast, /*fastmath=*/nullptr, mask);
523 rewriter.replaceOp(op: rootOp, newValues: result);
524 return success();
525 }
526};
527} // namespace
528
529void MultiDimReductionOp::getCanonicalizationPatterns(
530 RewritePatternSet &results, MLIRContext *context) {
531 results.add<ElideUnitDimsInMultiDimReduction>(context);
532}
533
534//===----------------------------------------------------------------------===//
535// ReductionOp
536//===----------------------------------------------------------------------===//
537
538void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
539 CombiningKind kind, Value vector,
540 arith::FastMathFlags fastMathFlags) {
541 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
542}
543
544void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
545 CombiningKind kind, Value vector, Value acc,
546 arith::FastMathFlags fastMathFlags) {
547 build(builder, result,
548 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
549 acc, fastMathFlags);
550}
551
552LogicalResult ReductionOp::verify() {
553 // Verify for 0-D and 1-D vector.
554 int64_t rank = getSourceVectorType().getRank();
555 if (rank > 1)
556 return emitOpError("unsupported reduction rank: ") << rank;
557
558 // Verify supported reduction kind.
559 Type eltType = getDest().getType();
560 if (!isSupportedCombiningKind(getKind(), eltType))
561 return emitOpError("unsupported reduction type '")
562 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
563 << "'";
564
565 return success();
566}
567
568// MaskableOpInterface methods.
569
570/// Returns the mask type expected by this operation.
571Type ReductionOp::getExpectedMaskType() {
572 auto vecType = getSourceVectorType();
573 return VectorType::get(vecType.getShape(),
574 IntegerType::get(vecType.getContext(), /*width=*/1),
575 vecType.getScalableDims());
576}
577
578Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
579 OpBuilder &builder, Location loc,
580 Value vector) {
581 switch (op) {
582 case arith::AtomicRMWKind::addf:
583 case arith::AtomicRMWKind::addi:
584 return builder.create<vector::ReductionOp>(vector.getLoc(),
585 CombiningKind::ADD, vector);
586 case arith::AtomicRMWKind::mulf:
587 case arith::AtomicRMWKind::muli:
588 return builder.create<vector::ReductionOp>(vector.getLoc(),
589 CombiningKind::MUL, vector);
590 case arith::AtomicRMWKind::minimumf:
591 return builder.create<vector::ReductionOp>(vector.getLoc(),
592 CombiningKind::MINIMUMF, vector);
593 case arith::AtomicRMWKind::mins:
594 return builder.create<vector::ReductionOp>(vector.getLoc(),
595 CombiningKind::MINSI, vector);
596 case arith::AtomicRMWKind::minu:
597 return builder.create<vector::ReductionOp>(vector.getLoc(),
598 CombiningKind::MINUI, vector);
599 case arith::AtomicRMWKind::maximumf:
600 return builder.create<vector::ReductionOp>(vector.getLoc(),
601 CombiningKind::MAXIMUMF, vector);
602 case arith::AtomicRMWKind::maxs:
603 return builder.create<vector::ReductionOp>(vector.getLoc(),
604 CombiningKind::MAXSI, vector);
605 case arith::AtomicRMWKind::maxu:
606 return builder.create<vector::ReductionOp>(vector.getLoc(),
607 CombiningKind::MAXUI, vector);
608 case arith::AtomicRMWKind::andi:
609 return builder.create<vector::ReductionOp>(vector.getLoc(),
610 CombiningKind::AND, vector);
611 case arith::AtomicRMWKind::ori:
612 return builder.create<vector::ReductionOp>(vector.getLoc(),
613 CombiningKind::OR, vector);
614 // TODO: Add remaining reduction operations.
615 default:
616 (void)emitOptionalError(loc, args: "Reduction operation type not supported");
617 break;
618 }
619 return nullptr;
620}
621
622std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
623 return llvm::to_vector<4>(getSourceVectorType().getShape());
624}
625
626namespace {
627struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
628 using OpRewritePattern::OpRewritePattern;
629
630 LogicalResult matchAndRewrite(ReductionOp reductionOp,
631 PatternRewriter &rewriter) const override {
632 // Vector mask setup.
633 OpBuilder::InsertionGuard guard(rewriter);
634 auto maskableOp =
635 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
636 Operation *rootOp;
637 Value mask;
638 if (maskableOp.isMasked()) {
639 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
640 rootOp = maskableOp.getMaskingOp();
641 mask = maskableOp.getMaskingOp().getMask();
642 } else {
643 rootOp = reductionOp;
644 }
645
646 auto vectorType = reductionOp.getSourceVectorType();
647 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
648 return failure();
649
650 Location loc = reductionOp.getLoc();
651 Value result;
652 if (vectorType.getRank() == 0) {
653 if (mask)
654 mask = rewriter.create<ExtractElementOp>(loc, mask);
655 result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
656 } else {
657 if (mask)
658 mask = rewriter.create<ExtractOp>(loc, mask, 0);
659 result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
660 }
661
662 if (Value acc = reductionOp.getAcc())
663 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
664 result, acc,
665 reductionOp.getFastmathAttr(), mask);
666
667 rewriter.replaceOp(op: rootOp, newValues: result);
668 return success();
669 }
670};
671} // namespace
672
673void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
674 MLIRContext *context) {
675 results.add<ElideSingleElementReduction>(context);
676}
677
678//===----------------------------------------------------------------------===//
679// ContractionOp
680//===----------------------------------------------------------------------===//
681
682void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
683 Value lhs, Value rhs, Value acc,
684 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
685 ArrayRef<IteratorType> iteratorTypes) {
686 result.addOperands({lhs, rhs, acc});
687 result.addTypes(acc.getType());
688 result.addAttribute(
689 getIndexingMapsAttrName(result.name),
690 builder.getAffineMapArrayAttr(
691 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
692 result.addAttribute(
693 getIteratorTypesAttrName(result.name),
694 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
695 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
696 return IteratorTypeAttr::get(builder.getContext(), t);
697 }))));
698}
699
700void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
701 Value lhs, Value rhs, Value acc,
702 ArrayAttr indexingMaps,
703 ArrayAttr iteratorTypes) {
704 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
705 ContractionOp::getDefaultKind());
706}
707
708void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
709 Value lhs, Value rhs, Value acc,
710 ArrayAttr indexingMaps,
711 ArrayAttr iteratorTypes, CombiningKind kind) {
712 result.addOperands({lhs, rhs, acc});
713 result.addTypes(acc.getType());
714 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
715 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
716 result.addAttribute(getKindAttrName(result.name),
717 CombiningKindAttr::get(builder.getContext(), kind));
718}
719
720ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
721 OpAsmParser::UnresolvedOperand lhsInfo;
722 OpAsmParser::UnresolvedOperand rhsInfo;
723 OpAsmParser::UnresolvedOperand accInfo;
724 SmallVector<OpAsmParser::UnresolvedOperand, 2> masksInfo;
725 SmallVector<Type, 2> types;
726 Type resultType;
727 auto loc = parser.getCurrentLocation();
728 DictionaryAttr dictAttr;
729 // TODO: Unify linalg op attribute parsing.
730 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
731 parser.parseComma() || parser.parseOperand(rhsInfo) ||
732 parser.parseComma() || parser.parseOperand(accInfo) ||
733 parser.parseTrailingOperandList(masksInfo) ||
734 parser.parseOptionalAttrDict(result.attributes) ||
735 parser.parseColonTypeList(types) ||
736 parser.parseKeywordType("into", resultType) ||
737 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
738 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
739 parser.resolveOperand(accInfo, resultType, result.operands) ||
740 parser.addTypeToList(resultType, result.types))
741 return failure();
742 result.attributes.append(dictAttr.getValue().begin(),
743 dictAttr.getValue().end());
744
745 // Convert array of string into an array of IteratyType enums. This is needed,
746 // because tests still use the old format when 'iterator_types' attribute is
747 // represented as an array of strings.
748 // TODO: Remove this conversion once tests are fixed.
749 ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
750 result.attributes.get(getIteratorTypesAttrName(result.name)));
751
752 SmallVector<Attribute> iteratorTypeAttrs;
753
754 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
755 auto maybeIteratorType = symbolizeIteratorType(s);
756 if (!maybeIteratorType.has_value())
757 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
758
759 iteratorTypeAttrs.push_back(
760 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
761 }
762 result.attributes.set(getIteratorTypesAttrName(result.name),
763 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
764
765 if (!result.attributes.get(getKindAttrName(result.name))) {
766 result.addAttribute(
767 getKindAttrName(result.name),
768 CombiningKindAttr::get(result.getContext(),
769 ContractionOp::getDefaultKind()));
770 }
771 if (masksInfo.empty())
772 return success();
773 if (masksInfo.size() != 2)
774 return parser.emitError(parser.getNameLoc(),
775 "expected zero or exactly 2 vector mask operands");
776 auto lhsType = llvm::cast<VectorType>(types[0]);
777 auto rhsType = llvm::cast<VectorType>(types[1]);
778 auto maskElementType = parser.getBuilder().getI1Type();
779 std::array<VectorType, 2> maskTypes = {
780 VectorType::Builder(lhsType).setElementType(maskElementType),
781 VectorType::Builder(rhsType).setElementType(maskElementType)};
782 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
783 return failure();
784 return success();
785}
786
787void ContractionOp::print(OpAsmPrinter &p) {
788 // TODO: Unify printing code with linalg ops.
789 auto attrNames = getTraitAttrNames();
790 llvm::StringSet<> traitAttrsSet;
791 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
792 SmallVector<NamedAttribute, 8> attrs;
793 for (auto attr : (*this)->getAttrs()) {
794 if (attr.getName() == getIteratorTypesAttrName()) {
795 auto iteratorTypes =
796 llvm::cast<ArrayAttr>(attr.getValue())
797 .getAsValueRange<IteratorTypeAttr, IteratorType>();
798 // Convert IteratorType enums into the string representation. This is
799 // needed, because tests still use the old format when 'iterator_types'
800 // attribute is represented as an array of strings.
801 // TODO: Remove this conversion once tests are fixed.
802 SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
803 llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
804 return StringAttr::get(getContext(), stringifyIteratorType(t));
805 }));
806
807 attrs.emplace_back(getIteratorTypesAttrName(),
808 ArrayAttr::get(getContext(), iteratorTypeNames));
809 } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
810 attrs.push_back(attr);
811 }
812
813 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
814 p << " " << dictAttr << " " << getLhs() << ", ";
815 p << getRhs() << ", " << getAcc();
816
817 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
818 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
819 << getResultType();
820}
821
822static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
823 const std::vector<std::pair<int64_t, int64_t>> &map) {
824 for (auto &dimPair : map) {
825 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
826 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
827 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
828 return false;
829 }
830 return true;
831}
832
833static LogicalResult verifyOutputShape(
834 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
835 Type resType,
836 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
837 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
838 DenseSet<int64_t> lhsContractingDimSet;
839 DenseSet<int64_t> rhsContractingDimSet;
840 for (auto &dimPair : contractingDimMap) {
841 lhsContractingDimSet.insert(V: dimPair.first);
842 rhsContractingDimSet.insert(V: dimPair.second);
843 }
844 DenseSet<int64_t> rhsBatchDimSet;
845 for (auto &dimPair : batchDimMap)
846 rhsBatchDimSet.insert(V: dimPair.second);
847
848 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
849 SmallVector<int64_t, 4> expectedResultDims;
850 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
851 if (lhsContractingDimSet.count(V: i) > 0)
852 continue;
853 expectedResultDims.push_back(Elt: lhsType.getDimSize(i));
854 }
855
856 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
857 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
858 if (rhsContractingDimSet.count(V: i) > 0 || rhsBatchDimSet.count(V: i) > 0)
859 continue;
860 expectedResultDims.push_back(Elt: rhsType.getDimSize(i));
861 }
862
863 // Verify 'expectedResultDims'.
864 if (expectedResultDims.empty()) {
865 // No batch or free dimension implies a scalar result.
866 if (llvm::isa<VectorType>(Val: resType) || llvm::isa<VectorType>(Val: accType))
867 return op.emitOpError("invalid accumulator/result vector shape");
868 } else {
869 // At least one batch or free dimension implies a vector result.
870 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
871 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
872 if (!resVectorType || !accVectorType)
873 return op.emitOpError("invalid accumulator/result vector shape");
874
875 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
876 // types fully define the result vector type. This assumes the affine maps
877 // are well-formed, which must have been verified already.
878 MLIRContext *ctx = op.getContext();
879 AffineMap lhsMap = op.getIndexingMapsArray()[0];
880 AffineMap rhsMap = op.getIndexingMapsArray()[1];
881 if (getUnusedDimsBitVector(maps: {lhsMap, rhsMap}).any())
882 return op.emitOpError(
883 "expected all dimensions to be either a LHS or a RHS dimension");
884 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
885 for (auto pair :
886 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
887 VectorType v = pair.first;
888 auto map = pair.second;
889 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
890 unsigned pos = map.getDimPosition(idx);
891 if (!extents[pos])
892 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
893 }
894 }
895 if (!llvm::all_of(Range&: extents, P: [](AffineExpr e) { return e; }))
896 return op.emitOpError("expected all dimensions to get an extent as "
897 "either a LHS or a RHS dimension");
898
899 AffineMap resMap = op.getIndexingMapsArray()[2];
900 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
901 /*symbolCount=*/0, results: extents, context: ctx);
902 // Compose the resMap with the extentsMap, which is a constant map.
903 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
904 assert(llvm::all_of(expectedMap.getResults(),
905 llvm::IsaPred<AffineConstantExpr>) &&
906 "expected constant extent along all dimensions.");
907 // Extract the expected shape and build the type.
908 auto expectedShape = llvm::to_vector<4>(
909 Range: llvm::map_range(C: expectedMap.getResults(), F: [](AffineExpr e) {
910 return cast<AffineConstantExpr>(Val&: e).getValue();
911 }));
912 auto expected =
913 VectorType::get(expectedShape, resVectorType.getElementType(),
914 resVectorType.getScalableDims());
915 if (resVectorType != expected || accVectorType != expected)
916 return op.emitOpError(
917 "invalid accumulator/result vector shape, expected: ")
918 << expected;
919 }
920 return success();
921}
922
923LogicalResult ContractionOp::verify() {
924 VectorType lhsType = getLhsType();
925 VectorType rhsType = getRhsType();
926 Type accType = getAccType();
927 Type resType = getResultType();
928
929 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
930 if (!lhsType.getElementType().isSignlessInteger())
931 return emitOpError("only supports signless integer types");
932 }
933
934 // Verify that an indexing map was specified for each vector operand.
935 if (getIndexingMapsArray().size() != 3)
936 return emitOpError("expected an indexing map for each vector operand");
937
938 // Verify that each index map has 'numIterators' inputs, no symbols, and
939 // that the number of map outputs equals the rank of its associated
940 // vector operand.
941 unsigned numIterators = getIteratorTypes().getValue().size();
942 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
943 auto index = it.index();
944 auto map = it.value();
945 if (map.getNumSymbols() != 0)
946 return emitOpError("expected indexing map ")
947 << index << " to have no symbols";
948 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
949 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
950 // Verify that the map has the right number of inputs, outputs, and indices.
951 // This also correctly accounts for (..) -> () for rank-0 results.
952 if (map.getNumDims() != numIterators)
953 return emitOpError("expected indexing map ")
954 << index << " to have " << numIterators << " number of inputs";
955 if (map.getNumResults() != rank)
956 return emitOpError("expected indexing map ")
957 << index << " to have " << rank << " number of outputs";
958 if (!map.isProjectedPermutation())
959 return emitOpError("expected indexing map ")
960 << index << " to be a projected permutation of its inputs";
961 }
962
963 auto contractingDimMap = getContractingDimMap();
964 auto batchDimMap = getBatchDimMap();
965
966 // Verify at least one contracting dimension pair was specified.
967 if (contractingDimMap.empty())
968 return emitOpError("expected at least one contracting dimension pair");
969
970 // Verify contracting dimension map was properly constructed.
971 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
972 return emitOpError("invalid contracting dimension map");
973
974 // Verify batch dimension map was properly constructed.
975 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
976 return emitOpError("invalid batch dimension map");
977
978 // Verify 'accType' and 'resType' shape.
979 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
980 contractingDimMap, batchDimMap)))
981 return failure();
982
983 // Verify supported combining kind.
984 auto vectorType = llvm::dyn_cast<VectorType>(resType);
985 auto elementType = vectorType ? vectorType.getElementType() : resType;
986 if (!isSupportedCombiningKind(getKind(), elementType))
987 return emitOpError("unsupported contraction type");
988
989 return success();
990}
991
992// MaskableOpInterface methods.
993
994/// Returns the mask type expected by this operation. Mostly used for
995/// verification purposes. It requires the operation to be vectorized."
996Type ContractionOp::getExpectedMaskType() {
997 auto indexingMaps = this->getIndexingMapsArray();
998 AffineMap lhsIdxMap = indexingMaps[0];
999 AffineMap rhsIdxMap = indexingMaps[1];
1000 VectorType lhsType = this->getLhsType();
1001 VectorType rhsType = this->getRhsType();
1002
1003 unsigned numVecDims = lhsIdxMap.getNumDims();
1004 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1005 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1006
1007 // Using the information in the indexing maps, extract the size of each
1008 // dimension in the vector.contract operation from the two input operands.
1009 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1010 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1011 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1012 lhsType.getScalableDims()[dimIdx];
1013 }
1014 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1015 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1016 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1017 rhsType.getScalableDims()[dimIdx];
1018 }
1019
1020 assert(!ShapedType::isDynamicShape(maskShape) &&
1021 "Mask shape couldn't be computed");
1022
1023 return VectorType::get(maskShape,
1024 IntegerType::get(lhsType.getContext(), /*width=*/1),
1025 maskShapeScalableDims);
1026}
1027
1028SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1029 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1030 getIteratorTypesAttrName(), getKindAttrName()};
1031}
1032
1033static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
1034 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1035 if (targetExpr == map.getResult(idx: i))
1036 return i;
1037 return -1;
1038}
1039
1040static std::vector<std::pair<int64_t, int64_t>>
1041getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1042 IteratorType targetIteratorType, MLIRContext *context) {
1043 std::vector<std::pair<int64_t, int64_t>> dimMap;
1044 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1045 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1046 if (iteratorType != targetIteratorType)
1047 continue;
1048 // Search lhs/rhs map results for 'targetExpr'.
1049 auto targetExpr = getAffineDimExpr(it.index(), context);
1050 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1051 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1052 if (lhsDim >= 0 && rhsDim >= 0)
1053 dimMap.emplace_back(lhsDim, rhsDim);
1054 }
1055 return dimMap;
1056}
1057
1058void ContractionOp::getIterationBounds(
1059 SmallVectorImpl<int64_t> &iterationBounds) {
1060 auto lhsShape = getLhsType().getShape();
1061 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1062 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1063 SmallVector<int64_t, 2> iterationShape;
1064 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1065 // Search lhs/rhs map results for 'targetExpr'.
1066 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1067 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1068 if (iteratorType == IteratorType::reduction) {
1069 // Get reduction dim size from lhs shape (same size in rhsShape).
1070 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1071 assert(lhsDimIndex >= 0);
1072 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1073 continue;
1074 }
1075 // Get parallel dimension size from result shape.
1076 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1077 assert(resDimIndex >= 0);
1078 assert(resVectorType != nullptr);
1079 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1080 }
1081}
1082
1083void ContractionOp::getIterationIndexMap(
1084 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1085 unsigned numMaps = getIndexingMapsArray().size();
1086 iterationIndexMap.resize(numMaps);
1087 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1088 auto index = it.index();
1089 auto map = it.value();
1090 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1091 auto dim = cast<AffineDimExpr>(map.getResult(i));
1092 iterationIndexMap[index][dim.getPosition()] = i;
1093 }
1094 }
1095}
1096
1097std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1098 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1099 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1100 getContext());
1101}
1102
1103std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1104 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1105 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1106 getContext());
1107}
1108
1109std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1110 SmallVector<int64_t, 4> shape;
1111 getIterationBounds(shape);
1112 return shape;
1113}
1114
1115/// Return a fused vector::ContractionOp which represents a patterns such as:
1116///
1117/// ```mlir
1118/// %c0 = vector.constant 0: ...
1119/// %c = vector.contract %a, %b, %c0: ...
1120/// %e = add %c, %d: ...
1121/// ```
1122///
1123/// by:
1124///
1125/// ```mlir
1126/// %e = vector.contract %a, %b, %d: ...
1127/// ```
1128///
1129/// Return null if the canonicalization does not apply.
1130// TODO: This should be a folding of Add into Contract in core but while they
1131// live in different dialects, it is not possible without unnatural
1132// dependencies.
1133template <typename AddOpType>
1134struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1135 using OpRewritePattern<AddOpType>::OpRewritePattern;
1136
1137 LogicalResult matchAndRewrite(AddOpType addOp,
1138 PatternRewriter &rewriter) const override {
1139 auto canonicalize = [&](Value maybeContraction,
1140 Value otherOperand) -> vector::ContractionOp {
1141 vector::ContractionOp contractionOp =
1142 dyn_cast_or_null<vector::ContractionOp>(
1143 maybeContraction.getDefiningOp());
1144 if (!contractionOp)
1145 return vector::ContractionOp();
1146 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1147 contractionOp.getAcc().getDefiningOp())) {
1148 if (maybeZero.getValue() ==
1149 rewriter.getZeroAttr(type: contractionOp.getAcc().getType())) {
1150 IRMapping bvm;
1151 bvm.map(contractionOp.getAcc(), otherOperand);
1152 auto newContraction =
1153 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1154 rewriter.replaceOp(addOp, newContraction.getResult());
1155 return newContraction;
1156 }
1157 }
1158 return vector::ContractionOp();
1159 };
1160
1161 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1162 vector::ContractionOp contract = canonicalize(a, b);
1163 contract = contract ? contract : canonicalize(b, a);
1164 return contract ? success() : failure();
1165 }
1166};
1167
1168void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1169 MLIRContext *context) {
1170 results.add<CanonicalizeContractAdd<arith::AddIOp>,
1171 CanonicalizeContractAdd<arith::AddFOp>>(context);
1172}
1173
1174//===----------------------------------------------------------------------===//
1175// ExtractElementOp
1176//===----------------------------------------------------------------------===//
1177
1178void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1179 Value source) {
1180 result.addOperands({source});
1181 result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
1182}
1183
1184LogicalResult vector::ExtractElementOp::verify() {
1185 VectorType vectorType = getSourceVectorType();
1186 if (vectorType.getRank() == 0) {
1187 if (getPosition())
1188 return emitOpError("expected position to be empty with 0-D vector");
1189 return success();
1190 }
1191 if (vectorType.getRank() != 1)
1192 return emitOpError("unexpected >1 vector rank");
1193 if (!getPosition())
1194 return emitOpError("expected position for 1-D vector");
1195 return success();
1196}
1197
1198OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
1199 // Skip the 0-D vector here now.
1200 if (!adaptor.getPosition())
1201 return {};
1202
1203 // Fold extractelement (splat X) -> X.
1204 if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1205 return splat.getInput();
1206
1207 // Fold extractelement(broadcast(X)) -> X.
1208 if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
1209 if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
1210 return broadcast.getSource();
1211
1212 auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
1213 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
1214 if (!pos || !src)
1215 return {};
1216
1217 auto srcElements = src.getValues<Attribute>();
1218
1219 uint64_t posIdx = pos.getInt();
1220 if (posIdx >= srcElements.size())
1221 return {};
1222
1223 return srcElements[posIdx];
1224}
1225
1226//===----------------------------------------------------------------------===//
1227// ExtractOp
1228//===----------------------------------------------------------------------===//
1229
1230void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1231 Value source, int64_t position) {
1232 build(builder, result, source, ArrayRef<int64_t>{position});
1233}
1234
1235void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1236 Value source, OpFoldResult position) {
1237 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1238}
1239
1240void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1241 Value source, ArrayRef<int64_t> position) {
1242 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1243 builder.getDenseI64ArrayAttr(position));
1244}
1245
1246void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1247 Value source, ArrayRef<OpFoldResult> position) {
1248 SmallVector<int64_t> staticPos;
1249 SmallVector<Value> dynamicPos;
1250 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1251 build(builder, result, source, dynamicPos,
1252 builder.getDenseI64ArrayAttr(staticPos));
1253}
1254
1255LogicalResult
1256ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1257 ExtractOp::Adaptor adaptor,
1258 SmallVectorImpl<Type> &inferredReturnTypes) {
1259 auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
1260 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1261 vectorType.getRank()) {
1262 inferredReturnTypes.push_back(vectorType.getElementType());
1263 } else {
1264 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1265 vectorType.getRank());
1266 inferredReturnTypes.push_back(VectorType::get(
1267 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1268 vectorType.getScalableDims().drop_front(n)));
1269 }
1270 return success();
1271}
1272
1273bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1274 // Allow extracting 1-element vectors instead of scalars.
1275 auto isCompatible = [](TypeRange l, TypeRange r) {
1276 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1277 return vectorType && vectorType.getShape().equals({1}) &&
1278 vectorType.getElementType() == r.front();
1279 };
1280 if (l.size() == 1 && r.size() == 1 &&
1281 (isCompatible(l, r) || isCompatible(r, l)))
1282 return true;
1283 return l == r;
1284}
1285
1286LogicalResult vector::ExtractOp::verify() {
1287 // Note: This check must come before getMixedPosition() to prevent a crash.
1288 auto dynamicMarkersCount =
1289 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1290 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1291 return emitOpError(
1292 "mismatch between dynamic and static positions (kDynamic marker but no "
1293 "corresponding dynamic position) -- this can only happen due to an "
1294 "incorrect fold/rewrite");
1295 auto position = getMixedPosition();
1296 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1297 return emitOpError(
1298 "expected position attribute of rank no greater than vector rank");
1299 for (auto [idx, pos] : llvm::enumerate(position)) {
1300 if (pos.is<Attribute>()) {
1301 int64_t constIdx = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
1302 if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) {
1303 return emitOpError("expected position attribute #")
1304 << (idx + 1)
1305 << " to be a non-negative integer smaller than the "
1306 "corresponding vector dimension";
1307 }
1308 }
1309 }
1310 return success();
1311}
1312
1313template <typename IntType>
1314static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1315 return llvm::to_vector<4>(llvm::map_range(
1316 arrayAttr.getAsRange<IntegerAttr>(),
1317 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1318}
1319
1320/// Fold the result of chains of ExtractOp in place by simply concatenating the
1321/// positions.
1322static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1323 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1324 return failure();
1325
1326 // TODO: Canonicalization for dynamic position not implemented yet.
1327 if (extractOp.hasDynamicPosition())
1328 return failure();
1329
1330 SmallVector<int64_t> globalPosition;
1331 ExtractOp currentOp = extractOp;
1332 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1333 globalPosition.append(in_start: extrPos.rbegin(), in_end: extrPos.rend());
1334 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1335 currentOp = nextOp;
1336 // TODO: Canonicalization for dynamic position not implemented yet.
1337 if (currentOp.hasDynamicPosition())
1338 return failure();
1339 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1340 globalPosition.append(in_start: extrPos.rbegin(), in_end: extrPos.rend());
1341 }
1342 extractOp.setOperand(0, currentOp.getVector());
1343 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1344 OpBuilder b(extractOp.getContext());
1345 std::reverse(first: globalPosition.begin(), last: globalPosition.end());
1346 extractOp.setStaticPosition(globalPosition);
1347 return success();
1348}
1349
1350namespace {
1351/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1352/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1353/// Compose TransposeOp permutations as we walk back.
1354/// This helper class keeps an updated extraction position `extractPosition`
1355/// with extra trailing sentinels.
1356/// The sentinels encode the internal transposition status of the result vector.
1357/// As we iterate, extractPosition is permuted and updated.
1358class ExtractFromInsertTransposeChainState {
1359public:
1360 ExtractFromInsertTransposeChainState(ExtractOp e);
1361
1362 /// Iterate over producing insert and transpose ops until we find a fold.
1363 Value fold();
1364
1365private:
1366 /// Return true if the vector at position `a` is contained within the vector
1367 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1368 /// is a prefix of `b`.
1369 template <typename ContainerA, typename ContainerB>
1370 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1371 return a.size() <= b.size() &&
1372 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1373 }
1374
1375 /// Return true if the vector at position `a` intersects the vector at
1376 /// position `b`. Under insert/extract semantics, this is the same as equality
1377 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1378 /// Comparison is on the common prefix (i.e. zip).
1379 template <typename ContainerA, typename ContainerB>
1380 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1381 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1382 if (elemA < 0 || elemB < 0)
1383 continue;
1384 if (elemA != elemB)
1385 return false;
1386 }
1387 return true;
1388 }
1389
1390 /// Folding is only possible in the absence of an internal permutation in the
1391 /// result vector.
1392 bool canFold() {
1393 return (sentinels == ArrayRef(extractPosition).drop_front(N: extractedRank));
1394 }
1395
1396 // Helper to get the next defining op of interest.
1397 void updateStateForNextIteration(Value v) {
1398 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1399 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1400 };
1401
1402 // Case 1. If we hit a transpose, just compose the map and iterate.
1403 // Invariant: insert + transpose do not change rank, we can always compose.
1404 LogicalResult handleTransposeOp();
1405
1406 // Case 2: the insert position matches extractPosition exactly, early return.
1407 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1408
1409 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1410 /// portion of the source of the insert.
1411 /// Example:
1412 /// ```
1413 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1414 /// // extractPosition == [1, 2, 3]
1415 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1416 /// // can fold to vector.extract %source[0, 3]
1417 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1418 /// ```
1419 /// To traverse through %source, we need to set the leading dims to 0 and
1420 /// drop the extra leading dims.
1421 /// This method updates the internal state.
1422 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1423
1424 /// Try to fold in place to extract(source, extractPosition) and return the
1425 /// folded result. Return null if folding is not possible (e.g. due to an
1426 /// internal tranposition in the result).
1427 Value tryToFoldExtractOpInPlace(Value source);
1428
1429 ExtractOp extractOp;
1430 int64_t vectorRank;
1431 int64_t extractedRank;
1432
1433 InsertOp nextInsertOp;
1434 TransposeOp nextTransposeOp;
1435
1436 /// Sentinel values that encode the internal permutation status of the result.
1437 /// They are set to (-1, ... , -k) at the beginning and appended to
1438 /// `extractPosition`.
1439 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1440 /// ensure that there is no internal transposition.
1441 /// Internal transposition cannot be accounted for with a folding pattern.
1442 // TODO: We could relax the internal transposition with an extra transposition
1443 // operation in a future canonicalizer.
1444 SmallVector<int64_t> sentinels;
1445 SmallVector<int64_t> extractPosition;
1446};
1447} // namespace
1448
1449ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1450 ExtractOp e)
1451 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1452 extractedRank(extractOp.getNumIndices()) {
1453 assert(vectorRank >= extractedRank && "Extracted position overflow");
1454 sentinels.reserve(N: vectorRank - extractedRank);
1455 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1456 sentinels.push_back(Elt: -(i + 1));
1457 extractPosition.assign(extractOp.getStaticPosition().begin(),
1458 extractOp.getStaticPosition().end());
1459 llvm::append_range(C&: extractPosition, R&: sentinels);
1460}
1461
1462// Case 1. If we hit a transpose, just compose the map and iterate.
1463// Invariant: insert + transpose do not change rank, we can always compose.
1464LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1465 // TODO: Canonicalization for dynamic position not implemented yet.
1466 if (extractOp.hasDynamicPosition())
1467 return failure();
1468
1469 if (!nextTransposeOp)
1470 return failure();
1471 AffineMap m = inversePermutation(AffineMap::getPermutationMap(
1472 nextTransposeOp.getPermutation(), extractOp.getContext()));
1473 extractPosition = applyPermutationMap(map: m, source: ArrayRef(extractPosition));
1474 return success();
1475}
1476
1477// Case 2: the insert position matches extractPosition exactly, early return.
1478LogicalResult
1479ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1480 Value &res) {
1481 // TODO: Canonicalization for dynamic position not implemented yet.
1482 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1483 return failure();
1484
1485 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1486 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(N: extractedRank))
1487 return failure();
1488 // Case 2.a. early-exit fold.
1489 res = nextInsertOp.getSource();
1490 // Case 2.b. if internal transposition is present, canFold will be false.
1491 return success(isSuccess: canFold());
1492}
1493
1494/// Case 3: if inserted position is a prefix of extractPosition,
1495/// extract a portion of the source of the insertion.
1496/// This method updates the internal state.
1497LogicalResult
1498ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1499 // TODO: Canonicalization for dynamic position not implemented yet.
1500 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1501 return failure();
1502
1503 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1504 if (!isContainedWithin(a: insertedPos, b: extractPosition))
1505 return failure();
1506 // Set leading dims to zero.
1507 std::fill_n(first: extractPosition.begin(), n: insertedPos.size(), value: 0);
1508 // Drop extra leading dims.
1509 extractPosition.erase(CS: extractPosition.begin(),
1510 CE: extractPosition.begin() + insertedPos.size());
1511 extractedRank = extractPosition.size() - sentinels.size();
1512 // Case 3.a. early-exit fold (break and delegate to post-while path).
1513 res = nextInsertOp.getSource();
1514 // Case 3.b. if internal transposition is present, canFold will be false.
1515 return success();
1516}
1517
1518/// Try to fold in place to extract(source, extractPosition) and return the
1519/// folded result. Return null if folding is not possible (e.g. due to an
1520/// internal tranposition in the result).
1521Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1522 Value source) {
1523 // TODO: Canonicalization for dynamic position not implemented yet.
1524 if (extractOp.hasDynamicPosition())
1525 return Value();
1526
1527 // If we can't fold (either internal transposition, or nothing to fold), bail.
1528 bool nothingToFold = (source == extractOp.getVector());
1529 if (nothingToFold || !canFold())
1530 return Value();
1531
1532 // Otherwise, fold by updating the op inplace and return its result.
1533 OpBuilder b(extractOp.getContext());
1534 extractOp.setStaticPosition(
1535 ArrayRef(extractPosition).take_front(extractedRank));
1536 extractOp.getVectorMutable().assign(source);
1537 return extractOp.getResult();
1538}
1539
1540/// Iterate over producing insert and transpose ops until we find a fold.
1541Value ExtractFromInsertTransposeChainState::fold() {
1542 // TODO: Canonicalization for dynamic position not implemented yet.
1543 if (extractOp.hasDynamicPosition())
1544 return Value();
1545
1546 Value valueToExtractFrom = extractOp.getVector();
1547 updateStateForNextIteration(v: valueToExtractFrom);
1548 while (nextInsertOp || nextTransposeOp) {
1549 // Case 1. If we hit a transpose, just compose the map and iterate.
1550 // Invariant: insert + transpose do not change rank, we can always compose.
1551 if (succeeded(result: handleTransposeOp())) {
1552 valueToExtractFrom = nextTransposeOp.getVector();
1553 updateStateForNextIteration(v: valueToExtractFrom);
1554 continue;
1555 }
1556
1557 Value result;
1558 // Case 2: the position match exactly.
1559 if (succeeded(result: handleInsertOpWithMatchingPos(res&: result)))
1560 return result;
1561
1562 // Case 3: if the inserted position is a prefix of extractPosition, we can
1563 // just extract a portion of the source of the insert.
1564 if (succeeded(result: handleInsertOpWithPrefixPos(res&: result)))
1565 return tryToFoldExtractOpInPlace(source: result);
1566
1567 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1568 // values. This is a more difficult case and we bail.
1569 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1570 if (isContainedWithin(a: extractPosition, b: insertedPos) ||
1571 intersectsWhereNonNegative(a: extractPosition, b: insertedPos))
1572 return Value();
1573
1574 // Case 5: No intersection, we forward the extract to insertOp.dest().
1575 valueToExtractFrom = nextInsertOp.getDest();
1576 updateStateForNextIteration(v: valueToExtractFrom);
1577 }
1578 // If after all this we can fold, go for it.
1579 return tryToFoldExtractOpInPlace(source: valueToExtractFrom);
1580}
1581
1582/// Returns true if the operation has a 0-D vector type operand or result.
1583static bool hasZeroDimVectors(Operation *op) {
1584 auto hasZeroDimVectorType = [](Type type) -> bool {
1585 auto vecType = dyn_cast<VectorType>(type);
1586 return vecType && vecType.getRank() == 0;
1587 };
1588
1589 return llvm::any_of(Range: op->getOperandTypes(), P: hasZeroDimVectorType) ||
1590 llvm::any_of(Range: op->getResultTypes(), P: hasZeroDimVectorType);
1591}
1592
1593/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1594static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1595 // TODO: Canonicalization for dynamic position not implemented yet.
1596 if (extractOp.hasDynamicPosition())
1597 return Value();
1598
1599 Operation *defOp = extractOp.getVector().getDefiningOp();
1600 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1601 return Value();
1602
1603 // 0-D vectors not supported.
1604 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1605 if (hasZeroDimVectors(op: defOp))
1606 return Value();
1607
1608 Value source = defOp->getOperand(idx: 0);
1609 if (extractOp.getType() == source.getType())
1610 return source;
1611 auto getRank = [](Type type) {
1612 return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1613 : 0;
1614 };
1615
1616 // If splat or broadcast from a scalar, just return the source scalar.
1617 unsigned broadcastSrcRank = getRank(source.getType());
1618 if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1619 return source;
1620
1621 unsigned extractResultRank = getRank(extractOp.getType());
1622 if (extractResultRank >= broadcastSrcRank)
1623 return Value();
1624 // Check that the dimension of the result haven't been broadcasted.
1625 auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1626 auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1627 if (extractVecType && broadcastVecType &&
1628 extractVecType.getShape() !=
1629 broadcastVecType.getShape().take_back(extractResultRank))
1630 return Value();
1631
1632 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1633 int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1634
1635 // Detect all the positions that come from "dim-1" broadcasting.
1636 // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1637 // extract position to `0` when extracting from the source operand.
1638 llvm::SetVector<int64_t> broadcastedUnitDims =
1639 broadcastOp.computeBroadcastedUnitDims();
1640 SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
1641 int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1642 for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1643 if (broadcastedUnitDims.contains(key: i))
1644 extractPos[i] = 0;
1645 // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1646 // matching extract position when extracting from the source operand.
1647 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1648 extractPos.erase(CS: extractPos.begin(),
1649 CE: std::next(x: extractPos.begin(), n: extractPos.size() - rankDiff));
1650 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1651 OpBuilder b(extractOp.getContext());
1652 extractOp.setOperand(0, source);
1653 extractOp.setStaticPosition(extractPos);
1654 return extractOp.getResult();
1655}
1656
1657// Fold extractOp with source coming from ShapeCast op.
1658static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1659 // TODO: Canonicalization for dynamic position not implemented yet.
1660 if (extractOp.hasDynamicPosition())
1661 return Value();
1662
1663 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1664 if (!shapeCastOp)
1665 return Value();
1666
1667 // 0-D vectors not supported.
1668 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1669 if (hasZeroDimVectors(shapeCastOp))
1670 return Value();
1671
1672 // Get the nth dimension size starting from lowest dimension.
1673 auto getDimReverse = [](VectorType type, int64_t n) {
1674 return type.getShape().take_back(n + 1).front();
1675 };
1676 int64_t destinationRank =
1677 llvm::isa<VectorType>(extractOp.getType())
1678 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1679 : 0;
1680 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1681 return Value();
1682 if (destinationRank > 0) {
1683 auto destinationType =
1684 llvm::cast<VectorType>(extractOp.getResult().getType());
1685 for (int64_t i = 0; i < destinationRank; i++) {
1686 // The lowest dimension of the destination must match the lowest
1687 // dimension of the shapecast op source.
1688 // TODO: This case could be support in a canonicalization pattern.
1689 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1690 getDimReverse(destinationType, i))
1691 return Value();
1692 }
1693 }
1694 // Extract the strides associated with the extract op vector source. Then use
1695 // this to calculate a linearized position for the extract.
1696 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1697 std::reverse(first: extractedPos.begin(), last: extractedPos.end());
1698 SmallVector<int64_t, 4> strides;
1699 int64_t stride = 1;
1700 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1701 strides.push_back(Elt: stride);
1702 stride *=
1703 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1704 }
1705
1706 int64_t position = linearize(offsets: extractedPos, basis: strides);
1707 // Then extract the strides associated to the shapeCast op vector source and
1708 // delinearize the position using those strides.
1709 SmallVector<int64_t, 4> newStrides;
1710 int64_t numDimension =
1711 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1712 stride = 1;
1713 for (int64_t i = 0; i < numDimension; i++) {
1714 newStrides.push_back(Elt: stride);
1715 stride *=
1716 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1717 }
1718 std::reverse(first: newStrides.begin(), last: newStrides.end());
1719 SmallVector<int64_t, 4> newPosition = delinearize(linearIndex: position, strides: newStrides);
1720 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1721 OpBuilder b(extractOp.getContext());
1722 extractOp.setStaticPosition(newPosition);
1723 extractOp.setOperand(0, shapeCastOp.getSource());
1724 return extractOp.getResult();
1725}
1726
1727/// Fold an ExtractOp from ExtractStridedSliceOp.
1728static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1729 // TODO: Canonicalization for dynamic position not implemented yet.
1730 if (extractOp.hasDynamicPosition())
1731 return Value();
1732
1733 auto extractStridedSliceOp =
1734 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1735 if (!extractStridedSliceOp)
1736 return Value();
1737
1738 // 0-D vectors not supported.
1739 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1740 if (hasZeroDimVectors(extractStridedSliceOp))
1741 return Value();
1742
1743 // Return if 'extractStridedSliceOp' has non-unit strides.
1744 if (extractStridedSliceOp.hasNonUnitStrides())
1745 return Value();
1746
1747 // Trim offsets for dimensions fully extracted.
1748 auto sliceOffsets =
1749 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1750 while (!sliceOffsets.empty()) {
1751 size_t lastOffset = sliceOffsets.size() - 1;
1752 if (sliceOffsets.back() != 0 ||
1753 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1754 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1755 break;
1756 sliceOffsets.pop_back();
1757 }
1758 unsigned destinationRank = 0;
1759 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1760 destinationRank = vecType.getRank();
1761 // The dimensions of the result need to be untouched by the
1762 // extractStridedSlice op.
1763 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1764 sliceOffsets.size())
1765 return Value();
1766
1767 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1768 assert(extractedPos.size() >= sliceOffsets.size());
1769 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1770 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1771 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1772
1773 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1774 OpBuilder b(extractOp.getContext());
1775 extractOp.setStaticPosition(extractedPos);
1776 return extractOp.getResult();
1777}
1778
1779/// Fold extract_op fed from a chain of insertStridedSlice ops.
1780static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1781 // TODO: Canonicalization for dynamic position not implemented yet.
1782 if (extractOp.hasDynamicPosition())
1783 return Value();
1784
1785 int64_t destinationRank =
1786 llvm::isa<VectorType>(extractOp.getType())
1787 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1788 : 0;
1789 auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
1790 if (!insertOp)
1791 return Value();
1792
1793 // 0-D vectors not supported.
1794 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1795 if (hasZeroDimVectors(insertOp))
1796 return Value();
1797
1798 while (insertOp) {
1799 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1800 insertOp.getSourceVectorType().getRank();
1801 if (destinationRank > insertOp.getSourceVectorType().getRank())
1802 return Value();
1803 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1804 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1805
1806 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1807 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1808 }))
1809 return Value();
1810 bool disjoint = false;
1811 SmallVector<int64_t, 4> offsetDiffs;
1812 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1813 int64_t start = insertOffsets[dim];
1814 int64_t size =
1815 (dim < insertRankDiff)
1816 ? 1
1817 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1818 int64_t end = start + size;
1819 int64_t offset = extractOffsets[dim];
1820 // Check if the start of the extract offset is in the interval inserted.
1821 if (start <= offset && offset < end) {
1822 if (dim >= insertRankDiff)
1823 offsetDiffs.push_back(Elt: offset - start);
1824 continue;
1825 }
1826 disjoint = true;
1827 break;
1828 }
1829 // The extract element chunk overlap with the vector inserted.
1830 if (!disjoint) {
1831 // If any of the inner dimensions are only partially inserted we have a
1832 // partial overlap.
1833 int64_t srcRankDiff =
1834 insertOp.getSourceVectorType().getRank() - destinationRank;
1835 for (int64_t i = 0; i < destinationRank; i++) {
1836 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1837 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1838 insertRankDiff))
1839 return Value();
1840 }
1841 extractOp.getVectorMutable().assign(insertOp.getSource());
1842 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1843 OpBuilder b(extractOp.getContext());
1844 extractOp.setStaticPosition(offsetDiffs);
1845 return extractOp.getResult();
1846 }
1847 // If the chunk extracted is disjoint from the chunk inserted, keep
1848 // looking in the insert chain.
1849 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1850 }
1851 return Value();
1852}
1853
1854OpFoldResult ExtractOp::fold(FoldAdaptor) {
1855 if (getNumIndices() == 0)
1856 return getVector();
1857 if (succeeded(foldExtractOpFromExtractChain(*this)))
1858 return getResult();
1859 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1860 return res;
1861 if (auto res = foldExtractFromBroadcast(*this))
1862 return res;
1863 if (auto res = foldExtractFromShapeCast(*this))
1864 return res;
1865 if (auto val = foldExtractFromExtractStrided(*this))
1866 return val;
1867 if (auto val = foldExtractStridedOpFromInsertChain(*this))
1868 return val;
1869 return OpFoldResult();
1870}
1871
1872namespace {
1873
1874// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1875class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1876public:
1877 using OpRewritePattern::OpRewritePattern;
1878
1879 LogicalResult matchAndRewrite(ExtractOp extractOp,
1880 PatternRewriter &rewriter) const override {
1881 Operation *defOp = extractOp.getVector().getDefiningOp();
1882 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1883 return failure();
1884
1885 Value source = defOp->getOperand(idx: 0);
1886 if (extractOp.getType() == source.getType())
1887 return failure();
1888 auto getRank = [](Type type) {
1889 return llvm::isa<VectorType>(type)
1890 ? llvm::cast<VectorType>(type).getRank()
1891 : 0;
1892 };
1893 unsigned broadcastSrcRank = getRank(source.getType());
1894 unsigned extractResultRank = getRank(extractOp.getType());
1895 // We only consider the case where the rank of the source is less than or
1896 // equal to the rank of the extract dst. The other cases are handled in the
1897 // folding patterns.
1898 if (extractResultRank < broadcastSrcRank)
1899 return failure();
1900
1901 // Special case if broadcast src is a 0D vector.
1902 if (extractResultRank == 0) {
1903 assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
1904 rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
1905 return success();
1906 }
1907 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1908 extractOp, extractOp.getType(), source);
1909 return success();
1910 }
1911};
1912
1913// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1914class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
1915public:
1916 using OpRewritePattern::OpRewritePattern;
1917
1918 LogicalResult matchAndRewrite(ExtractOp extractOp,
1919 PatternRewriter &rewriter) const override {
1920 // Return if 'ExtractOp' operand is not defined by a splat vector
1921 // ConstantOp.
1922 Value sourceVector = extractOp.getVector();
1923 Attribute vectorCst;
1924 if (!matchPattern(value: sourceVector, pattern: m_Constant(bind_value: &vectorCst)))
1925 return failure();
1926 auto splat = llvm::dyn_cast<SplatElementsAttr>(Val&: vectorCst);
1927 if (!splat)
1928 return failure();
1929 TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
1930 if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1931 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
1932 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1933 return success();
1934 }
1935};
1936
1937// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
1938class ExtractOpNonSplatConstantFolder final
1939 : public OpRewritePattern<ExtractOp> {
1940public:
1941 using OpRewritePattern::OpRewritePattern;
1942
1943 LogicalResult matchAndRewrite(ExtractOp extractOp,
1944 PatternRewriter &rewriter) const override {
1945 // TODO: Canonicalization for dynamic position not implemented yet.
1946 if (extractOp.hasDynamicPosition())
1947 return failure();
1948
1949 // Return if 'ExtractOp' operand is not defined by a compatible vector
1950 // ConstantOp.
1951 Value sourceVector = extractOp.getVector();
1952 Attribute vectorCst;
1953 if (!matchPattern(value: sourceVector, pattern: m_Constant(bind_value: &vectorCst)))
1954 return failure();
1955
1956 auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
1957 if (vecTy.isScalable())
1958 return failure();
1959
1960 // The splat case is handled by `ExtractOpSplatConstantFolder`.
1961 auto dense = llvm::dyn_cast<DenseElementsAttr>(Val&: vectorCst);
1962 if (!dense || dense.isSplat())
1963 return failure();
1964
1965 // Calculate the linearized position of the continuous chunk of elements to
1966 // extract.
1967 llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
1968 copy(extractOp.getStaticPosition(), completePositions.begin());
1969 int64_t elemBeginPosition =
1970 linearize(completePositions, computeStrides(vecTy.getShape()));
1971 auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
1972
1973 TypedAttr newAttr;
1974 if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
1975 SmallVector<Attribute> elementValues(
1976 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
1977 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
1978 } else {
1979 newAttr = *denseValuesBegin;
1980 }
1981
1982 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1983 return success();
1984 }
1985};
1986
1987// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
1988class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
1989public:
1990 using OpRewritePattern::OpRewritePattern;
1991
1992 LogicalResult matchAndRewrite(ExtractOp extractOp,
1993 PatternRewriter &rewriter) const override {
1994 auto createMaskOp =
1995 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
1996 if (!createMaskOp)
1997 return failure();
1998
1999 VectorType extractedMaskType =
2000 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2001
2002 if (!extractedMaskType)
2003 return failure();
2004
2005 auto maskOperands = createMaskOp.getOperands();
2006 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2007 VectorType maskType = createMaskOp.getVectorType();
2008
2009 bool containsUnknownDims = false;
2010 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2011
2012 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2013 dimIdx++) {
2014 int64_t pos = extractOpPos[dimIdx];
2015 Value operand = maskOperands[dimIdx];
2016 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2017 if (!constantOp) {
2018 // Bounds of this dim unknown.
2019 containsUnknownDims = true;
2020 continue;
2021 }
2022
2023 int64_t createMaskBound =
2024 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2025
2026 if (pos != ShapedType::kDynamic) {
2027 // If any position is outside the range from the `create_mask`, then the
2028 // extracted mask will be all-false.
2029 allFalse |= pos >= createMaskBound;
2030 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2031 // This dim is not all-true and since this is a dynamic index we don't
2032 // know if the extraction is within the true or false region.
2033 // Note: Zero dims have already handled via getMaskFormat().
2034 containsUnknownDims = true;
2035 }
2036 }
2037
2038 if (allFalse) {
2039 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2040 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2041 } else if (!containsUnknownDims) {
2042 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2043 extractOp, extractedMaskType,
2044 maskOperands.drop_front(extractOpPos.size()));
2045 } else {
2046 return failure();
2047 }
2048 return success();
2049 }
2050};
2051
2052// Folds extract(shape_cast(..)) into shape_cast when the total element count
2053// does not change.
2054LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2055 PatternRewriter &rewriter) {
2056 auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2057 if (!castOp)
2058 return failure();
2059
2060 VectorType sourceType = castOp.getSourceVectorType();
2061 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2062 if (!targetType)
2063 return failure();
2064
2065 if (sourceType.getNumElements() != targetType.getNumElements())
2066 return failure();
2067
2068 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2069 castOp.getSource());
2070 return success();
2071}
2072
2073} // namespace
2074
2075void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2076 MLIRContext *context) {
2077 results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2078 ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2079 results.add(foldExtractFromShapeCastToShapeCast);
2080}
2081
2082static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
2083 SmallVectorImpl<int64_t> &results) {
2084 for (auto attr : arrayAttr)
2085 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2086}
2087
2088//===----------------------------------------------------------------------===//
2089// FmaOp
2090//===----------------------------------------------------------------------===//
2091
2092std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2093 return llvm::to_vector<4>(getVectorType().getShape());
2094}
2095
2096//===----------------------------------------------------------------------===//
2097// BroadcastOp
2098//===----------------------------------------------------------------------===//
2099
2100/// Return the dimensions of the result vector that were formerly ones in the
2101/// source tensor and thus correspond to "dim-1" broadcasting.
2102static llvm::SetVector<int64_t>
2103computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
2104 ArrayRef<int64_t> dstShape) {
2105 int64_t rankDiff = dstShape.size() - srcShape.size();
2106 int64_t dstDim = rankDiff;
2107 llvm::SetVector<int64_t> res;
2108 for (auto [s1, s2] :
2109 llvm::zip_equal(t&: srcShape, u: dstShape.drop_front(N: rankDiff))) {
2110 if (s1 != s2) {
2111 assert(s1 == 1 && "expected dim-1 broadcasting");
2112 res.insert(X: dstDim);
2113 }
2114 ++dstDim;
2115 }
2116 return res;
2117}
2118
2119llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2120 // Scalar broadcast is without any unit dim broadcast.
2121 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2122 if (!srcVectorType)
2123 return {};
2124 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2125 getResultVectorType().getShape());
2126}
2127
2128/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2129/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2130/// This requires (and asserts) that the broadcast is free of dim-1
2131/// broadcasting.
2132/// Since vector.broadcast only allows expanding leading dimensions, an extra
2133/// vector.transpose may be inserted to make the broadcast possible.
2134/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2135/// the helper will assert. This means:
2136/// 1. `dstShape` must not be empty.
2137/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2138/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2139// must match the `value` shape.
2140Value BroadcastOp::createOrFoldBroadcastOp(
2141 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2142 const llvm::SetVector<int64_t> &broadcastedDims) {
2143 assert(!dstShape.empty() && "unexpected empty dst shape");
2144
2145 // Well-formedness check.
2146 SmallVector<int64_t> checkShape;
2147 for (int i = 0, e = dstShape.size(); i < e; ++i) {
2148 if (broadcastedDims.contains(i))
2149 continue;
2150 checkShape.push_back(dstShape[i]);
2151 }
2152 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2153 "ill-formed broadcastedDims contains values not confined to "
2154 "destVectorShape");
2155
2156 Location loc = value.getLoc();
2157 Type elementType = getElementTypeOrSelf(value.getType());
2158 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2159 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2160
2161 // Step 2. If scalar -> dstShape broadcast, just do it.
2162 if (!srcVectorType) {
2163 assert(checkShape.empty() &&
2164 "ill-formed createOrFoldBroadcastOp arguments");
2165 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2166 }
2167
2168 assert(srcVectorType.getShape().equals(checkShape) &&
2169 "ill-formed createOrFoldBroadcastOp arguments");
2170
2171 // Step 3. Since vector.broadcast only allows creating leading dims,
2172 // vector -> dstShape broadcast may require a transpose.
2173 // Traverse the dims in order and construct:
2174 // 1. The leading entries of the broadcastShape that is guaranteed to be
2175 // achievable by a simple broadcast.
2176 // 2. The induced permutation for the subsequent vector.transpose that will
2177 // bring us from `broadcastShape` back to he desired `dstShape`.
2178 // If the induced permutation is not the identity, create a vector.transpose.
2179 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2180 broadcastShape.reserve(dstShape.size());
2181 // Consider the example:
2182 // srcShape = 2x4
2183 // dstShape = 1x2x3x4x5
2184 // broadcastedDims = [0, 2, 4]
2185 //
2186 // We want to build:
2187 // broadcastShape = 1x3x5x2x4
2188 // permutation = [0, 2, 4, 1, 3]
2189 // ---V--- -----V-----
2190 // leading broadcast part src shape part
2191 //
2192 // Note that the trailing dims of broadcastShape are exactly the srcShape
2193 // by construction.
2194 // nextSrcShapeDim is used to keep track of where in the permutation the
2195 // "src shape part" occurs.
2196 int64_t nextSrcShapeDim = broadcastedDims.size();
2197 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2198 if (broadcastedDims.contains(i)) {
2199 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2200 // bring it to the head of the broadcastShape.
2201 // It will need to be permuted back from `broadcastShape.size() - 1` into
2202 // position `i`.
2203 broadcastShape.push_back(dstShape[i]);
2204 permutation[i] = broadcastShape.size() - 1;
2205 } else {
2206 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2207 // shape and needs to be permuted into position `i`.
2208 // Don't touch `broadcastShape` here, the whole srcShape will be
2209 // appended after.
2210 permutation[i] = nextSrcShapeDim++;
2211 }
2212 }
2213 // 3.c. Append the srcShape.
2214 llvm::append_range(broadcastShape, srcVectorType.getShape());
2215
2216 // Ensure there are no dim-1 broadcasts.
2217 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2218 .empty() &&
2219 "unexpected dim-1 broadcast");
2220
2221 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2222 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2223 vector::BroadcastableToResult::Success &&
2224 "must be broadcastable");
2225 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2226 // Step 4. If we find any dimension that indeed needs to be permuted,
2227 // immediately return a new vector.transpose.
2228 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2229 if (permutation[i] != i)
2230 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2231 // Otherwise return res.
2232 return res;
2233}
2234
2235BroadcastableToResult
2236mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
2237 std::pair<int, int> *mismatchingDims) {
2238 // Broadcast scalar to vector of the same element type.
2239 if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
2240 getElementTypeOrSelf(type: srcType) == getElementTypeOrSelf(dstVectorType))
2241 return BroadcastableToResult::Success;
2242 // From now on, only vectors broadcast.
2243 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2244 if (!srcVectorType)
2245 return BroadcastableToResult::SourceTypeNotAVector;
2246
2247 int64_t srcRank = srcVectorType.getRank();
2248 int64_t dstRank = dstVectorType.getRank();
2249 if (srcRank > dstRank)
2250 return BroadcastableToResult::SourceRankHigher;
2251 // Source has an exact match or singleton value for all trailing dimensions
2252 // (all leading dimensions are simply duplicated).
2253 int64_t lead = dstRank - srcRank;
2254 for (int64_t r = 0; r < srcRank; ++r) {
2255 int64_t srcDim = srcVectorType.getDimSize(r);
2256 int64_t dstDim = dstVectorType.getDimSize(lead + r);
2257 if (srcDim != 1 && srcDim != dstDim) {
2258 if (mismatchingDims) {
2259 mismatchingDims->first = srcDim;
2260 mismatchingDims->second = dstDim;
2261 }
2262 return BroadcastableToResult::DimensionMismatch;
2263 }
2264 }
2265
2266 return BroadcastableToResult::Success;
2267}
2268
2269LogicalResult BroadcastOp::verify() {
2270 std::pair<int, int> mismatchingDims;
2271 BroadcastableToResult res = isBroadcastableTo(
2272 getSourceType(), getResultVectorType(), &mismatchingDims);
2273 if (res == BroadcastableToResult::Success)
2274 return success();
2275 if (res == BroadcastableToResult::SourceRankHigher)
2276 return emitOpError("source rank higher than destination rank");
2277 if (res == BroadcastableToResult::DimensionMismatch)
2278 return emitOpError("dimension mismatch (")
2279 << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
2280 if (res == BroadcastableToResult::SourceTypeNotAVector)
2281 return emitOpError("source type is not a vector");
2282 llvm_unreachable("unexpected vector.broadcast op error");
2283}
2284
2285OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
2286 if (getSourceType() == getResultVectorType())
2287 return getSource();
2288 if (!adaptor.getSource())
2289 return {};
2290 auto vectorType = getResultVectorType();
2291 if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
2292 return DenseElementsAttr::get(vectorType, adaptor.getSource());
2293 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
2294 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
2295 return {};
2296}
2297
2298namespace {
2299
2300// Fold broadcast1(broadcast2(x)) into broadcast1(x).
2301struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
2302 using OpRewritePattern::OpRewritePattern;
2303
2304 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
2305 PatternRewriter &rewriter) const override {
2306 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
2307 if (!srcBroadcast)
2308 return failure();
2309 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
2310 broadcastOp.getResultVectorType(),
2311 srcBroadcast.getSource());
2312 return success();
2313 }
2314};
2315} // namespace
2316
2317void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2318 MLIRContext *context) {
2319 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
2320 // calling `populateCastAwayVectorLeadingOneDimPatterns`
2321 results.add<BroadcastFolder>(context);
2322}
2323
2324//===----------------------------------------------------------------------===//
2325// ShuffleOp
2326//===----------------------------------------------------------------------===//
2327
2328void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
2329 Value v2, ArrayRef<int64_t> mask) {
2330 build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
2331}
2332
2333LogicalResult ShuffleOp::verify() {
2334 VectorType resultType = getResultVectorType();
2335 VectorType v1Type = getV1VectorType();
2336 VectorType v2Type = getV2VectorType();
2337 // Verify ranks.
2338 int64_t resRank = resultType.getRank();
2339 int64_t v1Rank = v1Type.getRank();
2340 int64_t v2Rank = v2Type.getRank();
2341 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
2342 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
2343 if (!wellFormed0DCase && !wellFormedNDCase)
2344 return emitOpError("rank mismatch");
2345
2346 // Verify all but leading dimension sizes.
2347 for (int64_t r = 1; r < v1Rank; ++r) {
2348 int64_t resDim = resultType.getDimSize(r);
2349 int64_t v1Dim = v1Type.getDimSize(r);
2350 int64_t v2Dim = v2Type.getDimSize(r);
2351 if (resDim != v1Dim || v1Dim != v2Dim)
2352 return emitOpError("dimension mismatch");
2353 }
2354 // Verify mask length.
2355 auto maskAttr = getMask().getValue();
2356 int64_t maskLength = maskAttr.size();
2357 if (maskLength <= 0)
2358 return emitOpError("invalid mask length");
2359 if (maskLength != resultType.getDimSize(0))
2360 return emitOpError("mask length mismatch");
2361 // Verify all indices.
2362 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
2363 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2364 for (const auto &en : llvm::enumerate(maskAttr)) {
2365 auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
2366 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
2367 return emitOpError("mask index #") << (en.index() + 1) << " out of range";
2368 }
2369 return success();
2370}
2371
2372LogicalResult
2373ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2374 ShuffleOp::Adaptor adaptor,
2375 SmallVectorImpl<Type> &inferredReturnTypes) {
2376 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
2377 auto v1Rank = v1Type.getRank();
2378 // Construct resulting type: leading dimension matches mask
2379 // length, all trailing dimensions match the operands.
2380 SmallVector<int64_t, 4> shape;
2381 shape.reserve(v1Rank);
2382 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
2383 // In the 0-D case there is no trailing shape to append.
2384 if (v1Rank > 0)
2385 llvm::append_range(shape, v1Type.getShape().drop_front());
2386 inferredReturnTypes.push_back(
2387 VectorType::get(shape, v1Type.getElementType()));
2388 return success();
2389}
2390
2391static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
2392 uint64_t expected = begin;
2393 return idxArr.size() == width &&
2394 llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
2395 [&expected](auto attr) {
2396 return attr.getZExtValue() == expected++;
2397 });
2398}
2399
2400OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2401 VectorType v1Type = getV1VectorType();
2402 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2403 // but must be a canonicalization into a vector.broadcast.
2404 if (v1Type.getRank() == 0)
2405 return {};
2406
2407 // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2408 if (!v1Type.isScalable() &&
2409 isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
2410 return getV1();
2411 // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2412 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2413 isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
2414 getV2VectorType().getDimSize(0)))
2415 return getV2();
2416
2417 Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2();
2418 if (!lhs || !rhs)
2419 return {};
2420
2421 auto lhsType =
2422 llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
2423 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2424 // manipulation.
2425 if (lhsType.getRank() != 1)
2426 return {};
2427 int64_t lhsSize = lhsType.getDimSize(0);
2428
2429 SmallVector<Attribute> results;
2430 auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
2431 auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
2432 for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
2433 int64_t i = index.getZExtValue();
2434 if (i >= lhsSize) {
2435 results.push_back(rhsElements[i - lhsSize]);
2436 } else {
2437 results.push_back(lhsElements[i]);
2438 }
2439 }
2440
2441 return DenseElementsAttr::get(getResultVectorType(), results);
2442}
2443
2444namespace {
2445
2446// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2447// to a broadcast.
2448struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2449 using OpRewritePattern::OpRewritePattern;
2450
2451 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2452 PatternRewriter &rewriter) const override {
2453 VectorType v1VectorType = shuffleOp.getV1VectorType();
2454 ArrayAttr mask = shuffleOp.getMask();
2455 if (v1VectorType.getRank() > 0)
2456 return failure();
2457 if (mask.size() != 1)
2458 return failure();
2459 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2460 if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
2461 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2462 shuffleOp.getV1());
2463 else
2464 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2465 shuffleOp.getV2());
2466 return success();
2467 }
2468};
2469
2470/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2471class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2472public:
2473 using OpRewritePattern::OpRewritePattern;
2474
2475 LogicalResult matchAndRewrite(ShuffleOp op,
2476 PatternRewriter &rewriter) const override {
2477 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2478 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2479
2480 if (!v1Splat || !v2Splat)
2481 return failure();
2482
2483 if (v1Splat.getInput() != v2Splat.getInput())
2484 return failure();
2485
2486 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2487 return success();
2488 }
2489};
2490
2491/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2492/// vector.interleave.
2493class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2494public:
2495 using OpRewritePattern::OpRewritePattern;
2496
2497 LogicalResult matchAndRewrite(ShuffleOp op,
2498 PatternRewriter &rewriter) const override {
2499 VectorType resultType = op.getResultVectorType();
2500 if (resultType.isScalable())
2501 return rewriter.notifyMatchFailure(
2502 op, "ShuffleOp can't represent a scalable interleave");
2503
2504 if (resultType.getRank() != 1)
2505 return rewriter.notifyMatchFailure(
2506 op, "ShuffleOp can't represent an n-D interleave");
2507
2508 VectorType sourceType = op.getV1VectorType();
2509 if (sourceType != op.getV2VectorType() ||
2510 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2511 return rewriter.notifyMatchFailure(
2512 op, "ShuffleOp types don't match an interleave");
2513 }
2514
2515 ArrayAttr shuffleMask = op.getMask();
2516 int64_t resultVectorSize = resultType.getNumElements();
2517 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2518 int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
2519 int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
2520 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2521 return rewriter.notifyMatchFailure(op,
2522 "ShuffleOp mask not interleaving");
2523 }
2524
2525 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2526 return success();
2527 }
2528};
2529
2530} // namespace
2531
2532void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2533 MLIRContext *context) {
2534 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2535 context);
2536}
2537
2538//===----------------------------------------------------------------------===//
2539// InsertElementOp
2540//===----------------------------------------------------------------------===//
2541
2542void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2543 Value source, Value dest) {
2544 build(builder, result, source, dest, {});
2545}
2546
2547LogicalResult InsertElementOp::verify() {
2548 auto dstVectorType = getDestVectorType();
2549 if (dstVectorType.getRank() == 0) {
2550 if (getPosition())
2551 return emitOpError("expected position to be empty with 0-D vector");
2552 return success();
2553 }
2554 if (dstVectorType.getRank() != 1)
2555 return emitOpError("unexpected >1 vector rank");
2556 if (!getPosition())
2557 return emitOpError("expected position for 1-D vector");
2558 return success();
2559}
2560
2561OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
2562 // Skip the 0-D vector here.
2563 if (!adaptor.getPosition())
2564 return {};
2565
2566 auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
2567 auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
2568 auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
2569 if (!src || !dst || !pos)
2570 return {};
2571
2572 if (src.getType() != getDestVectorType().getElementType())
2573 return {};
2574
2575 auto dstElements = dst.getValues<Attribute>();
2576
2577 SmallVector<Attribute> results(dstElements);
2578
2579 uint64_t posIdx = pos.getInt();
2580 if (posIdx >= results.size())
2581 return {};
2582 results[posIdx] = src;
2583
2584 return DenseElementsAttr::get(getDestVectorType(), results);
2585}
2586
2587//===----------------------------------------------------------------------===//
2588// InsertOp
2589//===----------------------------------------------------------------------===//
2590
2591void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2592 Value source, Value dest, int64_t position) {
2593 build(builder, result, source, dest, ArrayRef<int64_t>{position});
2594}
2595
2596void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2597 Value source, Value dest, OpFoldResult position) {
2598 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
2599}
2600
2601void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2602 Value source, Value dest,
2603 ArrayRef<int64_t> position) {
2604 SmallVector<OpFoldResult> posVals;
2605 posVals.reserve(position.size());
2606 llvm::transform(position, std::back_inserter(posVals),
2607 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
2608 build(builder, result, source, dest, posVals);
2609}
2610
2611void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2612 Value source, Value dest,
2613 ArrayRef<OpFoldResult> position) {
2614 SmallVector<int64_t> staticPos;
2615 SmallVector<Value> dynamicPos;
2616 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
2617 build(builder, result, source, dest, dynamicPos,
2618 builder.getDenseI64ArrayAttr(staticPos));
2619}
2620
2621LogicalResult InsertOp::verify() {
2622 SmallVector<OpFoldResult> position = getMixedPosition();
2623 auto destVectorType = getDestVectorType();
2624 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
2625 return emitOpError(
2626 "expected position attribute of rank no greater than dest vector rank");
2627 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2628 if (srcVectorType &&
2629 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
2630 static_cast<unsigned>(destVectorType.getRank())))
2631 return emitOpError("expected position attribute rank + source rank to "
2632 "match dest vector rank");
2633 if (!srcVectorType &&
2634 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
2635 return emitOpError(
2636 "expected position attribute rank to match the dest vector rank");
2637 for (auto [idx, pos] : llvm::enumerate(position)) {
2638 if (auto attr = pos.dyn_cast<Attribute>()) {
2639 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
2640 if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
2641 return emitOpError("expected position attribute #")
2642 << (idx + 1)
2643 << " to be a non-negative integer smaller than the "
2644 "corresponding "
2645 "dest vector dimension";
2646 }
2647 }
2648 }
2649 return success();
2650}
2651
2652namespace {
2653
2654// If insertOp is only inserting unit dimensions it can be transformed to a
2655// broadcast.
2656class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2657public:
2658 using OpRewritePattern::OpRewritePattern;
2659
2660 LogicalResult matchAndRewrite(InsertOp insertOp,
2661 PatternRewriter &rewriter) const override {
2662 auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
2663 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2664 srcVecType.getNumElements())
2665 return failure();
2666 rewriter.replaceOpWithNewOp<BroadcastOp>(
2667 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2668 return success();
2669 }
2670};
2671
2672/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2673class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2674public:
2675 using OpRewritePattern::OpRewritePattern;
2676
2677 LogicalResult matchAndRewrite(InsertOp op,
2678 PatternRewriter &rewriter) const override {
2679 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2680 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2681
2682 if (!srcSplat || !dstSplat)
2683 return failure();
2684
2685 if (srcSplat.getInput() != dstSplat.getInput())
2686 return failure();
2687
2688 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2689 return success();
2690 }
2691};
2692
2693// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2694class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2695public:
2696 using OpRewritePattern::OpRewritePattern;
2697
2698 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2699 // unless the source vector constant has a single use.
2700 static constexpr int64_t vectorSizeFoldThreshold = 256;
2701
2702 LogicalResult matchAndRewrite(InsertOp op,
2703 PatternRewriter &rewriter) const override {
2704 // TODO: Canonicalization for dynamic position not implemented yet.
2705 if (op.hasDynamicPosition())
2706 return failure();
2707
2708 // Return if 'InsertOp' operand is not defined by a compatible vector
2709 // ConstantOp.
2710 TypedValue<VectorType> destVector = op.getDest();
2711 Attribute vectorDestCst;
2712 if (!matchPattern(value: destVector, pattern: m_Constant(bind_value: &vectorDestCst)))
2713 return failure();
2714
2715 VectorType destTy = destVector.getType();
2716 if (destTy.isScalable())
2717 return failure();
2718
2719 // Make sure we do not create too many large constants.
2720 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2721 !destVector.hasOneUse())
2722 return failure();
2723
2724 auto denseDest = llvm::cast<DenseElementsAttr>(Val&: vectorDestCst);
2725
2726 Value sourceValue = op.getSource();
2727 Attribute sourceCst;
2728 if (!matchPattern(value: sourceValue, pattern: m_Constant(bind_value: &sourceCst)))
2729 return failure();
2730
2731 // Calculate the linearized position of the continuous chunk of elements to
2732 // insert.
2733 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
2734 copy(op.getStaticPosition(), completePositions.begin());
2735 int64_t insertBeginPosition =
2736 linearize(completePositions, computeStrides(destTy.getShape()));
2737
2738 SmallVector<Attribute> insertedValues;
2739 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(Val&: sourceCst))
2740 llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
2741 else
2742 insertedValues.push_back(Elt: sourceCst);
2743
2744 auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
2745 copy(insertedValues, allValues.begin() + insertBeginPosition);
2746 auto newAttr = DenseElementsAttr::get(destTy, allValues);
2747
2748 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
2749 return success();
2750 }
2751};
2752
2753} // namespace
2754
2755void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
2756 MLIRContext *context) {
2757 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
2758 InsertOpConstantFolder>(context);
2759}
2760
2761// Eliminates insert operations that produce values identical to their source
2762// value. This happens when the source and destination vectors have identical
2763// sizes.
2764OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
2765 if (getNumIndices() == 0)
2766 return getSource();
2767 return {};
2768}
2769
2770//===----------------------------------------------------------------------===//
2771// InsertStridedSliceOp
2772//===----------------------------------------------------------------------===//
2773
2774void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2775 Value source, Value dest,
2776 ArrayRef<int64_t> offsets,
2777 ArrayRef<int64_t> strides) {
2778 result.addOperands({source, dest});
2779 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2780 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2781 result.addTypes(dest.getType());
2782 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
2783 offsetsAttr);
2784 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
2785 stridesAttr);
2786}
2787
2788// TODO: Should be moved to Tablegen ConfinedAttr attributes.
2789template <typename OpType>
2790static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
2791 ArrayAttr arrayAttr,
2792 ArrayRef<int64_t> shape,
2793 StringRef attrName) {
2794 if (arrayAttr.size() > shape.size())
2795 return op.emitOpError("expected ")
2796 << attrName << " attribute of rank no greater than vector rank";
2797 return success();
2798}
2799
2800// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2801// interval. If `halfOpen` is true then the admissible interval is [min, max).
2802// Otherwise, the admissible interval is [min, max].
2803template <typename OpType>
2804static LogicalResult
2805isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
2806 int64_t max, StringRef attrName,
2807 bool halfOpen = true) {
2808 for (auto attr : arrayAttr) {
2809 auto val = llvm::cast<IntegerAttr>(attr).getInt();
2810 auto upper = max;
2811 if (!halfOpen)
2812 upper += 1;
2813 if (val < min || val >= upper)
2814 return op.emitOpError("expected ") << attrName << " to be confined to ["
2815 << min << ", " << upper << ")";
2816 }
2817 return success();
2818}
2819
2820// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2821// interval. If `halfOpen` is true then the admissible interval is [min, max).
2822// Otherwise, the admissible interval is [min, max].
2823template <typename OpType>
2824static LogicalResult
2825isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
2826 ArrayRef<int64_t> shape, StringRef attrName,
2827 bool halfOpen = true, int64_t min = 0) {
2828 for (auto [index, attrDimPair] :
2829 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
2830 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
2831 int64_t max = std::get<1>(attrDimPair);
2832 if (!halfOpen)
2833 max += 1;
2834 if (val < min || val >= max)
2835 return op.emitOpError("expected ")
2836 << attrName << " dimension " << index << " to be confined to ["
2837 << min << ", " << max << ")";
2838 }
2839 return success();
2840}
2841
2842// Returns true if, for all indices i = 0..shape.size()-1, val is in the
2843// [min, max} interval:
2844// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
2845// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
2846// the admissible interval is [min, max].
2847template <typename OpType>
2848static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
2849 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2850 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
2851 bool halfOpen = true, int64_t min = 1) {
2852 assert(arrayAttr1.size() <= shape.size());
2853 assert(arrayAttr2.size() <= shape.size());
2854 for (auto [index, it] :
2855 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
2856 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
2857 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
2858 int64_t max = std::get<2>(it);
2859 if (!halfOpen)
2860 max += 1;
2861 if (val1 + val2 < 0 || val1 + val2 >= max)
2862 return op.emitOpError("expected sum(")
2863 << attrName1 << ", " << attrName2 << ") dimension " << index
2864 << " to be confined to [" << min << ", " << max << ")";
2865 }
2866 return success();
2867}
2868
2869static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
2870 MLIRContext *context) {
2871 auto attrs = llvm::map_range(C&: values, F: [context](int64_t v) -> Attribute {
2872 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
2873 });
2874 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
2875}
2876
2877LogicalResult InsertStridedSliceOp::verify() {
2878 auto sourceVectorType = getSourceVectorType();
2879 auto destVectorType = getDestVectorType();
2880 auto offsets = getOffsetsAttr();
2881 auto strides = getStridesAttr();
2882 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
2883 return emitOpError(
2884 "expected offsets of same size as destination vector rank");
2885 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
2886 return emitOpError("expected strides of same size as source vector rank");
2887 if (sourceVectorType.getRank() > destVectorType.getRank())
2888 return emitOpError(
2889 "expected source rank to be no greater than destination rank");
2890
2891 auto sourceShape = sourceVectorType.getShape();
2892 auto destShape = destVectorType.getShape();
2893 SmallVector<int64_t, 4> sourceShapeAsDestShape(
2894 destShape.size() - sourceShape.size(), 0);
2895 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2896 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2897 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2898 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
2899 offName)) ||
2900 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
2901 /*max=*/1, stridesName,
2902 /*halfOpen=*/false)) ||
2903 failed(isSumOfIntegerArrayAttrConfinedToShape(
2904 *this, offsets,
2905 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
2906 offName, "source vector shape",
2907 /*halfOpen=*/false, /*min=*/1)))
2908 return failure();
2909
2910 unsigned rankDiff = destShape.size() - sourceShape.size();
2911 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
2912 if (sourceVectorType.getScalableDims()[idx] !=
2913 destVectorType.getScalableDims()[idx + rankDiff]) {
2914 return emitOpError("mismatching scalable flags (at source vector idx=")
2915 << idx << ")";
2916 }
2917 if (sourceVectorType.getScalableDims()[idx]) {
2918 auto sourceSize = sourceShape[idx];
2919 auto destSize = destShape[idx + rankDiff];
2920 if (sourceSize != destSize) {
2921 return emitOpError("expected size at idx=")
2922 << idx
2923 << (" to match the corresponding base size from the input "
2924 "vector (")
2925 << sourceSize << (" vs ") << destSize << (")");
2926 }
2927 }
2928 }
2929
2930 return success();
2931}
2932
2933namespace {
2934/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
2935/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
2936class FoldInsertStridedSliceSplat final
2937 : public OpRewritePattern<InsertStridedSliceOp> {
2938public:
2939 using OpRewritePattern::OpRewritePattern;
2940
2941 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2942 PatternRewriter &rewriter) const override {
2943 auto srcSplatOp =
2944 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
2945 auto destSplatOp =
2946 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
2947
2948 if (!srcSplatOp || !destSplatOp)
2949 return failure();
2950
2951 if (srcSplatOp.getInput() != destSplatOp.getInput())
2952 return failure();
2953
2954 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2955 return success();
2956 }
2957};
2958
2959/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
2960/// to dst.
2961class FoldInsertStridedSliceOfExtract final
2962 : public OpRewritePattern<InsertStridedSliceOp> {
2963public:
2964 using OpRewritePattern::OpRewritePattern;
2965
2966 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2967 PatternRewriter &rewriter) const override {
2968 auto extractStridedSliceOp =
2969 insertStridedSliceOp.getSource()
2970 .getDefiningOp<vector::ExtractStridedSliceOp>();
2971
2972 if (!extractStridedSliceOp)
2973 return failure();
2974
2975 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
2976 return failure();
2977
2978 // Check if have the same strides and offsets.
2979 if (extractStridedSliceOp.getStrides() !=
2980 insertStridedSliceOp.getStrides() ||
2981 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
2982 return failure();
2983
2984 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2985 return success();
2986 }
2987};
2988
2989// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
2990// ConstantOp.
2991class InsertStridedSliceConstantFolder final
2992 : public OpRewritePattern<InsertStridedSliceOp> {
2993public:
2994 using OpRewritePattern::OpRewritePattern;
2995
2996 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2997 // unless the source vector constant has a single use.
2998 static constexpr int64_t vectorSizeFoldThreshold = 256;
2999
3000 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3001 PatternRewriter &rewriter) const override {
3002 // Return if 'InsertOp' operand is not defined by a compatible vector
3003 // ConstantOp.
3004 TypedValue<VectorType> destVector = op.getDest();
3005 Attribute vectorDestCst;
3006 if (!matchPattern(value: destVector, pattern: m_Constant(bind_value: &vectorDestCst)))
3007 return failure();
3008
3009 VectorType destTy = destVector.getType();
3010 if (destTy.isScalable())
3011 return failure();
3012
3013 // Make sure we do not create too many large constants.
3014 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3015 !destVector.hasOneUse())
3016 return failure();
3017
3018 auto denseDest = llvm::cast<DenseElementsAttr>(Val&: vectorDestCst);
3019
3020 TypedValue<VectorType> sourceValue = op.getSource();
3021 Attribute sourceCst;
3022 if (!matchPattern(value: sourceValue, pattern: m_Constant(bind_value: &sourceCst)))
3023 return failure();
3024
3025 // TODO: Handle non-unit strides when they become available.
3026 if (op.hasNonUnitStrides())
3027 return failure();
3028
3029 VectorType sliceVecTy = sourceValue.getType();
3030 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3031 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3032 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3033 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3034
3035 // Calcualte the destination element indices by enumerating all slice
3036 // positions within the destination and linearizing them. The enumeration
3037 // order is lexicographic which yields a sequence of monotonically
3038 // increasing linearized position indices.
3039 // Because the destination may have higher dimensionality then the slice,
3040 // we keep track of two overlapping sets of positions and offsets.
3041 auto denseSlice = llvm::cast<DenseElementsAttr>(Val&: sourceCst);
3042 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3043 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3044 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3045 MutableArrayRef<int64_t> currSlicePosition(
3046 currDestPosition.begin() + rankDifference, currDestPosition.end());
3047 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3048 offsets.end());
3049 do {
3050 int64_t linearizedPosition = linearize(offsets: currDestPosition, basis: destStrides);
3051 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3052 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3053 "Invalid slice element");
3054 newValues[linearizedPosition] = *sliceValuesIt;
3055 ++sliceValuesIt;
3056 } while (succeeded(
3057 result: incSlicePosition(position: currSlicePosition, shape: sliceShape, offsets: sliceOffsets)));
3058
3059 auto newAttr = DenseElementsAttr::get(destTy, newValues);
3060 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3061 return success();
3062 }
3063};
3064
3065} // namespace
3066
3067void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3068 RewritePatternSet &results, MLIRContext *context) {
3069 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
3070 InsertStridedSliceConstantFolder>(context);
3071}
3072
3073OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
3074 if (getSourceVectorType() == getDestVectorType())
3075 return getSource();
3076 return {};
3077}
3078
3079//===----------------------------------------------------------------------===//
3080// OuterProductOp
3081//===----------------------------------------------------------------------===//
3082
3083/// Build an op without mask, use the type of `acc` as the return type.
3084void OuterProductOp::build(OpBuilder &builder, OperationState &result,
3085 Value lhs, Value rhs, Value acc) {
3086 result.addOperands({lhs, rhs, acc});
3087 result.addTypes(acc.getType());
3088}
3089
3090void OuterProductOp::print(OpAsmPrinter &p) {
3091 p << " " << getLhs() << ", " << getRhs();
3092 if (getAcc()) {
3093 p << ", " << getAcc();
3094 p.printOptionalAttrDict((*this)->getAttrs());
3095 }
3096 p << " : " << getLhs().getType() << ", " << getRhs().getType();
3097}
3098
3099ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
3100 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
3101 Type tLHS, tRHS;
3102 if (parser.parseOperandList(operandsInfo) ||
3103 parser.parseOptionalAttrDict(result.attributes) ||
3104 parser.parseColonType(tLHS) || parser.parseComma() ||
3105 parser.parseType(tRHS))
3106 return failure();
3107 if (operandsInfo.size() < 2)
3108 return parser.emitError(parser.getNameLoc(),
3109 "expected at least 2 operands");
3110 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
3111 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
3112 if (!vLHS)
3113 return parser.emitError(parser.getNameLoc(),
3114 "expected vector type for operand #1");
3115
3116 VectorType resType;
3117 if (vRHS) {
3118 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
3119 vRHS.getScalableDims()[0]};
3120 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
3121 vLHS.getElementType(), scalableDimsRes);
3122 } else {
3123 // Scalar RHS operand
3124 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
3125 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
3126 scalableDimsRes);
3127 }
3128
3129 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
3130 result.attributes.append(
3131 OuterProductOp::getKindAttrName(result.name),
3132 CombiningKindAttr::get(result.getContext(),
3133 OuterProductOp::getDefaultKind()));
3134 }
3135
3136 return failure(
3137 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
3138 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
3139 (operandsInfo.size() > 2 &&
3140 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
3141 parser.addTypeToList(resType, result.types));
3142}
3143
3144LogicalResult OuterProductOp::verify() {
3145 Type tRHS = getOperandTypeRHS();
3146 VectorType vLHS = getOperandVectorTypeLHS(),
3147 vRHS = llvm::dyn_cast<VectorType>(tRHS),
3148 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
3149
3150 if (vLHS.getRank() != 1)
3151 return emitOpError("expected 1-d vector for operand #1");
3152
3153 if (vRHS) {
3154 // Proper OUTER operation.
3155 if (vRHS.getRank() != 1)
3156 return emitOpError("expected 1-d vector for operand #2");
3157 if (vRES.getRank() != 2)
3158 return emitOpError("expected 2-d vector result");
3159 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3160 return emitOpError("expected #1 operand dim to match result dim #1");
3161 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
3162 return emitOpError("expected #2 operand dim to match result dim #2");
3163 if (vLHS.isScalable() && !vRHS.isScalable()) {
3164 // This restriction reflects what's currently supported in terms of
3165 // scalable vectors. However, we could relax this if there's a use case.
3166 return emitOpError(
3167 "expected either both or only #2 operand dim to be scalable");
3168 }
3169 } else {
3170 // An AXPY operation.
3171 if (vRES.getRank() != 1)
3172 return emitOpError("expected 1-d vector result");
3173 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
3174 return emitOpError("expected #1 operand dim to match result dim #1");
3175 }
3176
3177 if (vACC && vACC != vRES)
3178 return emitOpError("expected operand #3 of same type as result type");
3179
3180 // Verify supported combining kind.
3181 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
3182 return emitOpError("unsupported outerproduct type");
3183
3184 return success();
3185}
3186
3187// MaskableOpInterface methods.
3188
3189/// Returns the mask type expected by this operation. Mostly used for
3190/// verification purposes. It requires the operation to be vectorized."
3191Type OuterProductOp::getExpectedMaskType() {
3192 auto vecType = this->getResultVectorType();
3193 return VectorType::get(vecType.getShape(),
3194 IntegerType::get(vecType.getContext(), /*width=*/1),
3195 vecType.getScalableDims());
3196}
3197
3198//===----------------------------------------------------------------------===//
3199// ReshapeOp
3200//===----------------------------------------------------------------------===//
3201
3202LogicalResult ReshapeOp::verify() {
3203 // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
3204 auto inputVectorType = getInputVectorType();
3205 auto outputVectorType = getOutputVectorType();
3206 int64_t inputShapeRank = getNumInputShapeSizes();
3207 int64_t outputShapeRank = getNumOutputShapeSizes();
3208 SmallVector<int64_t, 4> fixedVectorSizes;
3209 getFixedVectorSizes(fixedVectorSizes);
3210 int64_t numFixedVectorSizes = fixedVectorSizes.size();
3211
3212 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
3213 return emitError("invalid input shape for vector type ") << inputVectorType;
3214
3215 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
3216 return emitError("invalid output shape for vector type ")
3217 << outputVectorType;
3218
3219 // Verify that the 'fixedVectorSizes' match an input/output vector shape
3220 // suffix.
3221 unsigned inputVectorRank = inputVectorType.getRank();
3222 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
3223 unsigned index = inputVectorRank - numFixedVectorSizes - i;
3224 if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
3225 return emitError("fixed vector size must match input vector for dim ")
3226 << i;
3227 }
3228
3229 unsigned outputVectorRank = outputVectorType.getRank();
3230 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
3231 unsigned index = outputVectorRank - numFixedVectorSizes - i;
3232 if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
3233 return emitError("fixed vector size must match output vector for dim ")
3234 << i;
3235 }
3236
3237 // If all shape operands are produced by constant ops, verify that product
3238 // of dimensions for input/output shape match.
3239 auto isDefByConstant = [](Value operand) {
3240 return getConstantIntValue(operand).has_value();
3241 };
3242 if (llvm::all_of(getInputShape(), isDefByConstant) &&
3243 llvm::all_of(getOutputShape(), isDefByConstant)) {
3244 int64_t numInputElements = 1;
3245 for (auto operand : getInputShape())
3246 numInputElements *= getConstantIntValue(operand).value();
3247 int64_t numOutputElements = 1;
3248 for (auto operand : getOutputShape())
3249 numOutputElements *= getConstantIntValue(operand).value();
3250 if (numInputElements != numOutputElements)
3251 return emitError("product of input and output shape sizes must match");
3252 }
3253 return success();
3254}
3255
3256void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
3257 populateFromInt64AttrArray(getFixedVectorSizes(), results);
3258}
3259
3260//===----------------------------------------------------------------------===//
3261// ExtractStridedSliceOp
3262//===----------------------------------------------------------------------===//
3263
3264// Inference works as follows:
3265// 1. Add 'sizes' from prefix of dims in 'offsets'.
3266// 2. Add sizes from 'vectorType' for remaining dims.
3267// Scalable flags are inherited from 'vectorType'.
3268static Type inferStridedSliceOpResultType(VectorType vectorType,
3269 ArrayAttr offsets, ArrayAttr sizes,
3270 ArrayAttr strides) {
3271 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
3272 SmallVector<int64_t, 4> shape;
3273 shape.reserve(N: vectorType.getRank());
3274 unsigned idx = 0;
3275 for (unsigned e = offsets.size(); idx < e; ++idx)
3276 shape.push_back(Elt: llvm::cast<IntegerAttr>(sizes[idx]).getInt());
3277 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
3278 shape.push_back(Elt: vectorType.getShape()[idx]);
3279
3280 return VectorType::get(shape, vectorType.getElementType(),
3281 vectorType.getScalableDims());
3282}
3283
3284void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3285 Value source, ArrayRef<int64_t> offsets,
3286 ArrayRef<int64_t> sizes,
3287 ArrayRef<int64_t> strides) {
3288 result.addOperands(source);
3289 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3290 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
3291 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3292 result.addTypes(
3293 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
3294 offsetsAttr, sizesAttr, stridesAttr));
3295 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
3296 offsetsAttr);
3297 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
3298 sizesAttr);
3299 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
3300 stridesAttr);
3301}
3302
3303LogicalResult ExtractStridedSliceOp::verify() {
3304 auto type = getSourceVectorType();
3305 auto offsets = getOffsetsAttr();
3306 auto sizes = getSizesAttr();
3307 auto strides = getStridesAttr();
3308 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
3309 return emitOpError(
3310 "expected offsets, sizes and strides attributes of same size");
3311
3312 auto shape = type.getShape();
3313 auto offName = getOffsetsAttrName();
3314 auto sizesName = getSizesAttrName();
3315 auto stridesName = getStridesAttrName();
3316 if (failed(
3317 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
3318 failed(
3319 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
3320 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
3321 stridesName)) ||
3322 failed(
3323 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
3324 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
3325 /*halfOpen=*/false,
3326 /*min=*/1)) ||
3327 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3328 /*max=*/1, stridesName,
3329 /*halfOpen=*/false)) ||
3330 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
3331 shape, offName, sizesName,
3332 /*halfOpen=*/false)))
3333 return failure();
3334
3335 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
3336 offsets, sizes, strides);
3337 if (getResult().getType() != resultType)
3338 return emitOpError("expected result type to be ") << resultType;
3339
3340 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
3341 if (type.getScalableDims()[idx]) {
3342 auto inputDim = type.getShape()[idx];
3343 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3344 if (inputDim != inputSize)
3345 return emitOpError("expected size at idx=")
3346 << idx
3347 << (" to match the corresponding base size from the input "
3348 "vector (")
3349 << inputSize << (" vs ") << inputDim << (")");
3350 }
3351 }
3352
3353 return success();
3354}
3355
3356// When the source of ExtractStrided comes from a chain of InsertStrided ops try
3357// to use the source of the InsertStrided ops if we can detect that the
3358// extracted vector is a subset of one of the vector inserted.
3359static LogicalResult
3360foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
3361 // Helper to extract integer out of ArrayAttr.
3362 auto getElement = [](ArrayAttr array, int idx) {
3363 return llvm::cast<IntegerAttr>(array[idx]).getInt();
3364 };
3365 ArrayAttr extractOffsets = op.getOffsets();
3366 ArrayAttr extractStrides = op.getStrides();
3367 ArrayAttr extractSizes = op.getSizes();
3368 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
3369 while (insertOp) {
3370 if (op.getSourceVectorType().getRank() !=
3371 insertOp.getSourceVectorType().getRank())
3372 return failure();
3373 ArrayAttr insertOffsets = insertOp.getOffsets();
3374 ArrayAttr insertStrides = insertOp.getStrides();
3375 // If the rank of extract is greater than the rank of insert, we are likely
3376 // extracting a partial chunk of the vector inserted.
3377 if (extractOffsets.size() > insertOffsets.size())
3378 return failure();
3379 bool patialoverlap = false;
3380 bool disjoint = false;
3381 SmallVector<int64_t, 4> offsetDiffs;
3382 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
3383 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
3384 return failure();
3385 int64_t start = getElement(insertOffsets, dim);
3386 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
3387 int64_t offset = getElement(extractOffsets, dim);
3388 int64_t size = getElement(extractSizes, dim);
3389 // Check if the start of the extract offset is in the interval inserted.
3390 if (start <= offset && offset < end) {
3391 // If the extract interval overlaps but is not fully included we may
3392 // have a partial overlap that will prevent any folding.
3393 if (offset + size > end)
3394 patialoverlap = true;
3395 offsetDiffs.push_back(Elt: offset - start);
3396 continue;
3397 }
3398 disjoint = true;
3399 break;
3400 }
3401 // The extract element chunk is a subset of the insert element.
3402 if (!disjoint && !patialoverlap) {
3403 op.setOperand(insertOp.getSource());
3404 // OpBuilder is only used as a helper to build an I64ArrayAttr.
3405 OpBuilder b(op.getContext());
3406 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
3407 return success();
3408 }
3409 // If the chunk extracted is disjoint from the chunk inserted, keep looking
3410 // in the insert chain.
3411 if (disjoint)
3412 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
3413 else {
3414 // The extracted vector partially overlap the inserted vector, we cannot
3415 // fold.
3416 return failure();
3417 }
3418 }
3419 return failure();
3420}
3421
3422OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
3423 if (getSourceVectorType() == getResult().getType())
3424 return getVector();
3425 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
3426 return getResult();
3427 return {};
3428}
3429
3430void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
3431 populateFromInt64AttrArray(getOffsets(), results);
3432}
3433
3434namespace {
3435
3436// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
3437// ConstantMaskOp.
3438class StridedSliceConstantMaskFolder final
3439 : public OpRewritePattern<ExtractStridedSliceOp> {
3440public:
3441 using OpRewritePattern::OpRewritePattern;
3442
3443 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3444 PatternRewriter &rewriter) const override {
3445 // Return if 'extractStridedSliceOp' operand is not defined by a
3446 // ConstantMaskOp.
3447 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
3448 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
3449 if (!constantMaskOp)
3450 return failure();
3451 // Return if 'extractStridedSliceOp' has non-unit strides.
3452 if (extractStridedSliceOp.hasNonUnitStrides())
3453 return failure();
3454 // Gather constant mask dimension sizes.
3455 SmallVector<int64_t, 4> maskDimSizes;
3456 populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
3457 // Gather strided slice offsets and sizes.
3458 SmallVector<int64_t, 4> sliceOffsets;
3459 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
3460 sliceOffsets);
3461 SmallVector<int64_t, 4> sliceSizes;
3462 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
3463
3464 // Compute slice of vector mask region.
3465 SmallVector<int64_t, 4> sliceMaskDimSizes;
3466 sliceMaskDimSizes.reserve(N: maskDimSizes.size());
3467 for (auto [maskDimSize, sliceOffset, sliceSize] :
3468 llvm::zip(t&: maskDimSizes, u&: sliceOffsets, args&: sliceSizes)) {
3469 int64_t sliceMaskDimSize = std::max(
3470 a: static_cast<int64_t>(0),
3471 b: std::min(a: sliceOffset + sliceSize, b: maskDimSize) - sliceOffset);
3472 sliceMaskDimSizes.push_back(Elt: sliceMaskDimSize);
3473 }
3474 // Add unchanged dimensions.
3475 if (sliceMaskDimSizes.size() < maskDimSizes.size())
3476 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
3477 sliceMaskDimSizes.push_back(Elt: maskDimSizes[i]);
3478 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
3479 // region is a conjunction of mask dim intervals).
3480 if (llvm::is_contained(Range&: sliceMaskDimSizes, Element: 0))
3481 sliceMaskDimSizes.assign(NumElts: maskDimSizes.size(), Elt: 0);
3482
3483 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
3484 // region.
3485 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
3486 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3487 vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
3488 return success();
3489 }
3490};
3491
3492// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3493class StridedSliceSplatConstantFolder final
3494 : public OpRewritePattern<ExtractStridedSliceOp> {
3495public:
3496 using OpRewritePattern::OpRewritePattern;
3497
3498 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3499 PatternRewriter &rewriter) const override {
3500 // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3501 // ConstantOp.
3502 Value sourceVector = extractStridedSliceOp.getVector();
3503 Attribute vectorCst;
3504 if (!matchPattern(value: sourceVector, pattern: m_Constant(bind_value: &vectorCst)))
3505 return failure();
3506
3507 auto splat = llvm::dyn_cast<SplatElementsAttr>(Val&: vectorCst);
3508 if (!splat)
3509 return failure();
3510
3511 auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3512 splat.getSplatValue<Attribute>());
3513 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3514 newAttr);
3515 return success();
3516 }
3517};
3518
3519// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3520// ConstantOp.
3521class StridedSliceNonSplatConstantFolder final
3522 : public OpRewritePattern<ExtractStridedSliceOp> {
3523public:
3524 using OpRewritePattern::OpRewritePattern;
3525
3526 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3527 PatternRewriter &rewriter) const override {
3528 // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3529 // ConstantOp.
3530 Value sourceVector = extractStridedSliceOp.getVector();
3531 Attribute vectorCst;
3532 if (!matchPattern(value: sourceVector, pattern: m_Constant(bind_value: &vectorCst)))
3533 return failure();
3534
3535 // The splat case is handled by `StridedSliceSplatConstantFolder`.
3536 auto dense = llvm::dyn_cast<DenseElementsAttr>(Val&: vectorCst);
3537 if (!dense || dense.isSplat())
3538 return failure();
3539
3540 // TODO: Handle non-unit strides when they become available.
3541 if (extractStridedSliceOp.hasNonUnitStrides())
3542 return failure();
3543
3544 auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
3545 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3546 SmallVector<int64_t, 4> sourceStrides = computeStrides(sizes: sourceShape);
3547
3548 VectorType sliceVecTy = extractStridedSliceOp.getType();
3549 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3550 int64_t sliceRank = sliceVecTy.getRank();
3551
3552 // Expand offsets and sizes to match the vector rank.
3553 SmallVector<int64_t, 4> offsets(sliceRank, 0);
3554 copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3555
3556 SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end());
3557 copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3558
3559 // Calculate the slice elements by enumerating all slice positions and
3560 // linearizing them. The enumeration order is lexicographic which yields a
3561 // sequence of monotonically increasing linearized position indices.
3562 auto denseValuesBegin = dense.value_begin<Attribute>();
3563 SmallVector<Attribute> sliceValues;
3564 sliceValues.reserve(N: sliceVecTy.getNumElements());
3565 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3566 do {
3567 int64_t linearizedPosition = linearize(offsets: currSlicePosition, basis: sourceStrides);
3568 assert(linearizedPosition < sourceVecTy.getNumElements() &&
3569 "Invalid index");
3570 sliceValues.push_back(Elt: *(denseValuesBegin + linearizedPosition));
3571 } while (
3572 succeeded(result: incSlicePosition(position: currSlicePosition, shape: sliceShape, offsets)));
3573
3574 assert(static_cast<int64_t>(sliceValues.size()) ==
3575 sliceVecTy.getNumElements() &&
3576 "Invalid number of slice elements");
3577 auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3578 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3579 newAttr);
3580 return success();
3581 }
3582};
3583
3584// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3585// BroadcastOp(ExtractStrideSliceOp).
3586class StridedSliceBroadcast final
3587 : public OpRewritePattern<ExtractStridedSliceOp> {
3588public:
3589 using OpRewritePattern::OpRewritePattern;
3590
3591 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3592 PatternRewriter &rewriter) const override {
3593 auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3594 if (!broadcast)
3595 return failure();
3596 auto srcVecType =
3597 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
3598 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3599 auto dstVecType = llvm::cast<VectorType>(op.getType());
3600 unsigned dstRank = dstVecType.getRank();
3601 unsigned rankDiff = dstRank - srcRank;
3602 // Check if the most inner dimensions of the source of the broadcast are the
3603 // same as the destination of the extract. If this is the case we can just
3604 // use a broadcast as the original dimensions are untouched.
3605 bool lowerDimMatch = true;
3606 for (unsigned i = 0; i < srcRank; i++) {
3607 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3608 lowerDimMatch = false;
3609 break;
3610 }
3611 }
3612 Value source = broadcast.getSource();
3613 // If the inner dimensions don't match, it means we need to extract from the
3614 // source of the orignal broadcast and then broadcast the extracted value.
3615 // We also need to handle degenerated cases where the source is effectively
3616 // just a single scalar.
3617 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3618 if (!lowerDimMatch && !isScalarSrc) {
3619 source = rewriter.create<ExtractStridedSliceOp>(
3620 op->getLoc(), source,
3621 getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3622 getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3623 getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3624 }
3625 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3626 return success();
3627 }
3628};
3629
3630/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3631class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3632public:
3633 using OpRewritePattern::OpRewritePattern;
3634
3635 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3636 PatternRewriter &rewriter) const override {
3637 auto splat = op.getVector().getDefiningOp<SplatOp>();
3638 if (!splat)
3639 return failure();
3640 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3641 return success();
3642 }
3643};
3644
3645} // namespace
3646
3647void ExtractStridedSliceOp::getCanonicalizationPatterns(
3648 RewritePatternSet &results, MLIRContext *context) {
3649 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3650 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3651 results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3652 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3653 StridedSliceSplat>(context);
3654}
3655
3656//===----------------------------------------------------------------------===//
3657// TransferReadOp
3658//===----------------------------------------------------------------------===//
3659
3660/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3661void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3662 VectorType vectorType, Value source,
3663 ValueRange indices, AffineMapAttr permutationMapAttr,
3664 /*optional*/ ArrayAttr inBoundsAttr) {
3665 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3666 Value padding = builder.create<arith::ConstantOp>(
3667 result.location, elemType, builder.getZeroAttr(elemType));
3668 build(builder, result, vectorType, source, indices, permutationMapAttr,
3669 padding, /*mask=*/Value(), inBoundsAttr);
3670}
3671
3672/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
3673void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3674 VectorType vectorType, Value source,
3675 ValueRange indices, AffineMap permutationMap,
3676 std::optional<ArrayRef<bool>> inBounds) {
3677 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3678 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3679 ? builder.getBoolArrayAttr(inBounds.value())
3680 : ArrayAttr();
3681 build(builder, result, vectorType, source, indices, permutationMapAttr,
3682 inBoundsAttr);
3683}
3684
3685/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
3686void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3687 VectorType vectorType, Value source,
3688 ValueRange indices, Value padding,
3689 std::optional<ArrayRef<bool>> inBounds) {
3690 AffineMap permutationMap = getTransferMinorIdentityMap(
3691 llvm::cast<ShapedType>(source.getType()), vectorType);
3692 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3693 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3694 ? builder.getBoolArrayAttr(inBounds.value())
3695 : ArrayAttr();
3696 build(builder, result, vectorType, source, indices, permutationMapAttr,
3697 padding,
3698 /*mask=*/Value(), inBoundsAttr);
3699}
3700
3701/// 4. Builder that sets padding to zero and permutation map to
3702/// 'getMinorIdentityMap'.
3703void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3704 VectorType vectorType, Value source,
3705 ValueRange indices,
3706 std::optional<ArrayRef<bool>> inBounds) {
3707 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
3708 Value padding = builder.create<arith::ConstantOp>(
3709 result.location, elemType, builder.getZeroAttr(elemType));
3710 build(builder, result, vectorType, source, indices, padding, inBounds);
3711}
3712
3713template <typename EmitFun>
3714static LogicalResult verifyPermutationMap(AffineMap permutationMap,
3715 EmitFun emitOpError) {
3716 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
3717 for (auto expr : permutationMap.getResults()) {
3718 auto dim = dyn_cast<AffineDimExpr>(Val&: expr);
3719 auto zero = dyn_cast<AffineConstantExpr>(Val&: expr);
3720 if (zero) {
3721 if (zero.getValue() != 0) {
3722 return emitOpError(
3723 "requires a projected permutation_map (at most one dim or the zero "
3724 "constant can appear in each result)");
3725 }
3726 continue;
3727 }
3728 if (!dim) {
3729 return emitOpError("requires a projected permutation_map (at most one "
3730 "dim or the zero constant can appear in each result)");
3731 }
3732 if (seen[dim.getPosition()]) {
3733 return emitOpError(
3734 "requires a permutation_map that is a permutation (found one dim "
3735 "used more than once)");
3736 }
3737 seen[dim.getPosition()] = true;
3738 }
3739 return success();
3740}
3741
3742static LogicalResult
3743verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
3744 VectorType vectorType, VectorType maskType,
3745 VectorType inferredMaskType, AffineMap permutationMap,
3746 ArrayAttr inBounds) {
3747 if (op->hasAttr("masked")) {
3748 return op->emitOpError("masked attribute has been removed. "
3749 "Use in_bounds instead.");
3750 }
3751
3752 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
3753 return op->emitOpError(
3754 "requires source to be a memref or ranked tensor type");
3755
3756 auto elementType = shapedType.getElementType();
3757 DataLayout dataLayout = DataLayout::closest(op: op);
3758 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
3759 // Memref or tensor has vector element type.
3760 unsigned sourceVecSize =
3761 dataLayout.getTypeSizeInBits(t: vectorElementType.getElementType()) *
3762 vectorElementType.getShape().back();
3763 unsigned resultVecSize =
3764 dataLayout.getTypeSizeInBits(t: vectorType.getElementType()) *
3765 vectorType.getShape().back();
3766 if (resultVecSize % sourceVecSize != 0)
3767 return op->emitOpError(
3768 "requires the bitwidth of the minor 1-D vector to be an integral "
3769 "multiple of the bitwidth of the minor 1-D vector of the source");
3770
3771 unsigned sourceVecEltRank = vectorElementType.getRank();
3772 unsigned resultVecRank = vectorType.getRank();
3773 if (sourceVecEltRank > resultVecRank)
3774 return op->emitOpError(
3775 "requires source vector element and vector result ranks to match.");
3776 unsigned rankOffset = resultVecRank - sourceVecEltRank;
3777 // Check that permutation map results match 'rankOffset' of vector type.
3778 if (permutationMap.getNumResults() != rankOffset)
3779 return op->emitOpError("requires a permutation_map with result dims of "
3780 "the same rank as the vector type");
3781
3782 if (maskType)
3783 return op->emitOpError("does not support masks with vector element type");
3784 } else {
3785 // Memref or tensor has scalar element type.
3786 unsigned minorSize =
3787 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
3788 unsigned resultVecSize =
3789 dataLayout.getTypeSizeInBits(t: vectorType.getElementType()) * minorSize;
3790 if (resultVecSize % dataLayout.getTypeSizeInBits(t: elementType) != 0)
3791 return op->emitOpError(
3792 "requires the bitwidth of the minor 1-D vector to be an integral "
3793 "multiple of the bitwidth of the source element type");
3794
3795 // Check that permutation map results match rank of vector type.
3796 if (permutationMap.getNumResults() != vectorType.getRank())
3797 return op->emitOpError("requires a permutation_map with result dims of "
3798 "the same rank as the vector type");
3799 }
3800
3801 if (permutationMap.getNumSymbols() != 0)
3802 return op->emitOpError("requires permutation_map without symbols");
3803
3804 if (permutationMap.getNumInputs() != shapedType.getRank())
3805 return op->emitOpError("requires a permutation_map with input dims of the "
3806 "same rank as the source type");
3807
3808 if (maskType && maskType != inferredMaskType)
3809 return op->emitOpError("inferred mask type (")
3810 << inferredMaskType << ") and mask operand type (" << maskType
3811 << ") don't match";
3812
3813 if (inBounds) {
3814 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
3815 return op->emitOpError("expects the optional in_bounds attr of same rank "
3816 "as permutation_map results: ")
3817 << AffineMapAttr::get(permutationMap)
3818 << " vs inBounds of size: " << inBounds.size();
3819 for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
3820 if (isa<AffineConstantExpr>(Val: permutationMap.getResult(idx: i)) &&
3821 !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
3822 return op->emitOpError("requires broadcast dimensions to be in-bounds");
3823 }
3824
3825 return success();
3826}
3827
3828static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
3829 SmallVector<StringRef, 3> elidedAttrs;
3830 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
3831 if (op.getPermutationMap().isMinorIdentity())
3832 elidedAttrs.push_back(Elt: op.getPermutationMapAttrName());
3833 // Elide in_bounds attribute if all dims are out-of-bounds.
3834 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
3835 elidedAttrs.push_back(Elt: op.getInBoundsAttrName());
3836 p.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs);
3837}
3838
3839void TransferReadOp::print(OpAsmPrinter &p) {
3840 p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
3841 if (getMask())
3842 p << ", " << getMask();
3843 printTransferAttrs(p, *this);
3844 p << " : " << getShapedType() << ", " << getVectorType();
3845}
3846
3847VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
3848 AffineMap permMap) {
3849 auto i1Type = IntegerType::get(permMap.getContext(), 1);
3850 AffineMap invPermMap = inversePermutation(map: compressUnusedDims(map: permMap));
3851 assert(invPermMap && "Inversed permutation map couldn't be computed");
3852 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
3853
3854 SmallVector<bool> scalableDims =
3855 applyPermutationMap(invPermMap, vecType.getScalableDims());
3856
3857 return VectorType::get(maskShape, i1Type, scalableDims);
3858}
3859
3860ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3861 auto &builder = parser.getBuilder();
3862 SMLoc typesLoc;
3863 OpAsmParser::UnresolvedOperand sourceInfo;
3864 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
3865 OpAsmParser::UnresolvedOperand paddingInfo;
3866 SmallVector<Type, 2> types;
3867 OpAsmParser::UnresolvedOperand maskInfo;
3868 // Parsing with support for paddingValue.
3869 if (parser.parseOperand(sourceInfo) ||
3870 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
3871 parser.parseComma() || parser.parseOperand(paddingInfo))
3872 return failure();
3873 ParseResult hasMask = parser.parseOptionalComma();
3874 if (hasMask.succeeded()) {
3875 if (parser.parseOperand(maskInfo))
3876 return failure();
3877 }
3878 if (parser.parseOptionalAttrDict(result.attributes) ||
3879 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3880 return failure();
3881 if (types.size() != 2)
3882 return parser.emitError(typesLoc, "requires two types");
3883 auto indexType = builder.getIndexType();
3884 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
3885 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
3886 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3887 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
3888 if (!vectorType)
3889 return parser.emitError(typesLoc, "requires vector type");
3890 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
3891 Attribute permMapAttr = result.attributes.get(permMapAttrName);
3892 AffineMap permMap;
3893 if (!permMapAttr) {
3894 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3895 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
3896 } else {
3897 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
3898 }
3899 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3900 parser.resolveOperands(indexInfo, indexType, result.operands) ||
3901 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
3902 result.operands))
3903 return failure();
3904 if (hasMask.succeeded()) {
3905 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
3906 return parser.emitError(
3907 maskInfo.location, "does not support masks with vector element type");
3908 if (vectorType.getRank() != permMap.getNumResults()) {
3909 return parser.emitError(typesLoc,
3910 "expected the same rank for the vector and the "
3911 "results of the permutation map");
3912 }
3913 // Instead of adding the mask type as an op type, compute it based on the
3914 // vector type and the permutation map (to keep the type signature small).
3915 auto maskType = inferTransferOpMaskType(vectorType, permMap);
3916 if (parser.resolveOperand(maskInfo, maskType, result.operands))
3917 return failure();
3918 }
3919 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
3920 builder.getDenseI32ArrayAttr(
3921 {1, static_cast<int32_t>(indexInfo.size()), 1,
3922 static_cast<int32_t>(hasMask.succeeded())}));
3923 return parser.addTypeToList(vectorType, result.types);
3924}
3925
3926LogicalResult TransferReadOp::verify() {
3927 // Consistency of elemental types in source and vector.
3928 ShapedType shapedType = getShapedType();
3929 VectorType vectorType = getVectorType();
3930 VectorType maskType = getMaskType();
3931 auto paddingType = getPadding().getType();
3932 auto permutationMap = getPermutationMap();
3933 VectorType inferredMaskType =
3934 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
3935 : VectorType();
3936 auto sourceElementType = shapedType.getElementType();
3937
3938 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
3939 return emitOpError("requires ") << shapedType.getRank() << " indices";
3940
3941 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3942 shapedType, vectorType, maskType,
3943 inferredMaskType, permutationMap,
3944 getInBounds() ? *getInBounds() : ArrayAttr())))
3945 return failure();
3946
3947 if (auto sourceVectorElementType =
3948 llvm::dyn_cast<VectorType>(sourceElementType)) {
3949 // Source has vector element type.
3950 // Check that 'sourceVectorElementType' and 'paddingType' types match.
3951 if (sourceVectorElementType != paddingType)
3952 return emitOpError(
3953 "requires source element type and padding type to match.");
3954
3955 } else {
3956 // Check that 'paddingType' is valid to store in a vector type.
3957 if (!VectorType::isValidElementType(paddingType))
3958 return emitOpError("requires valid padding vector elemental type");
3959
3960 // Check that padding type and vector element types match.
3961 if (paddingType != sourceElementType)
3962 return emitOpError(
3963 "requires formal padding and source of the same elemental type");
3964 }
3965
3966 return verifyPermutationMap(permutationMap,
3967 [&](Twine t) { return emitOpError(t); });
3968}
3969
3970// MaskableOpInterface methods.
3971
3972/// Returns the mask type expected by this operation. Mostly used for
3973/// verification purposes. It requires the operation to be vectorized."
3974Type TransferReadOp::getExpectedMaskType() {
3975 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
3976}
3977
3978template <typename TransferOp>
3979static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
3980 // TODO: support more aggressive createOrFold on:
3981 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
3982 if (op.getShapedType().isDynamicDim(indicesIdx))
3983 return false;
3984 Value index = op.getIndices()[indicesIdx];
3985 std::optional<int64_t> cstOp = getConstantIntValue(ofr: index);
3986 if (!cstOp.has_value())
3987 return false;
3988
3989 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
3990 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
3991
3992 return cstOp.value() + vectorSize <= sourceSize;
3993}
3994
3995template <typename TransferOp>
3996static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
3997 // TODO: support 0-d corner case.
3998 // TODO: Be less conservative.
3999 if (op.getTransferRank() == 0)
4000 return failure();
4001 AffineMap permutationMap = op.getPermutationMap();
4002 bool changed = false;
4003 SmallVector<bool, 4> newInBounds;
4004 newInBounds.reserve(N: op.getTransferRank());
4005 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
4006 // Already marked as in-bounds, nothing to see here.
4007 if (op.isDimInBounds(i)) {
4008 newInBounds.push_back(Elt: true);
4009 continue;
4010 }
4011 // Currently out-of-bounds, check whether we can statically determine it is
4012 // inBounds.
4013 auto dimExpr = dyn_cast<AffineDimExpr>(Val: permutationMap.getResult(idx: i));
4014 assert(dimExpr && "Broadcast dims must be in-bounds");
4015 auto inBounds =
4016 isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
4017 newInBounds.push_back(Elt: inBounds);
4018 // We commit the pattern if it is "more inbounds".
4019 changed |= inBounds;
4020 }
4021 if (!changed)
4022 return failure();
4023 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4024 OpBuilder b(op.getContext());
4025 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
4026 return success();
4027}
4028
4029template <typename TransferOp>
4030static LogicalResult foldTransferFullMask(TransferOp op) {
4031 auto mask = op.getMask();
4032 if (!mask)
4033 return failure();
4034
4035 auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
4036 if (!constantMask)
4037 return failure();
4038
4039 if (!constantMask.isAllOnesMask())
4040 return failure();
4041
4042 op.getMaskMutable().clear();
4043 return success();
4044}
4045
4046/// ```
4047/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4048/// : vector<1x4xf32>, tensor<4x4xf32>
4049/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
4050/// : tensor<4x4xf32>, vector<1x4xf32>
4051/// ```
4052/// -> Folds into
4053/// ```
4054/// %v0
4055/// ```
4056static Value foldRAW(TransferReadOp readOp) {
4057 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4058 return {};
4059 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4060 while (defWrite) {
4061 if (checkSameValueRAW(defWrite, readOp))
4062 return defWrite.getVector();
4063 if (!isDisjointTransferIndices(
4064 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4065 cast<VectorTransferOpInterface>(readOp.getOperation())))
4066 break;
4067 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4068 }
4069 return {};
4070}
4071
4072OpFoldResult TransferReadOp::fold(FoldAdaptor) {
4073 if (Value vec = foldRAW(*this))
4074 return vec;
4075 /// transfer_read(memrefcast) -> transfer_read
4076 if (succeeded(foldTransferInBoundsAttribute(*this)))
4077 return getResult();
4078 if (succeeded(foldTransferFullMask(*this)))
4079 return getResult();
4080 if (succeeded(memref::foldMemRefCast(*this)))
4081 return getResult();
4082 if (succeeded(tensor::foldTensorCast(*this)))
4083 return getResult();
4084 return OpFoldResult();
4085}
4086
4087std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
4088 return llvm::to_vector<4>(getVectorType().getShape());
4089}
4090
4091void TransferReadOp::getEffects(
4092 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4093 &effects) {
4094 if (llvm::isa<MemRefType>(getShapedType()))
4095 effects.emplace_back(MemoryEffects::Read::get(), getSource(),
4096 SideEffects::DefaultResource::get());
4097}
4098
4099namespace {
4100/// Store to load forwarding for transfer operations with permuation maps.
4101/// Even if the permutation maps are different we can still propagate the store
4102/// into the load if the size of the dimensions read and written match. Then we
4103/// can replace the transfer_read + transfer_write by vector.broadcast and
4104/// vector.transpose.
4105/// Example:
4106/// ```
4107/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
4108/// {in_bounds = [true, true],
4109/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
4110/// vector<4x1xf32>, tensor<4x4x4xf32>
4111/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
4112/// {in_bounds = [true, true, true, true],
4113/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
4114/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
4115/// ```
4116/// To:
4117/// ```
4118/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
4119/// %r = vector.transpose %0, [3, 0, 2, 1] :
4120/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
4121/// ```
4122struct TransferReadAfterWriteToBroadcast
4123 : public OpRewritePattern<TransferReadOp> {
4124 using OpRewritePattern::OpRewritePattern;
4125
4126 LogicalResult matchAndRewrite(TransferReadOp readOp,
4127 PatternRewriter &rewriter) const override {
4128 if (readOp.hasOutOfBoundsDim() ||
4129 !llvm::isa<RankedTensorType>(readOp.getShapedType()))
4130 return failure();
4131 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4132 if (!defWrite)
4133 return failure();
4134 // TODO: If the written transfer chunk is a superset of the read transfer
4135 // chunk we could do an extract_strided_slice.
4136 if (readOp.getTransferChunkAccessed() !=
4137 defWrite.getTransferChunkAccessed())
4138 return failure();
4139 // TODO: Support cases where a dim is explicitly written but implicitly
4140 // read (i.e., a unit dim that is rank reduced).
4141 if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4142 getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4143 return failure();
4144 if (readOp.getIndices() != defWrite.getIndices() ||
4145 readOp.getMask() != defWrite.getMask())
4146 return failure();
4147 Value vec = defWrite.getVector();
4148 // TODO: loop through the chain of transfer_write if we can prove that they
4149 // don't overlap with the transfer_read. This requires improving
4150 // `isDisjointTransferIndices` helper.
4151 AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4152 AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4153 AffineMap map = readMap.compose(map: writeMap);
4154 if (map.getNumResults() == 0)
4155 return failure();
4156 // Calculate the permutation to apply to go from the vector stored to the
4157 // vector read.
4158 SmallVector<unsigned> permutation;
4159 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutedDims&: permutation))
4160 return failure();
4161
4162 Location loc = readOp.getLoc();
4163 // Calculate the broadcast shape by applying the reverse permutation to the
4164 // final shape we want.
4165 ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
4166 SmallVector<int64_t> broadcastShape(destShape.size());
4167 SmallVector<bool> broadcastScalableFlags(destShape.size());
4168 for (const auto &pos : llvm::enumerate(First&: permutation)) {
4169 broadcastShape[pos.value()] = destShape[pos.index()];
4170 broadcastScalableFlags[pos.value()] =
4171 readOp.getVectorType().getScalableDims()[pos.index()];
4172 }
4173 VectorType broadcastedType = VectorType::get(
4174 broadcastShape, defWrite.getVectorType().getElementType(),
4175 broadcastScalableFlags);
4176 vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
4177 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
4178 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
4179 transposePerm);
4180 return success();
4181 }
4182};
4183} // namespace
4184
4185void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4186 MLIRContext *context) {
4187 results.add<TransferReadAfterWriteToBroadcast>(context);
4188}
4189
4190//===----------------------------------------------------------------------===//
4191// TransferWriteOp
4192//===----------------------------------------------------------------------===//
4193
4194/// 1. Builder with type inference.
4195void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4196 Value vector, Value dest, ValueRange indices,
4197 AffineMapAttr permutationMapAttr,
4198 /*optional*/ Value mask,
4199 /*optional*/ ArrayAttr inBoundsAttr) {
4200 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
4201 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
4202 mask, inBoundsAttr);
4203}
4204
4205/// 2. Builder with type inference that sets an empty mask (variant with attrs).
4206void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4207 Value vector, Value dest, ValueRange indices,
4208 AffineMapAttr permutationMapAttr,
4209 /*optional*/ ArrayAttr inBoundsAttr) {
4210 build(builder, result, vector, dest, indices, permutationMapAttr,
4211 /*mask=*/Value(), inBoundsAttr);
4212}
4213
4214/// 3. Builder with type inference that sets an empty mask (variant without
4215/// attrs)
4216void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4217 Value vector, Value dest, ValueRange indices,
4218 AffineMap permutationMap,
4219 std::optional<ArrayRef<bool>> inBounds) {
4220 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4221 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4222 ? builder.getBoolArrayAttr(inBounds.value())
4223 : ArrayAttr();
4224 build(builder, result, vector, dest, indices, permutationMapAttr,
4225 /*mask=*/Value(), inBoundsAttr);
4226}
4227
4228/// 4. Builder with type inference that sets an empty mask and sets permutation
4229/// map to 'getMinorIdentityMap'.
4230void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
4231 Value vector, Value dest, ValueRange indices,
4232 std::optional<ArrayRef<bool>> inBounds) {
4233 auto vectorType = llvm::cast<VectorType>(vector.getType());
4234 AffineMap permutationMap = getTransferMinorIdentityMap(
4235 llvm::cast<ShapedType>(dest.getType()), vectorType);
4236 build(builder, result, vector, dest, indices, permutationMap, inBounds);
4237}
4238
4239ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4240 OperationState &result) {
4241 auto &builder = parser.getBuilder();
4242 SMLoc typesLoc;
4243 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
4244 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
4245 SmallVector<Type, 2> types;
4246 OpAsmParser::UnresolvedOperand maskInfo;
4247 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
4248 parser.parseOperand(sourceInfo) ||
4249 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
4250 return failure();
4251 ParseResult hasMask = parser.parseOptionalComma();
4252 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
4253 return failure();
4254 if (parser.parseOptionalAttrDict(result.attributes) ||
4255 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4256 return failure();
4257 if (types.size() != 2)
4258 return parser.emitError(typesLoc, "requires two types");
4259 auto indexType = builder.getIndexType();
4260 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
4261 if (!vectorType)
4262 return parser.emitError(typesLoc, "requires vector type");
4263 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
4264 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4265 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4266 auto permMapAttrName =
4267 TransferWriteOp::getPermutationMapAttrName(result.name);
4268 auto permMapAttr = result.attributes.get(permMapAttrName);
4269 AffineMap permMap;
4270 if (!permMapAttr) {
4271 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4272 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4273 } else {
4274 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4275 }
4276 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
4277 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4278 parser.resolveOperands(indexInfo, indexType, result.operands))
4279 return failure();
4280 if (hasMask.succeeded()) {
4281 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4282 return parser.emitError(
4283 maskInfo.location, "does not support masks with vector element type");
4284 if (vectorType.getRank() != permMap.getNumResults()) {
4285 return parser.emitError(typesLoc,
4286 "expected the same rank for the vector and the "
4287 "results of the permutation map");
4288 }
4289 auto maskType = inferTransferOpMaskType(vectorType, permMap);
4290 if (parser.resolveOperand(maskInfo, maskType, result.operands))
4291 return failure();
4292 }
4293 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
4294 builder.getDenseI32ArrayAttr(
4295 {1, 1, static_cast<int32_t>(indexInfo.size()),
4296 static_cast<int32_t>(hasMask.succeeded())}));
4297 return failure(llvm::isa<RankedTensorType>(shapedType) &&
4298 parser.addTypeToList(shapedType, result.types));
4299}
4300
4301void TransferWriteOp::print(OpAsmPrinter &p) {
4302 p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
4303 if (getMask())
4304 p << ", " << getMask();
4305 printTransferAttrs(p, *this);
4306 p << " : " << getVectorType() << ", " << getShapedType();
4307}
4308
4309LogicalResult TransferWriteOp::verify() {
4310 // Consistency of elemental types in shape and vector.
4311 ShapedType shapedType = getShapedType();
4312 VectorType vectorType = getVectorType();
4313 VectorType maskType = getMaskType();
4314 auto permutationMap = getPermutationMap();
4315 VectorType inferredMaskType =
4316 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4317 : VectorType();
4318
4319 if (llvm::size(getIndices()) != shapedType.getRank())
4320 return emitOpError("requires ") << shapedType.getRank() << " indices";
4321
4322 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
4323 // as the semantics is unclear. This can be revisited later if necessary.
4324 if (hasBroadcastDim())
4325 return emitOpError("should not have broadcast dimensions");
4326
4327 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4328 shapedType, vectorType, maskType,
4329 inferredMaskType, permutationMap,
4330 getInBounds() ? *getInBounds() : ArrayAttr())))
4331 return failure();
4332
4333 return verifyPermutationMap(permutationMap,
4334 [&](Twine t) { return emitOpError(t); });
4335}
4336
4337// MaskableOpInterface methods.
4338
4339/// Returns the mask type expected by this operation. Mostly used for
4340/// verification purposes.
4341Type TransferWriteOp::getExpectedMaskType() {
4342 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4343}
4344
4345/// Fold:
4346/// ```
4347/// %t1 = ...
4348/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
4349/// tensor<static_sizesxf32>, vector<static_sizesxf32>
4350/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
4351/// vector<static_sizesxf32>, tensor<static_sizesxf32>
4352/// ```
4353///
4354/// into:
4355///
4356/// ```
4357/// %t0
4358/// ```
4359///
4360/// The producer of t1 may or may not be DCE'd depending on whether it is a
4361/// block argument or has side effects.
4362static LogicalResult foldReadInitWrite(TransferWriteOp write,
4363 ArrayRef<Attribute>,
4364 SmallVectorImpl<OpFoldResult> &results) {
4365 // TODO: support 0-d corner case.
4366 if (write.getTransferRank() == 0)
4367 return failure();
4368 auto rankedTensorType =
4369 llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
4370 // If not operating on tensors, bail.
4371 if (!rankedTensorType)
4372 return failure();
4373 // If no read, bail.
4374 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4375 if (!read)
4376 return failure();
4377 // TODO: support 0-d corner case.
4378 if (read.getTransferRank() == 0)
4379 return failure();
4380 // For now, only accept minor identity. Future: composition is minor identity.
4381 if (!read.getPermutationMap().isMinorIdentity() ||
4382 !write.getPermutationMap().isMinorIdentity())
4383 return failure();
4384 // Bail on mismatching ranks.
4385 if (read.getTransferRank() != write.getTransferRank())
4386 return failure();
4387 // Bail on potential out-of-bounds accesses.
4388 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
4389 return failure();
4390 // Tensor types must be the same.
4391 if (read.getSource().getType() != rankedTensorType)
4392 return failure();
4393 // Vector types must be the same.
4394 if (read.getVectorType() != write.getVectorType())
4395 return failure();
4396 // Vector and Tensor shapes must match.
4397 if (read.getVectorType().getShape() != rankedTensorType.getShape())
4398 return failure();
4399 // If any index is nonzero.
4400 auto isNotConstantZero = [](Value v) {
4401 auto cstOp = getConstantIntValue(ofr: v);
4402 return !cstOp.has_value() || cstOp.value() != 0;
4403 };
4404 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
4405 llvm::any_of(write.getIndices(), isNotConstantZero))
4406 return failure();
4407 // Success.
4408 results.push_back(Elt: read.getSource());
4409 return success();
4410}
4411
4412static bool checkSameValueWAR(vector::TransferReadOp read,
4413 vector::TransferWriteOp write) {
4414 return read.getSource() == write.getSource() &&
4415 read.getIndices() == write.getIndices() &&
4416 read.getPermutationMap() == write.getPermutationMap() &&
4417 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4418 !write.getMask();
4419}
4420/// Fold transfer_write write after read:
4421/// ```
4422/// %t0 = ...
4423/// %v = vector.transfer_read %t0[%c0...] :
4424/// tensor<static_sizesxf32>, vector<static_sizesxf32>
4425/// %t1 = vector.transfer_write %v, %t0[%c0...] :
4426/// vector<static_sizesxf32>, tensor<static_sizesxf32>
4427/// ```
4428///
4429/// into:
4430///
4431/// ```
4432/// %t0
4433/// ```
4434static LogicalResult foldWAR(TransferWriteOp write,
4435 SmallVectorImpl<OpFoldResult> &results) {
4436 if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
4437 return failure();
4438 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4439 if (!read)
4440 return failure();
4441
4442 if (!checkSameValueWAR(read, write))
4443 return failure();
4444 results.push_back(Elt: read.getSource());
4445 return success();
4446}
4447
4448LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
4449 SmallVectorImpl<OpFoldResult> &results) {
4450 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
4451 return success();
4452 if (succeeded(foldWAR(*this, results)))
4453 return success();
4454 if (succeeded(foldTransferInBoundsAttribute(*this)))
4455 return success();
4456 if (succeeded(foldTransferFullMask(*this)))
4457 return success();
4458 return memref::foldMemRefCast(*this);
4459}
4460
4461std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4462 return llvm::to_vector<4>(getVectorType().getShape());
4463}
4464
4465void TransferWriteOp::getEffects(
4466 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4467 &effects) {
4468 if (llvm::isa<MemRefType>(getShapedType()))
4469 effects.emplace_back(MemoryEffects::Write::get(), getSource(),
4470 SideEffects::DefaultResource::get());
4471}
4472
4473namespace {
4474/// Remove dead transfer write from the SSA chain so that it an be eliminated by
4475/// DCE
4476/// ```
4477/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4478/// : vector<1x4xf32>, tensor<4x4xf32>
4479/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4480/// : vector<1x4xf32>, tensor<4x4xf32>
4481/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4482/// : vector<1x4xf32>, tensor<4x4xf32>
4483/// ```
4484///
4485/// into:
4486///
4487/// ```
4488/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4489/// : vector<1x4xf32>, tensor<4x4xf32>
4490/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4491/// : vector<1x4xf32>, tensor<4x4xf32>
4492/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4493/// : vector<1x4xf32>, tensor<4x4xf32>
4494/// ```
4495///
4496/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4497/// any other uses.
4498class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4499public:
4500 using OpRewritePattern::OpRewritePattern;
4501 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4502 PatternRewriter &rewriter) const override {
4503 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
4504 return failure();
4505 vector::TransferWriteOp writeToModify = writeOp;
4506
4507 auto defWrite =
4508 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4509 while (defWrite) {
4510 if (checkSameValueWAW(writeOp, defWrite)) {
4511 rewriter.modifyOpInPlace(writeToModify, [&]() {
4512 writeToModify.getSourceMutable().assign(defWrite.getSource());
4513 });
4514 return success();
4515 }
4516 if (!isDisjointTransferIndices(
4517 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4518 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4519 break;
4520 // If the previous write op doesn't have any other use we an safely look
4521 // at the previous store to see if it can be removed.
4522 if (!defWrite->hasOneUse())
4523 break;
4524 writeToModify = defWrite;
4525 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4526 }
4527 return failure();
4528 }
4529};
4530
4531/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4532/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4533/// overwritten and inserted into another tensor. After this rewrite, the
4534/// operations bufferize in-place since all of them work on the same slice.
4535///
4536/// For example:
4537/// ```mlir
4538/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4539/// : vector<8x16xf32>, tensor<8x16xf32>
4540/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4541/// : tensor<8x16xf32> to tensor<?x?xf32>
4542/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4543/// : tensor<?x?xf32> into tensor<27x37xf32>
4544/// ```
4545/// folds to
4546/// ```mlir
4547/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4548/// : tensor<27x37xf32> to tensor<?x?xf32>
4549/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4550/// : vector<8x16xf32>, tensor<?x?xf32>
4551/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4552/// : tensor<?x?xf32> into tensor<27x37xf32>
4553/// ```
4554struct SwapExtractSliceOfTransferWrite
4555 : public OpRewritePattern<tensor::InsertSliceOp> {
4556public:
4557 using OpRewritePattern::OpRewritePattern;
4558
4559 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4560 PatternRewriter &rewriter) const override {
4561 if (!insertOp.hasUnitStride())
4562 return failure();
4563 auto extractOp =
4564 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4565 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4566 return failure();
4567 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4568 if (!transferOp || !transferOp->hasOneUse())
4569 return failure();
4570
4571 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4572 // rank-reducing.
4573 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4574 return rewriter.notifyMatchFailure(insertOp,
4575 "use-def chain is rank-reducing");
4576 }
4577
4578 // Fail if tensor::ExtractSliceOp has non-zero offset.
4579 if (!extractOp.hasZeroOffset()) {
4580 return rewriter.notifyMatchFailure(insertOp,
4581 "ExtractSliceOp has non-zero offset");
4582 }
4583
4584 // Fail if tensor::TransferWriteOp has non-zero offset.
4585 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4586 return getConstantIntValue(ofr: value) == static_cast<int64_t>(0);
4587 })) {
4588 return rewriter.notifyMatchFailure(insertOp,
4589 "TranferWriteOp has non-zero offset");
4590 }
4591
4592 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4593 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
4594 return rewriter.notifyMatchFailure(
4595 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
4596 }
4597
4598 for (auto [insertSize, extractSize] :
4599 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4600 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
4601 return rewriter.notifyMatchFailure(
4602 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
4603 }
4604 }
4605
4606 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4607 assert(transferOp.getVectorType().hasStaticShape() &&
4608 "expected vector to have a static shape");
4609 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
4610 SmallVector<int64_t> resultShape = applyPermutationMap(
4611 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4612 if (transferOp.getMask() || !vectorShape.equals(RHS: resultShape)) {
4613 return rewriter.notifyMatchFailure(
4614 insertOp, "TransferWriteOp may not write the full tensor.");
4615 }
4616
4617 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4618 // Set all in_bounds to false and let the folder infer them.
4619 SmallVector<bool> newInBounds(vectorShape.size(), false);
4620 auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
4621 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4622 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4623 insertOp.getMixedStrides());
4624 auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
4625 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4626 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4627 rewriter.getBoolArrayAttr(newInBounds));
4628 rewriter.modifyOpInPlace(insertOp, [&]() {
4629 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4630 });
4631 return success();
4632 }
4633};
4634
4635} // namespace
4636
4637void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
4638 MLIRContext *context) {
4639 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
4640}
4641
4642//===----------------------------------------------------------------------===//
4643// LoadOp
4644//===----------------------------------------------------------------------===//
4645
4646static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
4647 MemRefType memRefTy) {
4648 if (!isLastMemrefDimUnitStride(memRefTy))
4649 return op->emitOpError(message: "most minor memref dim must have unit stride");
4650 return success();
4651}
4652
4653LogicalResult vector::LoadOp::verify() {
4654 VectorType resVecTy = getVectorType();
4655 MemRefType memRefTy = getMemRefType();
4656
4657 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4658 return failure();
4659
4660 // Checks for vector memrefs.
4661 Type memElemTy = memRefTy.getElementType();
4662 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4663 if (memVecTy != resVecTy)
4664 return emitOpError("base memref and result vector types should match");
4665 memElemTy = memVecTy.getElementType();
4666 }
4667
4668 if (resVecTy.getElementType() != memElemTy)
4669 return emitOpError("base and result element types should match");
4670 if (llvm::size(getIndices()) != memRefTy.getRank())
4671 return emitOpError("requires ") << memRefTy.getRank() << " indices";
4672 return success();
4673}
4674
4675OpFoldResult LoadOp::fold(FoldAdaptor) {
4676 if (succeeded(memref::foldMemRefCast(*this)))
4677 return getResult();
4678 return OpFoldResult();
4679}
4680
4681//===----------------------------------------------------------------------===//
4682// StoreOp
4683//===----------------------------------------------------------------------===//
4684
4685LogicalResult vector::StoreOp::verify() {
4686 VectorType valueVecTy = getVectorType();
4687 MemRefType memRefTy = getMemRefType();
4688
4689 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4690 return failure();
4691
4692 // Checks for vector memrefs.
4693 Type memElemTy = memRefTy.getElementType();
4694 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
4695 if (memVecTy != valueVecTy)
4696 return emitOpError(
4697 "base memref and valueToStore vector types should match");
4698 memElemTy = memVecTy.getElementType();
4699 }
4700
4701 if (valueVecTy.getElementType() != memElemTy)
4702 return emitOpError("base and valueToStore element type should match");
4703 if (llvm::size(getIndices()) != memRefTy.getRank())
4704 return emitOpError("requires ") << memRefTy.getRank() << " indices";
4705 return success();
4706}
4707
4708LogicalResult StoreOp::fold(FoldAdaptor adaptor,
4709 SmallVectorImpl<OpFoldResult> &results) {
4710 return memref::foldMemRefCast(*this);
4711}
4712
4713//===----------------------------------------------------------------------===//
4714// MaskedLoadOp
4715//===----------------------------------------------------------------------===//
4716
4717LogicalResult MaskedLoadOp::verify() {
4718 VectorType maskVType = getMaskVectorType();
4719 VectorType passVType = getPassThruVectorType();
4720 VectorType resVType = getVectorType();
4721 MemRefType memType = getMemRefType();
4722
4723 if (resVType.getElementType() != memType.getElementType())
4724 return emitOpError("base and result element type should match");
4725 if (llvm::size(getIndices()) != memType.getRank())
4726 return emitOpError("requires ") << memType.getRank() << " indices";
4727 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4728 return emitOpError("expected result dim to match mask dim");
4729 if (resVType != passVType)
4730 return emitOpError("expected pass_thru of same type as result type");
4731 return success();
4732}
4733
4734namespace {
4735class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
4736public:
4737 using OpRewritePattern::OpRewritePattern;
4738 LogicalResult matchAndRewrite(MaskedLoadOp load,
4739 PatternRewriter &rewriter) const override {
4740 switch (getMaskFormat(load.getMask())) {
4741 case MaskFormat::AllTrue:
4742 rewriter.replaceOpWithNewOp<vector::LoadOp>(
4743 load, load.getType(), load.getBase(), load.getIndices());
4744 return success();
4745 case MaskFormat::AllFalse:
4746 rewriter.replaceOp(load, load.getPassThru());
4747 return success();
4748 case MaskFormat::Unknown:
4749 return failure();
4750 }
4751 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
4752 }
4753};
4754} // namespace
4755
4756void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4757 MLIRContext *context) {
4758 results.add<MaskedLoadFolder>(context);
4759}
4760
4761OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
4762 if (succeeded(memref::foldMemRefCast(*this)))
4763 return getResult();
4764 return OpFoldResult();
4765}
4766
4767//===----------------------------------------------------------------------===//
4768// MaskedStoreOp
4769//===----------------------------------------------------------------------===//
4770
4771LogicalResult MaskedStoreOp::verify() {
4772 VectorType maskVType = getMaskVectorType();
4773 VectorType valueVType = getVectorType();
4774 MemRefType memType = getMemRefType();
4775
4776 if (valueVType.getElementType() != memType.getElementType())
4777 return emitOpError("base and valueToStore element type should match");
4778 if (llvm::size(getIndices()) != memType.getRank())
4779 return emitOpError("requires ") << memType.getRank() << " indices";
4780 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4781 return emitOpError("expected valueToStore dim to match mask dim");
4782 return success();
4783}
4784
4785namespace {
4786class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
4787public:
4788 using OpRewritePattern::OpRewritePattern;
4789 LogicalResult matchAndRewrite(MaskedStoreOp store,
4790 PatternRewriter &rewriter) const override {
4791 switch (getMaskFormat(store.getMask())) {
4792 case MaskFormat::AllTrue:
4793 rewriter.replaceOpWithNewOp<vector::StoreOp>(
4794 store, store.getValueToStore(), store.getBase(), store.getIndices());
4795 return success();
4796 case MaskFormat::AllFalse:
4797 rewriter.eraseOp(op: store);
4798 return success();
4799 case MaskFormat::Unknown:
4800 return failure();
4801 }
4802 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
4803 }
4804};
4805} // namespace
4806
4807void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4808 MLIRContext *context) {
4809 results.add<MaskedStoreFolder>(context);
4810}
4811
4812LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
4813 SmallVectorImpl<OpFoldResult> &results) {
4814 return memref::foldMemRefCast(*this);
4815}
4816
4817//===----------------------------------------------------------------------===//
4818// GatherOp
4819//===----------------------------------------------------------------------===//
4820
4821LogicalResult GatherOp::verify() {
4822 VectorType indVType = getIndexVectorType();
4823 VectorType maskVType = getMaskVectorType();
4824 VectorType resVType = getVectorType();
4825 ShapedType baseType = getBaseType();
4826
4827 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
4828 return emitOpError("requires base to be a memref or ranked tensor type");
4829
4830 if (resVType.getElementType() != baseType.getElementType())
4831 return emitOpError("base and result element type should match");
4832 if (llvm::size(getIndices()) != baseType.getRank())
4833 return emitOpError("requires ") << baseType.getRank() << " indices";
4834 if (resVType.getShape() != indVType.getShape())
4835 return emitOpError("expected result dim to match indices dim");
4836 if (resVType.getShape() != maskVType.getShape())
4837 return emitOpError("expected result dim to match mask dim");
4838 if (resVType != getPassThruVectorType())
4839 return emitOpError("expected pass_thru of same type as result type");
4840 return success();
4841}
4842
4843// MaskableOpInterface methods.
4844
4845/// Returns the mask type expected by this operation. Mostly used for
4846/// verification purposes. It requires the operation to be vectorized."
4847Type GatherOp::getExpectedMaskType() {
4848 auto vecType = this->getIndexVectorType();
4849 return VectorType::get(vecType.getShape(),
4850 IntegerType::get(vecType.getContext(), /*width=*/1),
4851 vecType.getScalableDims());
4852}
4853
4854std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
4855 return llvm::to_vector<4>(getVectorType().getShape());
4856}
4857
4858namespace {
4859class GatherFolder final : public OpRewritePattern<GatherOp> {
4860public:
4861 using OpRewritePattern::OpRewritePattern;
4862 LogicalResult matchAndRewrite(GatherOp gather,
4863 PatternRewriter &rewriter) const override {
4864 switch (getMaskFormat(gather.getMask())) {
4865 case MaskFormat::AllTrue:
4866 return failure(); // no unmasked equivalent
4867 case MaskFormat::AllFalse:
4868 rewriter.replaceOp(gather, gather.getPassThru());
4869 return success();
4870 case MaskFormat::Unknown:
4871 return failure();
4872 }
4873 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
4874 }
4875};
4876} // namespace
4877
4878void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
4879 MLIRContext *context) {
4880 results.add<GatherFolder>(context);
4881}
4882
4883//===----------------------------------------------------------------------===//
4884// ScatterOp
4885//===----------------------------------------------------------------------===//
4886
4887LogicalResult ScatterOp::verify() {
4888 VectorType indVType = getIndexVectorType();
4889 VectorType maskVType = getMaskVectorType();
4890 VectorType valueVType = getVectorType();
4891 MemRefType memType = getMemRefType();
4892
4893 if (valueVType.getElementType() != memType.getElementType())
4894 return emitOpError("base and valueToStore element type should match");
4895 if (llvm::size(getIndices()) != memType.getRank())
4896 return emitOpError("requires ") << memType.getRank() << " indices";
4897 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
4898 return emitOpError("expected valueToStore dim to match indices dim");
4899 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4900 return emitOpError("expected valueToStore dim to match mask dim");
4901 return success();
4902}
4903
4904namespace {
4905class ScatterFolder final : public OpRewritePattern<ScatterOp> {
4906public:
4907 using OpRewritePattern::OpRewritePattern;
4908 LogicalResult matchAndRewrite(ScatterOp scatter,
4909 PatternRewriter &rewriter) const override {
4910 switch (getMaskFormat(scatter.getMask())) {
4911 case MaskFormat::AllTrue:
4912 return failure(); // no unmasked equivalent
4913 case MaskFormat::AllFalse:
4914 rewriter.eraseOp(op: scatter);
4915 return success();
4916 case MaskFormat::Unknown:
4917 return failure();
4918 }
4919 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
4920 }
4921};
4922} // namespace
4923
4924void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
4925 MLIRContext *context) {
4926 results.add<ScatterFolder>(context);
4927}
4928
4929//===----------------------------------------------------------------------===//
4930// ExpandLoadOp
4931//===----------------------------------------------------------------------===//
4932
4933LogicalResult ExpandLoadOp::verify() {
4934 VectorType maskVType = getMaskVectorType();
4935 VectorType passVType = getPassThruVectorType();
4936 VectorType resVType = getVectorType();
4937 MemRefType memType = getMemRefType();
4938
4939 if (resVType.getElementType() != memType.getElementType())
4940 return emitOpError("base and result element type should match");
4941 if (llvm::size(getIndices()) != memType.getRank())
4942 return emitOpError("requires ") << memType.getRank() << " indices";
4943 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4944 return emitOpError("expected result dim to match mask dim");
4945 if (resVType != passVType)
4946 return emitOpError("expected pass_thru of same type as result type");
4947 return success();
4948}
4949
4950namespace {
4951class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
4952public:
4953 using OpRewritePattern::OpRewritePattern;
4954 LogicalResult matchAndRewrite(ExpandLoadOp expand,
4955 PatternRewriter &rewriter) const override {
4956 switch (getMaskFormat(expand.getMask())) {
4957 case MaskFormat::AllTrue:
4958 rewriter.replaceOpWithNewOp<vector::LoadOp>(
4959 expand, expand.getType(), expand.getBase(), expand.getIndices());
4960 return success();
4961 case MaskFormat::AllFalse:
4962 rewriter.replaceOp(expand, expand.getPassThru());
4963 return success();
4964 case MaskFormat::Unknown:
4965 return failure();
4966 }
4967 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
4968 }
4969};
4970} // namespace
4971
4972void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4973 MLIRContext *context) {
4974 results.add<ExpandLoadFolder>(context);
4975}
4976
4977//===----------------------------------------------------------------------===//
4978// CompressStoreOp
4979//===----------------------------------------------------------------------===//
4980
4981LogicalResult CompressStoreOp::verify() {
4982 VectorType maskVType = getMaskVectorType();
4983 VectorType valueVType = getVectorType();
4984 MemRefType memType = getMemRefType();
4985
4986 if (valueVType.getElementType() != memType.getElementType())
4987 return emitOpError("base and valueToStore element type should match");
4988 if (llvm::size(getIndices()) != memType.getRank())
4989 return emitOpError("requires ") << memType.getRank() << " indices";
4990 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4991 return emitOpError("expected valueToStore dim to match mask dim");
4992 return success();
4993}
4994
4995namespace {
4996class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
4997public:
4998 using OpRewritePattern::OpRewritePattern;
4999 LogicalResult matchAndRewrite(CompressStoreOp compress,
5000 PatternRewriter &rewriter) const override {
5001 switch (getMaskFormat(compress.getMask())) {
5002 case MaskFormat::AllTrue:
5003 rewriter.replaceOpWithNewOp<vector::StoreOp>(
5004 compress, compress.getValueToStore(), compress.getBase(),
5005 compress.getIndices());
5006 return success();
5007 case MaskFormat::AllFalse:
5008 rewriter.eraseOp(op: compress);
5009 return success();
5010 case MaskFormat::Unknown:
5011 return failure();
5012 }
5013 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
5014 }
5015};
5016} // namespace
5017
5018void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5019 MLIRContext *context) {
5020 results.add<CompressStoreFolder>(context);
5021}
5022
5023//===----------------------------------------------------------------------===//
5024// ShapeCastOp
5025//===----------------------------------------------------------------------===//
5026
5027/// Returns true if each element of 'a' is equal to the product of a contiguous
5028/// sequence of the elements of 'b'. Returns false otherwise.
5029static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5030 unsigned rankA = a.size();
5031 unsigned rankB = b.size();
5032 assert(rankA < rankB);
5033
5034 auto isOne = [](int64_t v) { return v == 1; };
5035
5036 // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5037 // casted to a 0-d vector.
5038 if (rankA == 0 && llvm::all_of(Range&: b, P: isOne))
5039 return true;
5040
5041 unsigned i = 0;
5042 unsigned j = 0;
5043 while (i < rankA && j < rankB) {
5044 int64_t dimA = a[i];
5045 int64_t dimB = 1;
5046 while (dimB < dimA && j < rankB)
5047 dimB *= b[j++];
5048 if (dimA != dimB)
5049 break;
5050 ++i;
5051
5052 // Handle the case when trailing dimensions are of size 1.
5053 // Include them into the contiguous sequence.
5054 if (i < rankA && llvm::all_of(Range: a.slice(N: i), P: isOne))
5055 i = rankA;
5056 if (j < rankB && llvm::all_of(Range: b.slice(N: j), P: isOne))
5057 j = rankB;
5058 }
5059
5060 return i == rankA && j == rankB;
5061}
5062
5063static LogicalResult verifyVectorShapeCast(Operation *op,
5064 VectorType sourceVectorType,
5065 VectorType resultVectorType) {
5066 // Check that element type is the same.
5067 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5068 return op->emitOpError(message: "source/result vectors must have same element type");
5069 auto sourceShape = sourceVectorType.getShape();
5070 auto resultShape = resultVectorType.getShape();
5071
5072 // Check that product of source dim sizes matches product of result dim sizes.
5073 int64_t sourceDimProduct = std::accumulate(
5074 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5075 int64_t resultDimProduct = std::accumulate(
5076 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5077 if (sourceDimProduct != resultDimProduct)
5078 return op->emitOpError(message: "source/result number of elements must match");
5079
5080 // Check that expanding/contracting rank cases.
5081 unsigned sourceRank = sourceVectorType.getRank();
5082 unsigned resultRank = resultVectorType.getRank();
5083 if (sourceRank < resultRank) {
5084 if (!isValidShapeCast(sourceShape, resultShape))
5085 return op->emitOpError(message: "invalid shape cast");
5086 } else if (sourceRank > resultRank) {
5087 if (!isValidShapeCast(resultShape, sourceShape))
5088 return op->emitOpError(message: "invalid shape cast");
5089 }
5090 return success();
5091}
5092
5093LogicalResult ShapeCastOp::verify() {
5094 auto sourceVectorType =
5095 llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5096 auto resultVectorType =
5097 llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5098
5099 // Check if source/result are of vector type.
5100 if (sourceVectorType && resultVectorType)
5101 return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5102
5103 return success();
5104}
5105
5106OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
5107 // No-op shape cast.
5108 if (getSource().getType() == getResult().getType())
5109 return getSource();
5110
5111 // Canceling shape casts.
5112 if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5113 if (getResult().getType() == otherOp.getSource().getType())
5114 return otherOp.getSource();
5115
5116 // Only allows valid transitive folding.
5117 VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
5118 VectorType resultType = llvm::cast<VectorType>(getResult().getType());
5119 if (srcType.getRank() < resultType.getRank()) {
5120 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5121 return {};
5122 } else if (srcType.getRank() > resultType.getRank()) {
5123 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5124 return {};
5125 } else {
5126 return {};
5127 }
5128
5129 setOperand(otherOp.getSource());
5130 return getResult();
5131 }
5132
5133 // Cancelling broadcast and shape cast ops.
5134 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5135 if (bcastOp.getSourceType() == getType())
5136 return bcastOp.getSource();
5137 }
5138
5139 return {};
5140}
5141
5142namespace {
5143// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5144class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5145public:
5146 using OpRewritePattern::OpRewritePattern;
5147
5148 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5149 PatternRewriter &rewriter) const override {
5150 auto constantOp =
5151 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5152 if (!constantOp)
5153 return failure();
5154 // Only handle splat for now.
5155 auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5156 if (!dense)
5157 return failure();
5158 auto newAttr =
5159 DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5160 dense.getSplatValue<Attribute>());
5161 rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5162 return success();
5163 }
5164};
5165
5166/// Helper function that computes a new vector type based on the input vector
5167/// type by removing the trailing one dims:
5168///
5169/// vector<4x1x1xi1> --> vector<4x1>
5170///
5171static VectorType trimTrailingOneDims(VectorType oldType) {
5172 ArrayRef<int64_t> oldShape = oldType.getShape();
5173 ArrayRef<int64_t> newShape = oldShape;
5174
5175 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
5176 ArrayRef<bool> newScalableDims = oldScalableDims;
5177
5178 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
5179 newShape = newShape.drop_back(N: 1);
5180 newScalableDims = newScalableDims.drop_back(N: 1);
5181 }
5182
5183 // Make sure we have at least 1 dimension.
5184 // TODO: Add support for 0-D vectors.
5185 if (newShape.empty()) {
5186 newShape = oldShape.take_back();
5187 newScalableDims = oldScalableDims.take_back();
5188 }
5189
5190 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
5191}
5192
5193/// Folds qualifying shape_cast(create_mask) into a new create_mask
5194///
5195/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
5196/// dimension. If the input vector comes from `vector.create_mask` for which
5197/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5198/// to fold shape_cast into create_mask.
5199///
5200/// BEFORE:
5201/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
5202/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
5203/// AFTER:
5204/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
5205class ShapeCastCreateMaskFolderTrailingOneDim final
5206 : public OpRewritePattern<ShapeCastOp> {
5207public:
5208 using OpRewritePattern::OpRewritePattern;
5209
5210 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
5211 PatternRewriter &rewriter) const override {
5212 Value shapeOpSrc = shapeOp->getOperand(0);
5213 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
5214 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
5215 if (!createMaskOp && !constantMaskOp)
5216 return failure();
5217
5218 VectorType shapeOpResTy = shapeOp.getResultVectorType();
5219 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
5220
5221 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
5222 if (newVecType != shapeOpResTy)
5223 return failure();
5224
5225 auto numDimsToDrop =
5226 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
5227
5228 // No unit dims to drop
5229 if (!numDimsToDrop)
5230 return failure();
5231
5232 if (createMaskOp) {
5233 auto maskOperands = createMaskOp.getOperands();
5234 auto numMaskOperands = maskOperands.size();
5235
5236 // Check every mask dim size to see whether it can be dropped
5237 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5238 --i) {
5239 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
5240 if (!constant || (constant.value() != 1))
5241 return failure();
5242 }
5243 SmallVector<Value> newMaskOperands =
5244 maskOperands.drop_back(numDimsToDrop);
5245
5246 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
5247 newMaskOperands);
5248 return success();
5249 }
5250
5251 if (constantMaskOp) {
5252 auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
5253 auto numMaskOperands = maskDimSizes.size();
5254
5255 // Check every mask dim size to see whether it can be dropped
5256 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
5257 --i) {
5258 if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
5259 return failure();
5260 }
5261
5262 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5263 ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
5264
5265 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5266 newMaskOperandsAttr);
5267 return success();
5268 }
5269
5270 return failure();
5271 }
5272};
5273
5274/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
5275/// This only applies when the shape of the broadcast source
5276/// 1. is a suffix of the shape of the result (i.e. when broadcast without
5277/// reshape is expressive enough to capture the result in a single op), or
5278/// 2. has the same element count as the shape cast result.
5279class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
5280public:
5281 using OpRewritePattern::OpRewritePattern;
5282
5283 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5284 PatternRewriter &rewriter) const override {
5285 auto broadcastOp =
5286 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
5287 if (!broadcastOp)
5288 return failure();
5289
5290 ArrayRef<int64_t> broadcastSourceShape;
5291 if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()))
5292 broadcastSourceShape = srcType.getShape();
5293 ArrayRef<int64_t> shapeCastTargetShape =
5294 shapeCastOp.getResultVectorType().getShape();
5295
5296 // If `broadcastSourceShape` is a suffix of the result, we can just replace
5297 // with a broadcast to the final shape.
5298 if (broadcastSourceShape ==
5299 shapeCastTargetShape.take_back(N: broadcastSourceShape.size())) {
5300 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5301 shapeCastOp, shapeCastOp.getResultVectorType(),
5302 broadcastOp.getSource());
5303 return success();
5304 }
5305
5306 // Otherwise, if the final result has the same element count, we can replace
5307 // with a shape cast.
5308 if (auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
5309 if (srcType.getNumElements() ==
5310 shapeCastOp.getResultVectorType().getNumElements()) {
5311 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5312 shapeCastOp, shapeCastOp.getResultVectorType(),
5313 broadcastOp.getSource());
5314 return success();
5315 }
5316 }
5317
5318 return failure();
5319 }
5320};
5321
5322} // namespace
5323
5324void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
5325 MLIRContext *context) {
5326 results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5327 ShapeCastBroadcastFolder>(context);
5328}
5329
5330//===----------------------------------------------------------------------===//
5331// VectorBitCastOp
5332//===----------------------------------------------------------------------===//
5333
5334LogicalResult BitCastOp::verify() {
5335 auto sourceVectorType = getSourceVectorType();
5336 auto resultVectorType = getResultVectorType();
5337
5338 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
5339 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
5340 return emitOpError("dimension size mismatch at: ") << i;
5341 }
5342
5343 DataLayout dataLayout = DataLayout::closest(*this);
5344 auto sourceElementBits =
5345 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
5346 auto resultElementBits =
5347 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
5348
5349 if (sourceVectorType.getRank() == 0) {
5350 if (sourceElementBits != resultElementBits)
5351 return emitOpError("source/result bitwidth of the 0-D vector element "
5352 "types must be equal");
5353 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
5354 resultElementBits * resultVectorType.getShape().back()) {
5355 return emitOpError(
5356 "source/result bitwidth of the minor 1-D vectors must be equal");
5357 }
5358
5359 return success();
5360}
5361
5362OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
5363 // Nop cast.
5364 if (getSource().getType() == getResult().getType())
5365 return getSource();
5366
5367 // Canceling bitcasts.
5368 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
5369 if (getResult().getType() == otherOp.getSource().getType())
5370 return otherOp.getSource();
5371
5372 setOperand(otherOp.getSource());
5373 return getResult();
5374 }
5375
5376 Attribute sourceConstant = adaptor.getSource();
5377 if (!sourceConstant)
5378 return {};
5379
5380 Type srcElemType = getSourceVectorType().getElementType();
5381 Type dstElemType = getResultVectorType().getElementType();
5382
5383 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
5384 if (floatPack.isSplat()) {
5385 auto splat = floatPack.getSplatValue<FloatAttr>();
5386
5387 // Casting fp16 into fp32.
5388 if (srcElemType.isF16() && dstElemType.isF32()) {
5389 uint32_t bits = static_cast<uint32_t>(
5390 splat.getValue().bitcastToAPInt().getZExtValue());
5391 // Duplicate the 16-bit pattern.
5392 bits = (bits << 16) | (bits & 0xffff);
5393 APInt intBits(32, bits);
5394 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
5395 return DenseElementsAttr::get(getResultVectorType(), floatBits);
5396 }
5397 }
5398 }
5399
5400 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
5401 if (intPack.isSplat()) {
5402 auto splat = intPack.getSplatValue<IntegerAttr>();
5403
5404 if (llvm::isa<IntegerType>(dstElemType)) {
5405 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
5406 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
5407
5408 // Casting to a larger integer bit width.
5409 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
5410 APInt intBits = splat.getValue().zext(dstBitWidth);
5411
5412 // Duplicate the lower width element.
5413 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
5414 intBits = (intBits << srcBitWidth) | intBits;
5415 return DenseElementsAttr::get(getResultVectorType(), intBits);
5416 }
5417 }
5418 }
5419 }
5420
5421 return {};
5422}
5423
5424//===----------------------------------------------------------------------===//
5425// TypeCastOp
5426//===----------------------------------------------------------------------===//
5427
5428static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
5429 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
5430 SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
5431 memRefType.getShape().end());
5432 if (vectorType)
5433 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
5434 return res;
5435}
5436
5437/// Build the canonical memRefType with a single vector.
5438/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
5439void TypeCastOp::build(OpBuilder &builder, OperationState &result,
5440 Value source) {
5441 result.addOperands(source);
5442 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
5443 VectorType vectorType =
5444 VectorType::get(extractShape(memRefType),
5445 getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
5446 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
5447 memRefType.getMemorySpace()));
5448}
5449
5450LogicalResult TypeCastOp::verify() {
5451 MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
5452 if (!canonicalType.getLayout().isIdentity())
5453 return emitOpError("expects operand to be a memref with identity layout");
5454 if (!getResultMemRefType().getLayout().isIdentity())
5455 return emitOpError("expects result to be a memref with identity layout");
5456 if (getResultMemRefType().getMemorySpace() !=
5457 getMemRefType().getMemorySpace())
5458 return emitOpError("expects result in same memory space");
5459
5460 auto sourceType = getMemRefType();
5461 auto resultType = getResultMemRefType();
5462 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
5463 getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
5464 return emitOpError(
5465 "expects result and operand with same underlying scalar type: ")
5466 << resultType;
5467 if (extractShape(sourceType) != extractShape(resultType))
5468 return emitOpError(
5469 "expects concatenated result and operand shapes to be equal: ")
5470 << resultType;
5471 return success();
5472}
5473
5474//===----------------------------------------------------------------------===//
5475// TransposeOp
5476//===----------------------------------------------------------------------===//
5477
5478void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
5479 Value vector, ArrayRef<int64_t> permutation) {
5480 VectorType vt = llvm::cast<VectorType>(vector.getType());
5481 SmallVector<int64_t, 4> transposedShape(vt.getRank());
5482 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
5483 for (unsigned i = 0; i < permutation.size(); ++i) {
5484 transposedShape[i] = vt.getShape()[permutation[i]];
5485 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
5486 }
5487
5488 result.addOperands(vector);
5489 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
5490 transposedScalableDims));
5491 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
5492 builder.getDenseI64ArrayAttr(permutation));
5493}
5494
5495OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
5496 // Eliminate splat constant transpose ops.
5497 if (auto attr =
5498 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
5499 if (attr.isSplat())
5500 return attr.reshape(getResultVectorType());
5501
5502 // Eliminate identity transpose ops. This happens when the dimensions of the
5503 // input vector remain in their original order after the transpose operation.
5504 ArrayRef<int64_t> perm = getPermutation();
5505
5506 // Check if the permutation of the dimensions contains sequential values:
5507 // {0, 1, 2, ...}.
5508 for (int64_t i = 0, e = perm.size(); i < e; i++) {
5509 if (perm[i] != i)
5510 return {};
5511 }
5512
5513 return getVector();
5514}
5515
5516LogicalResult vector::TransposeOp::verify() {
5517 VectorType vectorType = getSourceVectorType();
5518 VectorType resultType = getResultVectorType();
5519 int64_t rank = resultType.getRank();
5520 if (vectorType.getRank() != rank)
5521 return emitOpError("vector result rank mismatch: ") << rank;
5522 // Verify transposition array.
5523 ArrayRef<int64_t> perm = getPermutation();
5524 int64_t size = perm.size();
5525 if (rank != size)
5526 return emitOpError("transposition length mismatch: ") << size;
5527 SmallVector<bool, 8> seen(rank, false);
5528 for (const auto &ta : llvm::enumerate(perm)) {
5529 if (ta.value() < 0 || ta.value() >= rank)
5530 return emitOpError("transposition index out of range: ") << ta.value();
5531 if (seen[ta.value()])
5532 return emitOpError("duplicate position index: ") << ta.value();
5533 seen[ta.value()] = true;
5534 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
5535 return emitOpError("dimension size mismatch at: ") << ta.value();
5536 }
5537 return success();
5538}
5539
5540std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5541 return llvm::to_vector<4>(getResultVectorType().getShape());
5542}
5543
5544namespace {
5545
5546// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5547class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
5548public:
5549 using OpRewritePattern::OpRewritePattern;
5550
5551 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5552 PatternRewriter &rewriter) const override {
5553 // Composes two permutations: result[i] = permutation1[permutation2[i]].
5554 auto composePermutations = [](ArrayRef<int64_t> permutation1,
5555 ArrayRef<int64_t> permutation2) {
5556 SmallVector<int64_t, 4> result;
5557 for (auto index : permutation2)
5558 result.push_back(Elt: permutation1[index]);
5559 return result;
5560 };
5561
5562 // Return if the input of 'transposeOp' is not defined by another transpose.
5563 vector::TransposeOp parentTransposeOp =
5564 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5565 if (!parentTransposeOp)
5566 return failure();
5567
5568 SmallVector<int64_t, 4> permutation = composePermutations(
5569 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
5570 // Replace 'transposeOp' with a new transpose operation.
5571 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
5572 transposeOp, transposeOp.getResult().getType(),
5573 parentTransposeOp.getVector(), permutation);
5574 return success();
5575 }
5576};
5577
5578// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
5579struct FoldTransposedScalarBroadcast final
5580 : public OpRewritePattern<vector::TransposeOp> {
5581 using OpRewritePattern::OpRewritePattern;
5582
5583 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5584 PatternRewriter &rewriter) const override {
5585 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5586 if (!bcastOp)
5587 return failure();
5588
5589 auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
5590 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5591 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5592 transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
5593 return success();
5594 }
5595
5596 return failure();
5597 }
5598};
5599
5600// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
5601class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
5602public:
5603 using OpRewritePattern::OpRewritePattern;
5604
5605 LogicalResult matchAndRewrite(TransposeOp transposeOp,
5606 PatternRewriter &rewriter) const override {
5607 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5608 if (!splatOp)
5609 return failure();
5610
5611 rewriter.replaceOpWithNewOp<vector::SplatOp>(
5612 transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
5613 return success();
5614 }
5615};
5616
5617/// Folds transpose(create_mask) into a new transposed create_mask.
5618class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
5619public:
5620 using OpRewritePattern::OpRewritePattern;
5621
5622 LogicalResult matchAndRewrite(TransposeOp transpOp,
5623 PatternRewriter &rewriter) const override {
5624 Value transposeSrc = transpOp.getVector();
5625 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
5626 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
5627 if (!createMaskOp && !constantMaskOp)
5628 return failure();
5629
5630 // Get the transpose permutation and apply it to the vector.create_mask or
5631 // vector.constant_mask operands.
5632 ArrayRef<int64_t> permutation = transpOp.getPermutation();
5633
5634 if (createMaskOp) {
5635 auto maskOperands = createMaskOp.getOperands();
5636 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
5637 applyPermutationToVector(inVec&: newOperands, permutation);
5638
5639 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
5640 transpOp, transpOp.getResultVectorType(), newOperands);
5641 return success();
5642 }
5643
5644 // ConstantMaskOp case.
5645 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5646 SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
5647 applyPermutationToVector(inVec&: newMaskDimSizes, permutation);
5648
5649 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
5650 transpOp, transpOp.getResultVectorType(),
5651 ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
5652 return success();
5653 }
5654};
5655
5656} // namespace
5657
5658void vector::TransposeOp::getCanonicalizationPatterns(
5659 RewritePatternSet &results, MLIRContext *context) {
5660 results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5661 TransposeFolder, FoldTransposeSplat>(context);
5662}
5663
5664//===----------------------------------------------------------------------===//
5665// ConstantMaskOp
5666//===----------------------------------------------------------------------===//
5667
5668LogicalResult ConstantMaskOp::verify() {
5669 auto resultType = llvm::cast<VectorType>(getResult().getType());
5670 // Check the corner case of 0-D vectors first.
5671 if (resultType.getRank() == 0) {
5672 if (getMaskDimSizes().size() != 1)
5673 return emitError("array attr must have length 1 for 0-D vectors");
5674 auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
5675 if (dim != 0 && dim != 1)
5676 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
5677 return success();
5678 }
5679
5680 // Verify that array attr size matches the rank of the vector result.
5681 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
5682 return emitOpError(
5683 "must specify array attr of size equal vector result rank");
5684 // Verify that each array attr element is in bounds of corresponding vector
5685 // result dimension size.
5686 auto resultShape = resultType.getShape();
5687 auto resultScalableDims = resultType.getScalableDims();
5688 SmallVector<int64_t, 4> maskDimSizes;
5689 for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
5690 int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5691 if (maskDimSize < 0 || maskDimSize > resultShape[index])
5692 return emitOpError(
5693 "array attr of size out of bounds of vector result dimension size");
5694 if (resultScalableDims[index] && maskDimSize != 0 &&
5695 maskDimSize != resultShape[index])
5696 return emitOpError(
5697 "only supports 'none set' or 'all set' scalable dimensions");
5698 maskDimSizes.push_back(maskDimSize);
5699 }
5700 // Verify that if one mask dim size is zero, they all should be zero (because
5701 // the mask region is a conjunction of each mask dimension interval).
5702 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
5703 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
5704 if (anyZeros && !allZeros)
5705 return emitOpError("expected all mask dim sizes to be zeros, "
5706 "as a result of conjunction with zero mask dim");
5707 return success();
5708}
5709
5710bool ConstantMaskOp::isAllOnesMask() {
5711 auto resultType = getVectorType();
5712 // Check the corner case of 0-D vectors first.
5713 if (resultType.getRank() == 0) {
5714 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
5715 return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
5716 }
5717 for (const auto [resultSize, intAttr] :
5718 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
5719 int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5720 if (maskDimSize < resultSize)
5721 return false;
5722 }
5723 return true;
5724}
5725
5726//===----------------------------------------------------------------------===//
5727// CreateMaskOp
5728//===----------------------------------------------------------------------===//
5729
5730void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
5731 VectorType type,
5732 ArrayRef<OpFoldResult> mixedOperands) {
5733 SmallVector<Value> operands =
5734 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
5735 build(builder, result, type, operands);
5736}
5737
5738LogicalResult CreateMaskOp::verify() {
5739 auto vectorType = llvm::cast<VectorType>(getResult().getType());
5740 // Verify that an operand was specified for each result vector each dimension.
5741 if (vectorType.getRank() == 0) {
5742 if (getNumOperands() != 1)
5743 return emitOpError(
5744 "must specify exactly one operand for 0-D create_mask");
5745 } else if (getNumOperands() !=
5746 llvm::cast<VectorType>(getResult().getType()).getRank()) {
5747 return emitOpError(
5748 "must specify an operand for each result vector dimension");
5749 }
5750 return success();
5751}
5752
5753namespace {
5754
5755/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
5756///
5757/// Ex 1:
5758/// %c2 = arith.constant 2 : index
5759/// %c3 = arith.constant 3 : index
5760/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
5761/// Becomes:
5762/// vector.constant_mask [3, 2] : vector<4x3xi1>
5763///
5764/// Ex 2:
5765/// %c_neg_1 = arith.constant -1 : index
5766/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
5767/// becomes:
5768/// vector.constant_mask [0] : vector<[8]xi1>
5769///
5770/// Ex 3:
5771/// %c8 = arith.constant 8 : index
5772/// %c16 = arith.constant 16 : index
5773/// %0 = vector.vscale
5774/// %1 = arith.muli %0, %c16 : index
5775/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
5776/// becomes:
5777/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
5778class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
5779public:
5780 using OpRewritePattern::OpRewritePattern;
5781
5782 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
5783 PatternRewriter &rewriter) const override {
5784 VectorType retTy = createMaskOp.getResult().getType();
5785 bool isScalable = retTy.isScalable();
5786
5787 // Check every mask operand
5788 for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
5789 if (auto cst = getConstantIntValue(operand)) {
5790 // Most basic case - this operand is a constant value. Note that for
5791 // scalable dimensions, CreateMaskOp can be folded only if the
5792 // corresponding operand is negative or zero.
5793 if (retTy.getScalableDims()[opIdx] && *cst > 0)
5794 return failure();
5795
5796 continue;
5797 }
5798
5799 // Non-constant operands are not allowed for non-scalable vectors.
5800 if (!isScalable)
5801 return failure();
5802
5803 // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
5804 // true" mask, so can also be treated as constant.
5805 auto mul = operand.getDefiningOp<arith::MulIOp>();
5806 if (!mul)
5807 return failure();
5808 auto mulLHS = mul.getRhs();
5809 auto mulRHS = mul.getLhs();
5810 bool isOneOpVscale =
5811 (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
5812 isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
5813
5814 auto isConstantValMatchingDim =
5815 [=, dim = retTy.getShape()[opIdx]](Value operand) {
5816 auto constantVal = getConstantIntValue(operand);
5817 return (constantVal.has_value() && constantVal.value() == dim);
5818 };
5819
5820 bool isOneOpConstantMatchingDim =
5821 isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
5822
5823 if (!isOneOpVscale || !isOneOpConstantMatchingDim)
5824 return failure();
5825 }
5826
5827 // Gather constant mask dimension sizes.
5828 SmallVector<int64_t, 4> maskDimSizes;
5829 maskDimSizes.reserve(N: createMaskOp->getNumOperands());
5830 for (auto [operand, maxDimSize] : llvm::zip_equal(
5831 createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5832 std::optional dimSize = getConstantIntValue(operand);
5833 if (!dimSize) {
5834 // Although not a constant, it is safe to assume that `operand` is
5835 // "vscale * maxDimSize".
5836 maskDimSizes.push_back(maxDimSize);
5837 continue;
5838 }
5839 int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
5840 // If one of dim sizes is zero, set all dims to zero.
5841 if (dimSize <= 0) {
5842 maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
5843 break;
5844 }
5845 maskDimSizes.push_back(dimSizeVal);
5846 }
5847
5848 // Replace 'createMaskOp' with ConstantMaskOp.
5849 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
5850 createMaskOp, retTy,
5851 vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
5852 return success();
5853 }
5854};
5855
5856} // namespace
5857
5858void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
5859 MLIRContext *context) {
5860 results.add<CreateMaskFolder>(context);
5861}
5862
5863//===----------------------------------------------------------------------===//
5864// MaskOp
5865//===----------------------------------------------------------------------===//
5866
5867void MaskOp::build(
5868 OpBuilder &builder, OperationState &result, Value mask,
5869 Operation *maskableOp,
5870 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
5871 assert(maskRegionBuilder &&
5872 "builder callback for 'maskRegion' must be present");
5873
5874 result.addOperands(mask);
5875 OpBuilder::InsertionGuard guard(builder);
5876 Region *maskRegion = result.addRegion();
5877 builder.createBlock(maskRegion);
5878 maskRegionBuilder(builder, maskableOp);
5879}
5880
5881void MaskOp::build(
5882 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
5883 Value mask, Operation *maskableOp,
5884 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
5885 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
5886 maskRegionBuilder);
5887}
5888
5889void MaskOp::build(
5890 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
5891 Value mask, Value passthru, Operation *maskableOp,
5892 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
5893 build(builder, result, mask, maskableOp, maskRegionBuilder);
5894 if (passthru)
5895 result.addOperands(passthru);
5896 result.addTypes(resultTypes);
5897}
5898
5899ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
5900 // Create the op region.
5901 result.regions.reserve(1);
5902 Region &maskRegion = *result.addRegion();
5903
5904 auto &builder = parser.getBuilder();
5905
5906 // Parse all the operands.
5907 OpAsmParser::UnresolvedOperand mask;
5908 if (parser.parseOperand(mask))
5909 return failure();
5910
5911 // Optional passthru operand.
5912 OpAsmParser::UnresolvedOperand passthru;
5913 ParseResult parsePassthru = parser.parseOptionalComma();
5914 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
5915 return failure();
5916
5917 // Parse op region.
5918 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
5919 return failure();
5920
5921 MaskOp::ensureTerminator(maskRegion, builder, result.location);
5922
5923 // Parse the optional attribute list.
5924 if (parser.parseOptionalAttrDict(result.attributes))
5925 return failure();
5926
5927 // Parse all the types.
5928 Type maskType;
5929 if (parser.parseColonType(maskType))
5930 return failure();
5931
5932 SmallVector<Type> resultTypes;
5933 if (parser.parseOptionalArrowTypeList(resultTypes))
5934 return failure();
5935 result.types.append(resultTypes);
5936
5937 // Resolve operands.
5938 if (parser.resolveOperand(mask, maskType, result.operands))
5939 return failure();
5940
5941 if (parsePassthru.succeeded())
5942 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
5943 return failure();
5944
5945 return success();
5946}
5947
5948void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
5949 p << " " << getMask();
5950 if (getPassthru())
5951 p << ", " << getPassthru();
5952
5953 // Print single masked operation and skip terminator.
5954 p << " { ";
5955 Block *singleBlock = &getMaskRegion().getBlocks().front();
5956 if (singleBlock && !singleBlock->getOperations().empty())
5957 p.printCustomOrGenericOp(&singleBlock->front());
5958 p << " }";
5959
5960 p.printOptionalAttrDict(getOperation()->getAttrs());
5961
5962 p << " : " << getMask().getType();
5963 if (getNumResults() > 0)
5964 p << " -> " << getResultTypes();
5965}
5966
5967void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
5968 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
5969 MaskOp>::ensureTerminator(region, builder, loc);
5970 // Keep the default yield terminator if the number of masked operations is not
5971 // the expected. This case will trigger a verification failure.
5972 Block &block = region.front();
5973 if (block.getOperations().size() != 2)
5974 return;
5975
5976 // Replace default yield terminator with a new one that returns the results
5977 // from the masked operation.
5978 OpBuilder opBuilder(builder.getContext());
5979 Operation *maskedOp = &block.front();
5980 Operation *oldYieldOp = &block.back();
5981 assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp");
5982
5983 // Empty vector.mask op.
5984 if (maskedOp == oldYieldOp)
5985 return;
5986
5987 opBuilder.setInsertionPoint(oldYieldOp);
5988 opBuilder.create<vector::YieldOp>(loc, maskedOp->getResults());
5989 oldYieldOp->dropAllReferences();
5990 oldYieldOp->erase();
5991}
5992
5993LogicalResult MaskOp::verify() {
5994 // Structural checks.
5995 Block &block = getMaskRegion().getBlocks().front();
5996 if (block.getOperations().empty())
5997 return emitOpError("expects a terminator within the mask region");
5998 if (block.getOperations().size() > 2)
5999 return emitOpError("expects only one operation to mask");
6000
6001 // Terminator checks.
6002 auto terminator = dyn_cast<vector::YieldOp>(block.back());
6003 if (!terminator)
6004 return emitOpError("expects a terminator within the mask region");
6005
6006 if (terminator->getNumOperands() != getNumResults())
6007 return emitOpError(
6008 "expects number of results to match mask region yielded values");
6009
6010 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
6011 // Empty vector.mask. Nothing else to check.
6012 if (!maskableOp)
6013 return success();
6014
6015 // Result checks.
6016 if (maskableOp->getNumResults() != getNumResults())
6017 return emitOpError("expects number of results to match maskable operation "
6018 "number of results");
6019
6020 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
6021 return emitOpError(
6022 "expects result type to match maskable operation result type");
6023
6024 if (llvm::count_if(maskableOp->getResultTypes(),
6025 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
6026 return emitOpError("multiple vector results not supported");
6027
6028 // Mask checks.
6029 Type expectedMaskType = maskableOp.getExpectedMaskType();
6030 if (getMask().getType() != expectedMaskType)
6031 return emitOpError("expects a ")
6032 << expectedMaskType << " mask for the maskable operation";
6033
6034 // Passthru checks.
6035 Value passthru = getPassthru();
6036 if (passthru) {
6037 if (!maskableOp.supportsPassthru())
6038 return emitOpError(
6039 "doesn't expect a passthru argument for this maskable operation");
6040
6041 if (maskableOp->getNumResults() != 1)
6042 return emitOpError("expects result when passthru argument is provided");
6043
6044 if (passthru.getType() != maskableOp->getResultTypes()[0])
6045 return emitOpError("expects passthru type to match result type");
6046 }
6047
6048 return success();
6049}
6050
6051/// Folds vector.mask ops with an all-true mask.
6052LogicalResult MaskOp::fold(FoldAdaptor adaptor,
6053 SmallVectorImpl<OpFoldResult> &results) {
6054 MaskFormat maskFormat = getMaskFormat(getMask());
6055 if (isEmpty())
6056 return failure();
6057
6058 if (maskFormat != MaskFormat::AllTrue)
6059 return failure();
6060
6061 // Move maskable operation outside of the `vector.mask` region.
6062 Operation *maskableOp = getMaskableOp();
6063 maskableOp->dropAllUses();
6064 maskableOp->moveBefore(getOperation());
6065
6066 llvm::append_range(results, maskableOp->getResults());
6067 return success();
6068}
6069
6070// Elides empty vector.mask operations with or without return values. Propagates
6071// the yielded values by the vector.yield terminator, if any, or erases the op,
6072// otherwise.
6073class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
6074 using OpRewritePattern::OpRewritePattern;
6075
6076 LogicalResult matchAndRewrite(MaskOp maskOp,
6077 PatternRewriter &rewriter) const override {
6078 auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
6079 if (maskingOp.getMaskableOp())
6080 return failure();
6081
6082 if (!maskOp.isEmpty())
6083 return failure();
6084
6085 Block *block = maskOp.getMaskBlock();
6086 auto terminator = cast<vector::YieldOp>(block->front());
6087 if (terminator.getNumOperands() == 0)
6088 rewriter.eraseOp(op: maskOp);
6089 else
6090 rewriter.replaceOp(maskOp, terminator.getOperands());
6091
6092 return success();
6093 }
6094};
6095
6096void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
6097 MLIRContext *context) {
6098 results.add<ElideEmptyMaskOp>(context);
6099}
6100
6101// MaskingOpInterface definitions.
6102
6103/// Returns the operation masked by this 'vector.mask'.
6104Operation *MaskOp::getMaskableOp() {
6105 Block *block = getMaskBlock();
6106 if (block->getOperations().size() < 2)
6107 return nullptr;
6108
6109 return &block->front();
6110}
6111
6112/// Returns true if 'vector.mask' has a passthru value.
6113bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
6114
6115//===----------------------------------------------------------------------===//
6116// ScanOp
6117//===----------------------------------------------------------------------===//
6118
6119LogicalResult ScanOp::verify() {
6120 VectorType srcType = getSourceType();
6121 VectorType initialType = getInitialValueType();
6122 // Check reduction dimension < rank.
6123 int64_t srcRank = srcType.getRank();
6124 int64_t reductionDim = getReductionDim();
6125 if (reductionDim >= srcRank)
6126 return emitOpError("reduction dimension ")
6127 << reductionDim << " has to be less than " << srcRank;
6128
6129 // Check that rank(initial_value) = rank(src) - 1.
6130 int64_t initialValueRank = initialType.getRank();
6131 if (initialValueRank != srcRank - 1)
6132 return emitOpError("initial value rank ")
6133 << initialValueRank << " has to be equal to " << srcRank - 1;
6134
6135 // Check shapes of initial value and src.
6136 ArrayRef<int64_t> srcShape = srcType.getShape();
6137 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
6138 SmallVector<int64_t> expectedShape;
6139 for (int i = 0; i < srcRank; i++) {
6140 if (i != reductionDim)
6141 expectedShape.push_back(srcShape[i]);
6142 }
6143 if (!llvm::equal(initialValueShapes, expectedShape)) {
6144 return emitOpError("incompatible input/initial value shapes");
6145 }
6146
6147 // Verify supported reduction kind.
6148 Type eltType = getDestType().getElementType();
6149 if (!isSupportedCombiningKind(getKind(), eltType))
6150 return emitOpError("unsupported reduction type ")
6151 << eltType << " for kind '" << stringifyCombiningKind(getKind())
6152 << "'";
6153
6154 return success();
6155}
6156
6157void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
6158 RewritePatternSet &patterns, PatternBenefit benefit) {
6159 patterns
6160 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
6161 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
6162 StridedSliceConstantMaskFolder, TransposeFolder>(
6163 arg: patterns.getContext(), args&: benefit);
6164}
6165
6166//===----------------------------------------------------------------------===//
6167// SplatOp
6168//===----------------------------------------------------------------------===//
6169
6170OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
6171 auto constOperand = adaptor.getInput();
6172 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
6173 return {};
6174
6175 // SplatElementsAttr::get treats single value for second arg as being a splat.
6176 return SplatElementsAttr::get(getType(), {constOperand});
6177}
6178
6179//===----------------------------------------------------------------------===//
6180// WarpExecuteOnLane0Op
6181//===----------------------------------------------------------------------===//
6182
6183void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
6184 p << "(" << getLaneid() << ")";
6185
6186 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
6187 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
6188 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
6189
6190 if (!getArgs().empty())
6191 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
6192 if (!getResults().empty())
6193 p << " -> (" << getResults().getTypes() << ')';
6194 p << " ";
6195 p.printRegion(getRegion(),
6196 /*printEntryBlockArgs=*/true,
6197 /*printBlockTerminators=*/!getResults().empty());
6198 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
6199}
6200
6201ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
6202 OperationState &result) {
6203 // Create the region.
6204 result.regions.reserve(1);
6205 Region *warpRegion = result.addRegion();
6206
6207 auto &builder = parser.getBuilder();
6208 OpAsmParser::UnresolvedOperand laneId;
6209
6210 // Parse predicate operand.
6211 if (parser.parseLParen() ||
6212 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
6213 parser.parseRParen())
6214 return failure();
6215
6216 int64_t warpSize;
6217 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
6218 parser.parseRSquare())
6219 return failure();
6220 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
6221 builder.getContext())),
6222 builder.getI64IntegerAttr(warpSize));
6223
6224 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
6225 return failure();
6226
6227 llvm::SMLoc inputsOperandsLoc;
6228 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
6229 SmallVector<Type> inputTypes;
6230 if (succeeded(parser.parseOptionalKeyword("args"))) {
6231 if (parser.parseLParen())
6232 return failure();
6233
6234 inputsOperandsLoc = parser.getCurrentLocation();
6235 if (parser.parseOperandList(inputsOperands) ||
6236 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
6237 return failure();
6238 }
6239 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
6240 result.operands))
6241 return failure();
6242
6243 // Parse optional results type list.
6244 if (parser.parseOptionalArrowTypeList(result.types))
6245 return failure();
6246 // Parse the region.
6247 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
6248 /*argTypes=*/{}))
6249 return failure();
6250 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
6251
6252 // Parse the optional attribute list.
6253 if (parser.parseOptionalAttrDict(result.attributes))
6254 return failure();
6255 return success();
6256}
6257
6258void WarpExecuteOnLane0Op::getSuccessorRegions(
6259 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
6260 if (!point.isParent()) {
6261 regions.push_back(RegionSuccessor(getResults()));
6262 return;
6263 }
6264
6265 // The warp region is always executed
6266 regions.push_back(RegionSuccessor(&getWarpRegion()));
6267}
6268
6269void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
6270 TypeRange resultTypes, Value laneId,
6271 int64_t warpSize) {
6272 build(builder, result, resultTypes, laneId, warpSize,
6273 /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
6274}
6275
6276void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
6277 TypeRange resultTypes, Value laneId,
6278 int64_t warpSize, ValueRange args,
6279 TypeRange blockArgTypes) {
6280 result.addOperands(laneId);
6281 result.addAttribute(getAttributeNames()[0],
6282 builder.getI64IntegerAttr(warpSize));
6283 result.addTypes(resultTypes);
6284 result.addOperands(args);
6285 assert(args.size() == blockArgTypes.size());
6286 OpBuilder::InsertionGuard guard(builder);
6287 Region *warpRegion = result.addRegion();
6288 Block *block = builder.createBlock(warpRegion);
6289 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
6290 block->addArgument(type, arg.getLoc());
6291}
6292
6293/// Helper check if the distributed vector type is consistent with the expanded
6294/// type and distributed size.
6295static LogicalResult verifyDistributedType(Type expanded, Type distributed,
6296 int64_t warpSize, Operation *op) {
6297 // If the types matches there is no distribution.
6298 if (expanded == distributed)
6299 return success();
6300 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
6301 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
6302 if (!expandedVecType || !distributedVecType)
6303 return op->emitOpError(message: "expected vector type for distributed operands.");
6304 if (expandedVecType.getRank() != distributedVecType.getRank() ||
6305 expandedVecType.getElementType() != distributedVecType.getElementType())
6306 return op->emitOpError(
6307 message: "expected distributed vectors to have same rank and element type.");
6308
6309 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
6310 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
6311 int64_t eDim = expandedVecType.getDimSize(i);
6312 int64_t dDim = distributedVecType.getDimSize(i);
6313 if (eDim == dDim)
6314 continue;
6315 if (eDim % dDim != 0)
6316 return op->emitOpError()
6317 << "expected expanded vector dimension #" << i << " (" << eDim
6318 << ") to be a multipler of the distributed vector dimension ("
6319 << dDim << ")";
6320 scales[i] = eDim / dDim;
6321 }
6322 if (std::accumulate(first: scales.begin(), last: scales.end(), init: 1,
6323 binary_op: std::multiplies<int64_t>()) != warpSize)
6324 return op->emitOpError()
6325 << "incompatible distribution dimensions from " << expandedVecType
6326 << " to " << distributedVecType << " with warp size = " << warpSize;
6327
6328 return success();
6329}
6330
6331LogicalResult WarpExecuteOnLane0Op::verify() {
6332 if (getArgs().size() != getWarpRegion().getNumArguments())
6333 return emitOpError(
6334 "expected same number op arguments and block arguments.");
6335 auto yield =
6336 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
6337 if (yield.getNumOperands() != getNumResults())
6338 return emitOpError(
6339 "expected same number of yield operands and return values.");
6340 int64_t warpSize = getWarpSize();
6341 for (auto [regionArg, arg] :
6342 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
6343 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
6344 warpSize, getOperation())))
6345 return failure();
6346 }
6347 for (auto [yieldOperand, result] :
6348 llvm::zip_equal(yield.getOperands(), getResults())) {
6349 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
6350 warpSize, getOperation())))
6351 return failure();
6352 }
6353 return success();
6354}
6355
6356bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
6357 return succeeded(
6358 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
6359}
6360
6361Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
6362 CombiningKind kind, Value v1, Value acc,
6363 arith::FastMathFlagsAttr fastmath,
6364 Value mask) {
6365 Type t1 = getElementTypeOrSelf(type: v1.getType());
6366 Type tAcc = getElementTypeOrSelf(type: acc.getType());
6367 Value result;
6368
6369 switch (kind) {
6370 case CombiningKind::ADD:
6371 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6372 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
6373 else if (llvm::isa<FloatType>(Val: t1) && llvm::isa<FloatType>(Val: tAcc))
6374 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
6375 else
6376 llvm_unreachable("invalid value types for ADD reduction");
6377 break;
6378 case CombiningKind::AND:
6379 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6380 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
6381 break;
6382 case CombiningKind::MAXNUMF:
6383 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6384 "expected float values");
6385 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
6386 break;
6387 case CombiningKind::MAXIMUMF:
6388 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6389 "expected float values");
6390 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
6391 break;
6392 case CombiningKind::MINNUMF:
6393 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6394 "expected float values");
6395 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
6396 break;
6397 case CombiningKind::MINIMUMF:
6398 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
6399 "expected float values");
6400 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
6401 break;
6402 case CombiningKind::MAXSI:
6403 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6404 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
6405 break;
6406 case CombiningKind::MINSI:
6407 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6408 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
6409 break;
6410 case CombiningKind::MAXUI:
6411 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6412 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
6413 break;
6414 case CombiningKind::MINUI:
6415 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6416 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
6417 break;
6418 case CombiningKind::MUL:
6419 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
6420 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
6421 else if (llvm::isa<FloatType>(Val: t1) && llvm::isa<FloatType>(Val: tAcc))
6422 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
6423 else
6424 llvm_unreachable("invalid value types for MUL reduction");
6425 break;
6426 case CombiningKind::OR:
6427 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6428 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
6429 break;
6430 case CombiningKind::XOR:
6431 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
6432 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
6433 break;
6434 };
6435
6436 assert(result && "unknown CombiningKind");
6437 return selectPassthru(builder&: b, mask, newValue: result, passthru: acc);
6438}
6439
6440//===----------------------------------------------------------------------===//
6441// Vector Masking Utilities
6442//===----------------------------------------------------------------------===//
6443
6444/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
6445/// as masked operation.
6446void mlir::vector::createMaskOpRegion(OpBuilder &builder,
6447 Operation *maskableOp) {
6448 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
6449 Block *insBlock = builder.getInsertionBlock();
6450 // Create a block and move the op to that block.
6451 insBlock->getOperations().splice(
6452 where: insBlock->begin(), L2&: maskableOp->getBlock()->getOperations(), N: maskableOp);
6453 builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
6454}
6455
6456/// Creates a vector.mask operation around a maskable operation. Returns the
6457/// vector.mask operation if the mask provided is valid. Otherwise, returns
6458/// the maskable operation itself.
6459Operation *mlir::vector::maskOperation(OpBuilder &builder,
6460 Operation *maskableOp, Value mask,
6461 Value passthru) {
6462 if (!mask)
6463 return maskableOp;
6464 if (passthru)
6465 return builder.create<MaskOp>(maskableOp->getLoc(),
6466 maskableOp->getResultTypes(), mask, passthru,
6467 maskableOp, createMaskOpRegion);
6468 return builder.create<MaskOp>(maskableOp->getLoc(),
6469 maskableOp->getResultTypes(), mask, maskableOp,
6470 createMaskOpRegion);
6471}
6472
6473/// Creates a vector select operation that picks values from `newValue` or
6474/// `passthru` for each result vector lane based on `mask`. This utility is used
6475/// to propagate the pass-thru value of vector.mask or for cases where only the
6476/// pass-thru value propagation is needed. VP intrinsics do not support
6477/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6478/// usually able to match op + select patterns and fold them into a native
6479/// target instructions.
6480Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
6481 Value newValue, Value passthru) {
6482 if (!mask)
6483 return newValue;
6484
6485 return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
6486 mask, newValue, passthru);
6487}
6488
6489//===----------------------------------------------------------------------===//
6490// TableGen'd op method definitions
6491//===----------------------------------------------------------------------===//
6492
6493#define GET_ATTRDEF_CLASSES
6494#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
6495
6496#define GET_OP_CLASSES
6497#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
6498

source code of mlir/lib/Dialect/Vector/IR/VectorOps.cpp