1//===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.contract' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26#include "mlir/IR/BuiltinAttributeInterfaces.h"
27#include "mlir/IR/BuiltinTypes.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/IR/Location.h"
30#include "mlir/IR/Matchers.h"
31#include "mlir/IR/PatternMatch.h"
32#include "mlir/IR/TypeUtilities.h"
33#include "mlir/Interfaces/VectorInterfaces.h"
34#include "mlir/Support/LogicalResult.h"
35
36#define DEBUG_TYPE "vector-contract-lowering"
37
38using namespace mlir;
39using namespace mlir::vector;
40
41//===----------------------------------------------------------------------===//
42// Helper functions
43//===----------------------------------------------------------------------===//
44// Helper to find an index in an affine map.
45static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
46 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
47 int64_t idx = map.getDimPosition(idx: i);
48 if (idx == index)
49 return i;
50 }
51 return std::nullopt;
52}
53
54// Helper to construct iterator types with one index removed.
55static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
56 int64_t index) {
57 SmallVector<Attribute> results;
58 for (const auto &it : llvm::enumerate(iteratorTypes)) {
59 int64_t idx = it.index();
60 if (idx == index)
61 continue;
62 results.push_back(it.value());
63 }
64 return results;
65}
66
67// Helper to construct an affine map with one index removed.
68static AffineMap adjustMap(AffineMap map, int64_t index,
69 PatternRewriter &rewriter) {
70 auto *ctx = rewriter.getContext();
71 SmallVector<AffineExpr> results;
72 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
73 int64_t idx = map.getDimPosition(idx: i);
74 if (idx == index)
75 continue;
76 // Re-insert remaining indices, but renamed when occurring
77 // after the removed index.
78 auto targetExpr = getAffineDimExpr(position: idx < index ? idx : idx - 1, context: ctx);
79 results.push_back(Elt: targetExpr);
80 }
81 return AffineMap::get(dimCount: map.getNumDims() - 1, symbolCount: 0, results, context: ctx);
82}
83
84// Helper method to possibly drop a dimension in a load.
85// TODO
86static Value reshapeLoad(Location loc, Value val, VectorType type,
87 int64_t index, int64_t pos,
88 PatternRewriter &rewriter) {
89 if (index == -1)
90 return val;
91
92 // At extraction dimension?
93 if (index == 0)
94 return rewriter.create<vector::ExtractOp>(loc, val, pos);
95
96 // Unroll leading dimensions.
97 VectorType vType = VectorType::Builder(type).dropDim(0);
98 VectorType resType = VectorType::Builder(type).dropDim(index);
99 Value result = rewriter.create<arith::ConstantOp>(
100 loc, resType, rewriter.getZeroAttr(resType));
101 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
102 Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
103 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
104 result = rewriter.create<vector::InsertOp>(loc, load, result, d);
105 }
106 return result;
107}
108
109// Helper method to possibly drop a dimension in a store.
110// TODO
111static Value reshapeStore(Location loc, Value val, Value result,
112 VectorType type, int64_t index, int64_t pos,
113 PatternRewriter &rewriter) {
114 // Unmodified?
115 if (index == -1)
116 return val;
117 // At insertion dimension?
118 if (index == 0)
119 return rewriter.create<vector::InsertOp>(loc, val, result, pos);
120
121 // Unroll leading dimensions.
122 VectorType vType = VectorType::Builder(type).dropDim(0);
123 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
124 Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
125 Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
126 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
127 result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
128 }
129 return result;
130}
131
132/// Helper to create arithmetic operation associated with a kind of contraction.
133static std::optional<Value>
134createContractArithOp(Location loc, Value x, Value y, Value acc,
135 vector::CombiningKind kind, PatternRewriter &rewriter,
136 bool isInt, Value mask = Value()) {
137 using vector::CombiningKind;
138 Value mul;
139
140 if (isInt) {
141 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
142 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
143 // Only valid for floating point types.
144 return std::nullopt;
145 mul = rewriter.create<arith::MulIOp>(loc, x, y);
146 } else {
147 // Float case.
148 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
149 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
150 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
151 kind == CombiningKind::XOR)
152 // Only valid for integer types.
153 return std::nullopt;
154 // Special case for fused multiply-add.
155 if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
156 Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
157 if (mask)
158 // The fma op doesn't need explicit masking. However, fma ops used in
159 // reductions must preserve previous 'acc' values for masked-out lanes.
160 fma = selectPassthru(builder&: rewriter, mask, newValue: fma, passthru: acc);
161 return fma;
162 }
163 mul = rewriter.create<arith::MulFOp>(loc, x, y);
164 }
165
166 if (!acc)
167 return std::optional<Value>(mul);
168
169 return makeArithReduction(rewriter, loc, kind, mul, acc,
170 /*fastmath=*/nullptr, mask);
171}
172
173/// Return the positions of the reductions in the given map.
174static SmallVector<int64_t> getReductionIndex(AffineMap map,
175 ArrayAttr iteratorTypes) {
176 SmallVector<int64_t> dimsIdx;
177 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
178 if (isReductionIterator(iteratorTypes[map.getDimPosition(idx: i)]))
179 dimsIdx.push_back(Elt: i);
180 }
181 return dimsIdx;
182}
183
184/// Look for a given dimension in an affine map and return its position. Return
185/// std::nullopt if the dimension is not in the map results.
186static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
187 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
188 if (map.getDimPosition(idx: i) == dim)
189 return i;
190 }
191 return std::nullopt;
192}
193
194/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
195/// operands `x` and `y`.
196static Value createAdd(Location loc, Value x, Value y, bool isInt,
197 PatternRewriter &rewriter) {
198 if (isInt)
199 return rewriter.create<arith::AddIOp>(loc, x, y);
200 return rewriter.create<arith::AddFOp>(loc, x, y);
201}
202
203/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
204/// operands `x and `y`.
205static Value createMul(Location loc, Value x, Value y, bool isInt,
206 PatternRewriter &rewriter) {
207 if (isInt)
208 return rewriter.create<arith::MulIOp>(loc, x, y);
209 return rewriter.create<arith::MulFOp>(loc, x, y);
210}
211
212namespace {
213
214/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
215/// semantics to:
216/// ```
217/// %flattened_a = vector.shape_cast %a
218/// %flattened_b = vector.shape_cast %b
219/// %flattened_d = vector.matmul %flattened_a, %flattened_b
220/// %d = vector.shape_cast %%flattened_d
221/// %e = add %c, %d
222/// ```
223/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
224//
225/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
226/// the vector.contract op is a row-major matrix multiply.
227class ContractionOpToMatmulOpLowering
228 : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
229public:
230 using MaskableOpRewritePattern::MaskableOpRewritePattern;
231
232 using FilterConstraintType =
233 std::function<LogicalResult(vector::ContractionOp op)>;
234
235 static LogicalResult defaultFilter(vector::ContractionOp op) {
236 return success();
237 }
238
239 ContractionOpToMatmulOpLowering(
240 vector::VectorTransformsOptions vectorTransformOptions,
241 MLIRContext *context, PatternBenefit benefit = 1,
242 FilterConstraintType constraint = defaultFilter)
243 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
244 vectorTransformOptions(vectorTransformOptions),
245 filter(std::move(constraint)) {}
246
247 FailureOr<Value>
248 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
249 PatternRewriter &rewriter) const override;
250
251private:
252 /// Options to control the vector patterns.
253 vector::VectorTransformsOptions vectorTransformOptions;
254 FilterConstraintType filter;
255};
256
257/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
258/// semantics to a reduction_size-unrolled sequence:
259/// ```
260/// %at = vector.transpose %a, [1, 0]
261/// %bRow0 = vector.extract %b[0]
262/// %atRow0 = vector.extract %at[0]
263/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
264/// ...
265/// %bRowK = vector.extract %b[K]
266/// %atRowK = vector.extract %at[K]
267/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
268/// ```
269///
270/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
271/// the vector.contract op is a row-major matrix multiply.
272class ContractionOpToOuterProductOpLowering
273 : public MaskableOpRewritePattern<vector::ContractionOp> {
274public:
275 using MaskableOpRewritePattern::MaskableOpRewritePattern;
276
277 using FilterConstraintType =
278 std::function<LogicalResult(vector::ContractionOp op)>;
279
280 static LogicalResult defaultFilter(vector::ContractionOp op) {
281 return success();
282 }
283
284 ContractionOpToOuterProductOpLowering(
285 vector::VectorTransformsOptions vectorTransformOptions,
286 MLIRContext *context, PatternBenefit benefit = 1,
287 FilterConstraintType constraint = defaultFilter)
288 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
289 vectorTransformOptions(vectorTransformOptions),
290 filter(std::move(constraint)) {}
291
292 FailureOr<Value>
293 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
294 PatternRewriter &rewriter) const override;
295
296private:
297 /// Options to control the vector patterns.
298 vector::VectorTransformsOptions vectorTransformOptions;
299 FilterConstraintType filter;
300};
301
302/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
303/// semantics to an output-size-unrolled sequence:
304/// ```
305/// %out = arith.constant ... : vector<MxNxelt_type>
306/// %bt = vector.transpose %b, [1, 0]
307/// %aRow0 = vector.extract %a[0]
308/// %btRow0 = vector.extract %bt[0]
309/// %c00 = vector.reduce %atRow0, %bRow0
310/// %out00 = vector.insert %c00, %out[0, 0]
311/// ...
312/// %aRowLast = vector.extract %at[M-1]
313/// %btRowLast = vector.extract %b[N-1]
314/// %cLastLast = vector.reduce %atRowLast, %bRowLast
315/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
316/// ```
317///
318/// This only kicks in when VectorTransformsOptions is set to Dot and
319/// the vector.contract op is a row-major matmul or matvec.
320class ContractionOpToDotLowering
321 : public MaskableOpRewritePattern<vector::ContractionOp> {
322public:
323 using MaskableOpRewritePattern::MaskableOpRewritePattern;
324
325 using FilterConstraintType =
326 std::function<LogicalResult(vector::ContractionOp op)>;
327
328 static LogicalResult defaultFilter(vector::ContractionOp op) {
329 return success();
330 }
331
332 ContractionOpToDotLowering(
333 vector::VectorTransformsOptions vectorTransformOptions,
334 MLIRContext *context, PatternBenefit benefit = 1,
335 const FilterConstraintType &constraint = defaultFilter)
336 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
337 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
338
339 FailureOr<Value>
340 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
341 PatternRewriter &rewriter) const override;
342
343private:
344 /// Options to control the vector patterns.
345 vector::VectorTransformsOptions vectorTransformOptions;
346 FilterConstraintType filter;
347};
348
349/// Progressive lowering of ContractionOp.
350///
351/// One:
352/// %x = vector.contract with at least one free/batch dimension
353/// is replaced by:
354/// %a = vector.contract with one less free/batch dimension
355/// %b = vector.contract with one less free/batch dimension
356/// ..
357/// %x = combine %a %b ..
358/// until a pure contraction is reached (no free/batch dimensions),
359/// which is replaced by a dot-product.
360///
361/// This only kicks in when either VectorTransformsOptions is set
362/// to Dot or when other contraction patterns fail.
363class ContractionOpLowering
364 : public MaskableOpRewritePattern<vector::ContractionOp> {
365public:
366 using MaskableOpRewritePattern::MaskableOpRewritePattern;
367 using FilterConstraintType =
368 std::function<LogicalResult(vector::ContractionOp op)>;
369
370 static LogicalResult defaultFilter(vector::ContractionOp op) {
371 return success();
372 }
373
374 ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
375 MLIRContext *context, PatternBenefit benefit = 1,
376 FilterConstraintType constraint = defaultFilter)
377 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
378 vectorTransformOptions(vectorTransformOptions),
379 filter(std::move(constraint)) {}
380
381 FailureOr<Value>
382 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
383 PatternRewriter &rewriter) const override;
384
385private:
386 /// Options to control the vector patterns.
387 vector::VectorTransformsOptions vectorTransformOptions;
388 FilterConstraintType filter;
389 // Lower one parallel dimension.
390 FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
391 vector::ContractionOp op, int64_t lhsIndex,
392 int64_t rhsIndex, Value mask) const;
393 // Lower one reduction dimension.
394 FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
395 vector::ContractionOp op, Value mask) const;
396};
397
398/// Generate a vector implementation for matmat, matvec and tmatvec.
399/// This unrolls outer-products along the reduction dimension.
400struct UnrolledOuterProductGenerator
401 : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
402 UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
403 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
404 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
405 res(op.getAcc()), lhsType(op.getLhsType()) {
406 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
407 if (maskableOp.isMasked())
408 mask = maskableOp.getMaskingOp().getMask();
409 }
410
411 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
412 if (!v)
413 return v;
414 return rewriter.create<vector::TransposeOp>(loc, v, perm);
415 }
416
417 Value promote(Value v, Type dstElementType) {
418 Type elementType = v.getType();
419 auto vecType = dyn_cast<VectorType>(elementType);
420 if (vecType)
421 elementType = vecType.getElementType();
422 if (elementType == dstElementType)
423 return v;
424 Type promotedType = dstElementType;
425 if (vecType)
426 promotedType = vecType.clone(promotedType);
427 if (isa<FloatType>(dstElementType))
428 return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
429 return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
430 }
431
432 FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
433 VectorType lhsType, int reductionSize,
434 std::optional<Value> maybeMask = std::nullopt) {
435 // Incremental support for masking.
436 if (mask && !maybeMask.has_value())
437 return failure();
438
439 Type resElementType = cast<VectorType>(res.getType()).getElementType();
440 for (int64_t k = 0; k < reductionSize; ++k) {
441 Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
442 Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
443 extractA = promote(v: extractA, dstElementType: resElementType);
444 extractB = promote(v: extractB, dstElementType: resElementType);
445 Value extractMask;
446 if (maybeMask.has_value() && maybeMask.value())
447 extractMask =
448 rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
449
450 Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
451 loc, res.getType(), extractA, extractB, res, kind);
452 res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
453 }
454 return res;
455 }
456
457 /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
458 /// dimension `reductionDim`. If the dimension is a scalable dimension,
459 /// returns "nullopt".
460 std::optional<int64_t> getReductionSize(VectorType vecType,
461 int64_t reductionDim) {
462 // Cannot unroll scalable dimension.
463 if (vecType.getScalableDims()[reductionDim])
464 return std::nullopt;
465 int64_t reductionSize = vecType.getDimSize(reductionDim);
466 assert(reductionSize > 0 &&
467 "Reduction dim must be a known static size to allow unrolling");
468 return reductionSize;
469 }
470
471 /// Two outer parallel, one inner reduction (matmat flavor).
472 FailureOr<Value> matmat() {
473 if (!iters({Par(), Par(), Red()}))
474 return failure();
475 // Set up the parallel/reduction structure in the right form.
476 AffineExpr m, n, k;
477 bindDims(rewriter.getContext(), m, n, k);
478
479 // Classical row-major matmul: Just permute the lhs.
480 if (layout({{m, k}, {k, n}, {m, n}})) {
481 if (auto reductionSize = getReductionSize(lhsType, 1)) {
482 // Note: `t` creates new IR. It must be nested within this `if` check
483 // so that no IR is created when then pattern returns "failure".
484 Value tLhs = t(v: lhs);
485 Value tMask = t(v: mask, perm: {2, 0, 1});
486 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
487 }
488 }
489 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
490 if (layout({{m, k}, {n, k}, {m, n}})) {
491 if (auto reductionSize = getReductionSize(lhsType, 1)) {
492 Value tLhs = t(v: lhs);
493 Value tRhs = t(v: rhs);
494 Value tMask = t(v: mask, perm: {2, 0, 1});
495 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
496 }
497 }
498 // No need to permute anything.
499 if (layout({{k, m}, {k, n}, {m, n}})) {
500 if (auto reductionSize = getReductionSize(lhsType, 0)) {
501 Value tMask = t(v: mask, perm: {2, 0, 1});
502 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
503 }
504 }
505 // Just permute the rhs.
506 if (layout({{k, m}, {n, k}, {m, n}})) {
507 if (auto reductionSize = getReductionSize(lhsType, 0)) {
508 Value tRhs = t(v: rhs);
509 Value tMask = t(v: mask, perm: {2, 0, 1});
510 return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
511 }
512 }
513 // Transposed output: swap RHS and LHS.
514 // Classical row-major matmul: permute the lhs.
515 if (layout({{m, k}, {k, n}, {n, m}})) {
516 if (auto reductionSize = getReductionSize(lhsType, 1)) {
517 Value tLhs = t(v: lhs);
518 Value tMask = t(v: mask, perm: {2, 0, 1});
519 return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
520 }
521 }
522 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
523 if (layout({{m, k}, {n, k}, {n, m}})) {
524 if (auto reductionSize = getReductionSize(lhsType, 1)) {
525 Value tRhs = t(v: rhs);
526 Value tLhs = t(v: lhs);
527 Value tMask = t(v: mask, perm: {2, 0, 1});
528 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
529 }
530 }
531 if (layout({{k, m}, {k, n}, {n, m}})) {
532 if (auto reductionSize = getReductionSize(lhsType, 0)) {
533 Value tMask = t(v: mask, perm: {2, 0, 1});
534 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
535 }
536 }
537 if (layout({{k, m}, {n, k}, {n, m}})) {
538 if (auto reductionSize = getReductionSize(lhsType, 0)) {
539 Value tRhs = t(v: rhs);
540 Value tMask = t(v: mask, perm: {2, 0, 1});
541 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
542 }
543 }
544 return failure();
545 }
546
547 //
548 // One outer parallel, one inner reduction (matvec flavor).
549 // Mask needs to be transposed everywhere to turn the reduction dimension
550 // outermost as required by outerproduct.
551 //
552 FailureOr<Value> matvec() {
553 if (!iters({Par(), Red()}))
554 return failure();
555 AffineExpr m, k;
556 bindDims(rewriter.getContext(), m, k);
557
558 // Case mat-vec: transpose.
559 if (layout({{m, k}, {k}, {m}})) {
560 if (auto reductionSize = getReductionSize(lhsType, 1)) {
561 Value tLhs = t(v: lhs);
562 Value tMask = t(v: mask);
563 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
564 }
565 }
566 // Case mat-trans-vec: ready to go.
567 if (layout({{k, m}, {k}, {m}})) {
568 if (auto reductionSize = getReductionSize(lhsType, 0)) {
569 Value tMask = t(v: mask);
570 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
571 }
572 }
573 // Case vec-mat: swap and transpose.
574 if (layout({{k}, {m, k}, {m}})) {
575 if (auto reductionSize = getReductionSize(lhsType, 0)) {
576 Value tRhs = t(v: rhs);
577 Value tMask = t(v: mask);
578 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
579 }
580 }
581 // Case vec-mat-trans: swap and ready to go.
582 if (layout({{k}, {k, m}, {m}})) {
583 if (auto reductionSize = getReductionSize(lhsType, 0)) {
584 Value tMask = t(v: mask);
585 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
586 }
587 }
588 return failure();
589 }
590
591 //
592 // One outer reduction, one inner parallel (tmatvec flavor).
593 // Mask already has the shape of the outer product.
594 //
595 FailureOr<Value> tmatvec() {
596 if (!iters({Red(), Par()}))
597 return failure();
598 AffineExpr k, m;
599 bindDims(rewriter.getContext(), k, m);
600
601 // Case mat-vec: transpose.
602 if (layout({{m, k}, {k}, {m}}))
603 if (auto reductionSize = getReductionSize(lhsType, 1))
604 return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
605 // Case mat-trans-vec: ready to go.
606 if (layout({{k, m}, {k}, {m}}))
607 if (auto reductionSize = getReductionSize(lhsType, 0))
608 return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
609 // Case vec-mat: swap and transpose.
610 if (layout({{k}, {m, k}, {m}}))
611 if (auto reductionSize = getReductionSize(lhsType, 0))
612 return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
613 // Case vec-mat-trans: swap and ready to go.
614 if (layout({{k}, {k, m}, {m}}))
615 if (auto reductionSize = getReductionSize(lhsType, 0))
616 return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
617 return failure();
618 }
619
620private:
621 vector::CombiningKind kind;
622 Value lhs, rhs, res, mask;
623 VectorType lhsType;
624};
625
626/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
627/// semantics to a reduction_size-unrolled sequence:
628/// ```
629/// %at = vector.transpose %a, [1, 0]
630/// %bRow0 = vector.extract %b[0]
631/// %atRow0 = vector.extract %at[0]
632/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
633/// ...
634/// %bRowK = vector.extract %b[K]
635/// %atRowK = vector.extract %at[K]
636/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
637/// ```
638///
639/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
640/// otherwise supports any layout permutation of the matrix-multiply.
641FailureOr<Value>
642ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
643 vector::ContractionOp op, MaskingOpInterface maskOp,
644 PatternRewriter &rewriter) const {
645 if (vectorTransformOptions.vectorContractLowering !=
646 vector::VectorContractLowering::OuterProduct)
647 return failure();
648
649 if (failed(filter(op)))
650 return failure();
651
652 UnrolledOuterProductGenerator e(rewriter, op);
653 FailureOr<Value> matmatRes = e.matmat();
654 if (succeeded(result: matmatRes)) {
655 return matmatRes;
656 }
657 FailureOr<Value> matvecRes = e.matvec();
658 if (succeeded(result: matvecRes)) {
659 return matvecRes;
660 }
661
662 FailureOr<Value> tmatvecRes = e.tmatvec();
663 return tmatvecRes;
664}
665
666FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
667 vector::ContractionOp op, MaskingOpInterface maskOp,
668 PatternRewriter &rewriter) const {
669 // TODO: Support vector.mask.
670 if (maskOp)
671 return failure();
672
673 if (failed(filter(op)))
674 return failure();
675
676 if (vectorTransformOptions.vectorContractLowering !=
677 vector::VectorContractLowering::Dot)
678 return failure();
679
680 auto iteratorTypes = op.getIteratorTypes().getValue();
681 static constexpr std::array<int64_t, 2> perm = {1, 0};
682 Location loc = op.getLoc();
683 Value lhs = op.getLhs(), rhs = op.getRhs();
684
685 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
686 auto infer = [&](MapList m) {
687 return AffineMap::inferFromExprList(m, op.getContext());
688 };
689 AffineExpr m, n, k;
690 bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k);
691 SmallVector<AffineMap> maps = op.getIndexingMapsArray();
692 //
693 // In the following we wish to make the reduction dimension innermost so we
694 // can load vectors and just fmul + reduce into a scalar.
695 //
696 if (isParallelIterator(iteratorTypes[0]) &&
697 isParallelIterator(iteratorTypes[1]) &&
698 isReductionIterator(iteratorTypes[2])) {
699 //
700 // Two outer parallel, one inner reduction (matmat flavor).
701 //
702 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
703 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
704 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
705 // No need to permute anything.
706 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
707 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
708 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
709 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
710 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
711 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
712 // This is the classical row-major matmul. Just permute the lhs.
713 Value tmp = lhs;
714 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
715 rhs = tmp;
716 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
717 std::swap(a&: lhs, b&: rhs);
718 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
719 Value tmp = lhs;
720 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
721 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
722 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
723 Value tmp = rhs;
724 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
725 lhs = tmp;
726 } else {
727 return failure();
728 }
729 } else if (isParallelIterator(iteratorTypes[0]) &&
730 isReductionIterator(iteratorTypes[1])) {
731 //
732 // One outer parallel, one inner reduction (matvec flavor)
733 //
734 if (maps == infer({{m, n}, {n}, {m}})) {
735 // No need to permute anything.
736 } else if (maps == infer({{n, m}, {n}, {m}})) {
737 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
738 } else if (maps == infer({{n}, {m, n}, {m}})) {
739 std::swap(a&: lhs, b&: rhs);
740 } else if (maps == infer({{n}, {n, m}, {m}})) {
741 std::swap(a&: lhs, b&: rhs);
742 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
743 } else {
744 return failure();
745 }
746 } else {
747 return failure();
748 }
749
750 VectorType dstType = cast<VectorType>(op.getResultType());
751 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
752 "Expected dst type of rank 1 or 2");
753
754 unsigned rank = dstType.getRank();
755 unsigned dstRows = dstType.getShape()[0];
756 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
757
758 // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
759 Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
760 rewriter.getZeroAttr(dstType));
761 bool isInt = isa<IntegerType>(dstType.getElementType());
762 for (unsigned r = 0; r < dstRows; ++r) {
763 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
764 for (unsigned c = 0; c < dstColumns; ++c) {
765 Value b = rank == 1
766 ? rhs
767 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
768 Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
769 Value reduced = rewriter.create<vector::ReductionOp>(
770 op.getLoc(), vector::CombiningKind::ADD, m);
771
772 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
773 : SmallVector<int64_t, 2>{r, c};
774 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
775 }
776 }
777 if (auto acc = op.getAcc())
778 res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
779 return res;
780}
781
782/// Lower vector.contract with all size one reduction dimensions to
783/// elementwise ops when possible.
784struct ContractOpToElementwise
785 : public MaskableOpRewritePattern<vector::ContractionOp> {
786 using MaskableOpRewritePattern::MaskableOpRewritePattern;
787 using FilterConstraintType =
788 std::function<LogicalResult(vector::ContractionOp op)>;
789 static LogicalResult defaultFilter(vector::ContractionOp op) {
790 return success();
791 }
792 ContractOpToElementwise(
793 vector::VectorTransformsOptions vectorTransformOptions,
794 MLIRContext *context, PatternBenefit benefit = 1,
795 const FilterConstraintType &constraint = defaultFilter)
796 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
797 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
798
799 FailureOr<Value>
800 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
801 MaskingOpInterface maskOp,
802 PatternRewriter &rewriter) const override {
803 // TODO: Support vector.mask.
804 if (maskOp)
805 return failure();
806
807 if (failed(filter(contractOp)))
808 return failure();
809
810 if (vectorTransformOptions.vectorContractLowering !=
811 vector::VectorContractLowering::ParallelArith)
812 return failure();
813
814 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
815 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
816 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
817 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
818 SmallVector<int64_t> lhsReductionDims =
819 getReductionIndex(lhsMap, contractOp.getIteratorTypes());
820 SmallVector<int64_t> rhsReductionDims =
821 getReductionIndex(rhsMap, contractOp.getIteratorTypes());
822 // All the reduction dimensions must be a size 1.
823 for (int64_t dim : lhsReductionDims) {
824 if (lhsShape[dim] != 1)
825 return failure();
826 }
827 for (int64_t dim : rhsReductionDims) {
828 if (rhsShape[dim] != 1)
829 return failure();
830 }
831 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
832 unsigned numParallelDims = accMap.getNumResults();
833 unsigned numLhsDimToBroadcast =
834 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
835 unsigned numRhsDimToBroadcast =
836 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
837 SmallVector<int64_t> lhsDims;
838 SmallVector<int64_t> lhsTranspose;
839 SmallVector<int64_t> rhsDims;
840 SmallVector<int64_t> rhsTranspose;
841 for (int64_t dim : lhsReductionDims)
842 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
843 for (int64_t dim : rhsReductionDims)
844 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
845 // Loop through the parallel dimensions to calculate the dimensions to
846 // broadcast and to permute in order to extract only parallel dimensions.
847 for (unsigned i = 0; i < numParallelDims; i++) {
848 std::optional<unsigned> lhsDim =
849 getDimPosition(map: lhsMap, dim: accMap.getDimPosition(idx: i));
850 if (lhsDim) {
851 lhsTranspose.push_back(Elt: numLhsDimToBroadcast + *lhsDim);
852 } else {
853 // If the parallel dimension doesn't exist we will have to broadcast it.
854 lhsDims.push_back(
855 Elt: cast<VectorType>(contractOp.getResultType()).getDimSize(i));
856 lhsTranspose.push_back(Elt: lhsDims.size() - 1);
857 }
858 std::optional<unsigned> rhsDim =
859 getDimPosition(map: rhsMap, dim: accMap.getDimPosition(idx: i));
860 if (rhsDim) {
861 rhsTranspose.push_back(Elt: numRhsDimToBroadcast + *rhsDim);
862 } else {
863 // If the parallel dimension doesn't exist we will have to broadcast it.
864 rhsDims.push_back(
865 Elt: cast<VectorType>(contractOp.getResultType()).getDimSize(i));
866 rhsTranspose.push_back(Elt: rhsDims.size() - 1);
867 }
868 }
869 Value newLhs = contractOp.getLhs();
870 Value newRhs = contractOp.getRhs();
871 Location loc = contractOp.getLoc();
872 if (!lhsDims.empty()) {
873 lhsDims.append(in_start: lhsShape.begin(), in_end: lhsShape.end());
874 auto expandedType =
875 VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
876 newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
877 }
878 if (!rhsDims.empty()) {
879 rhsDims.append(in_start: rhsShape.begin(), in_end: rhsShape.end());
880 auto expandedType =
881 VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
882 newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
883 }
884 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
885 newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
886 newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
887 SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
888 SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
889 newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
890 newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
891 std::optional<Value> result =
892 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
893 contractOp.getKind(), rewriter, isInt);
894 if (result)
895 return *result;
896
897 return failure();
898 }
899
900private:
901 /// Options to control the vector patterns.
902 vector::VectorTransformsOptions vectorTransformOptions;
903 FilterConstraintType filter;
904};
905
906/// Progressive lowering of ContractionOp.
907/// One:
908/// %x = vector.contract with at least one free/batch dimension
909/// is replaced by:
910/// %a = vector.contract with one less free/batch dimension
911/// %b = vector.contract with one less free/batch dimension
912/// ..
913/// %x = combine %a %b ..
914/// until a pure contraction is reached (no free/batch dimensions),
915/// which is replaced by a dot-product.
916///
917/// This only kicks in when either VectorTransformsOptions is set
918/// to DOT or when other contraction patterns fail.
919//
920// TODO: break down into transpose/reshape/cast ops
921// when they become available to avoid code dup
922// TODO: investigate lowering order impact on performance
923FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
924 vector::ContractionOp op, MaskingOpInterface maskOp,
925 PatternRewriter &rewriter) const {
926 if (failed(filter(op)))
927 return failure();
928
929 // TODO: support mixed mode contract lowering.
930 if (op.getLhsType().getElementType() !=
931 getElementTypeOrSelf(op.getAccType()) ||
932 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
933 return failure();
934
935 // TODO: the code below assumes the default contraction, make sure it supports
936 // other kinds before enabling this lowering.
937 if (op.getKind() != vector::CombiningKind::ADD) {
938 return rewriter.notifyMatchFailure(
939 op, "contractions other than 'add' not supported");
940 }
941
942 // TODO: implement benefits, cost models.
943 MLIRContext *ctx = op.getContext();
944
945 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
946 FailureOr<Value> newVal1 =
947 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
948 if (!failed(result: newVal1))
949 return newVal1;
950
951 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
952 FailureOr<Value> newVal2 =
953 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
954 if (!failed(result: newVal2))
955 return newVal2;
956
957 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
958 FailureOr<Value> newVal3 =
959 pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
960 if (!failed(result: newVal3))
961 return newVal3;
962
963 ContractOpToElementwise pat4(vectorTransformOptions, ctx);
964 FailureOr<Value> newVal4 =
965 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
966 if (!failed(result: newVal4))
967 return newVal4;
968
969 // Vector mask setup.
970
971 Value mask;
972 if (maskOp)
973 mask = maskOp.getMask();
974 // Find first batch dimension in LHS/RHS, and lower when found.
975 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
976 if (!batchDimMap.empty()) {
977 int64_t lhsIndex = batchDimMap[0].first;
978 int64_t rhsIndex = batchDimMap[0].second;
979 auto newOp = lowerParallel(rewriter, op: op, lhsIndex, rhsIndex, mask);
980 if (failed(newOp))
981 return failure();
982 return newOp;
983 }
984
985 // Collect contracting dimensions.
986 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
987 op.getContractingDimMap();
988 DenseSet<int64_t> lhsContractingDimSet;
989 DenseSet<int64_t> rhsContractingDimSet;
990 for (auto &dimPair : contractingDimMap) {
991 lhsContractingDimSet.insert(dimPair.first);
992 rhsContractingDimSet.insert(dimPair.second);
993 }
994
995 // Find first free dimension in LHS, and lower when found.
996 VectorType lhsType = op.getLhsType();
997 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
998 if (lhsContractingDimSet.count(V: lhsIndex) == 0) {
999 auto newOp = lowerParallel(rewriter, op: op, lhsIndex, /*rhsIndex=*/-1, mask);
1000 if (failed(newOp))
1001 return failure();
1002 return newOp;
1003 }
1004 }
1005
1006 // Find first free dimension in RHS, and lower when found.
1007 VectorType rhsType = op.getRhsType();
1008 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1009 if (rhsContractingDimSet.count(V: rhsIndex) == 0) {
1010 auto newOp = lowerParallel(rewriter, op: op, /*lhsIndex=*/-1, rhsIndex, mask);
1011 if (failed(newOp))
1012 return failure();
1013 return newOp;
1014 }
1015 }
1016
1017 // Lower the first remaining reduction dimension.
1018 if (!contractingDimMap.empty()) {
1019 auto newOp = lowerReduction(rewriter, op: op, mask);
1020 if (failed(newOp))
1021 return failure();
1022 return newOp;
1023 }
1024
1025 return failure();
1026}
1027
1028// Lower one parallel dimension.
1029// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1030// TODO: consider reusing existing contract unrolling
1031FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
1032 vector::ContractionOp op,
1033 int64_t lhsIndex,
1034 int64_t rhsIndex,
1035 Value mask) const {
1036 VectorType lhsType = op.getLhsType();
1037 VectorType rhsType = op.getRhsType();
1038 VectorType resType = cast<VectorType>(op.getResultType());
1039 // Find the iterator type index and result index.
1040 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1041 int64_t iterIndex = -1;
1042 int64_t dimSize = -1;
1043 if (lhsIndex >= 0) {
1044 iterIndex = iMap[0].getDimPosition(idx: lhsIndex);
1045 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(idx: rhsIndex))
1046 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1047 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1048 << " to map to the same dimension";
1049 });
1050 if (lhsType.getScalableDims()[lhsIndex])
1051 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1052 diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1053 << ") is not supported yet";
1054 });
1055 dimSize = lhsType.getDimSize(lhsIndex);
1056 } else if (rhsIndex >= 0) {
1057 iterIndex = iMap[1].getDimPosition(idx: rhsIndex);
1058 if (rhsType.getScalableDims()[rhsIndex])
1059 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1060 diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1061 << ") is not supported yet";
1062 });
1063 dimSize = rhsType.getDimSize(rhsIndex);
1064 }
1065 if (iterIndex < 0)
1066 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1067 diag << "expected either lhsIndex=" << lhsIndex
1068 << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1069 });
1070 // value_or(-1) means that we tolerate a dimension not appearing
1071 // in the result map. That can't happen for actual parallel iterators, but
1072 // the caller ContractionOpLowering::matchAndRewrite is currently calling
1073 // lowerParallel also for the case of unit-size reduction dims appearing only
1074 // on one of LHS or RHS, not both. At the moment, such cases are created by
1075 // CastAwayContractionLeadingOneDim, so we need to either support that or
1076 // modify that pattern.
1077 int64_t resIndex = getResultIndex(map: iMap[2], index: iterIndex).value_or(u: -1);
1078 if (resIndex == -1 && dimSize != 1)
1079 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1080 diag << "expected the dimension for iterIndex=" << iterIndex
1081 << " to either appear in the result map, or to be a unit dimension";
1082 });
1083
1084 // Construct new iterator types and affine map array attribute.
1085 std::array<AffineMap, 3> lowIndexingMaps = {
1086 adjustMap(map: iMap[0], index: iterIndex, rewriter),
1087 adjustMap(map: iMap[1], index: iterIndex, rewriter),
1088 adjustMap(map: iMap[2], index: iterIndex, rewriter)};
1089 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1090 auto lowIter =
1091 rewriter.getArrayAttr(value: adjustIter(op.getIteratorTypes(), iterIndex));
1092 // Unroll into a series of lower dimensional vector.contract ops.
1093 Location loc = op.getLoc();
1094 Value result = rewriter.create<arith::ConstantOp>(
1095 loc, resType, rewriter.getZeroAttr(resType));
1096
1097 for (int64_t d = 0; d < dimSize; ++d) {
1098 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1099 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1100 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1101
1102 Value lowMask;
1103 if (mask)
1104 lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1105 iterIndex, d, rewriter);
1106
1107 Operation *lowContract = rewriter.create<vector::ContractionOp>(
1108 loc, lhs, rhs, acc, lowAffine, lowIter);
1109 lowContract = maskOperation(builder&: rewriter, maskableOp: lowContract, mask: lowMask);
1110 result = reshapeStore(loc, lowContract->getResult(idx: 0), result, resType,
1111 resIndex, d, rewriter);
1112 }
1113 return result;
1114}
1115
1116// Lower one reduction dimension.
1117FailureOr<Value> ContractionOpLowering::lowerReduction(
1118 PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1119 auto loc = op.getLoc();
1120 VectorType lhsType = op.getLhsType();
1121 VectorType rhsType = op.getRhsType();
1122 Type resType = op.getResultType();
1123 if (isa<VectorType>(Val: resType))
1124 return rewriter.notifyMatchFailure(op,
1125 "did not expect a VectorType result");
1126 bool isInt = isa<IntegerType>(Val: resType);
1127 // Use iterator index 0.
1128 int64_t iterIndex = 0;
1129 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1130 std::optional<int64_t> lookupLhs = getResultIndex(map: iMap[0], index: iterIndex);
1131 std::optional<int64_t> lookupRhs = getResultIndex(map: iMap[1], index: iterIndex);
1132 if (!lookupLhs.has_value())
1133 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1134 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1135 });
1136 if (!lookupRhs.has_value())
1137 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1138 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1139 });
1140 int64_t lhsIndex = *lookupLhs;
1141 int64_t rhsIndex = *lookupRhs;
1142 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1143 if (dimSize != rhsType.getDimSize(rhsIndex))
1144 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1145 diag << "expect LHS dimension " << lhsIndex
1146 << " to have the same size as RHS dimension " << rhsIndex;
1147 });
1148 // Base case.
1149 if (lhsType.getRank() == 1) {
1150 if (rhsType.getRank() != 1)
1151 return rewriter.notifyMatchFailure(
1152 op, "When LHS has rank 1, expected also RHS to have rank 1");
1153 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1154 auto kind = vector::CombiningKind::ADD;
1155
1156 Value acc = op.getAcc();
1157 Operation *reductionOp =
1158 acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1159 : rewriter.create<vector::ReductionOp>(loc, kind, m);
1160 return maskOperation(builder&: rewriter, maskableOp: reductionOp, mask)->getResult(idx: 0);
1161 }
1162 // Construct new iterator types and affine map array attribute.
1163 std::array<AffineMap, 3> lowIndexingMaps = {
1164 adjustMap(map: iMap[0], index: iterIndex, rewriter),
1165 adjustMap(map: iMap[1], index: iterIndex, rewriter),
1166 adjustMap(map: iMap[2], index: iterIndex, rewriter)};
1167 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1168 auto lowIter =
1169 rewriter.getArrayAttr(value: adjustIter(op.getIteratorTypes(), iterIndex));
1170 // Unroll into a series of lower dimensional vector.contract ops.
1171 // By feeding the initial accumulator into the first contraction,
1172 // and the result of each contraction into the next, eventually
1173 // the sum of all reductions is computed.
1174 Value result = op.getAcc();
1175 for (int64_t d = 0; d < dimSize; ++d) {
1176 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1177 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1178 Value newMask;
1179 if (mask)
1180 newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1181 iterIndex, d, rewriter);
1182
1183 Operation *newContract = rewriter.create<vector::ContractionOp>(
1184 loc, lhs, rhs, result, lowAffine, lowIter);
1185 result = maskOperation(builder&: rewriter, maskableOp: newContract, mask: newMask)->getResult(idx: 0);
1186 }
1187 return result;
1188}
1189
1190/// Progressive lowering of OuterProductOp.
1191/// One:
1192/// %x = vector.outerproduct %lhs, %rhs, %acc
1193/// is replaced by:
1194/// %z = zero-result
1195/// %0 = vector.extract %lhs[0]
1196/// %1 = vector.broadcast %0
1197/// %2 = vector.extract %acc[0]
1198/// %3 = vector.fma %1, %rhs, %2
1199/// %4 = vector.insert %3, %z[0]
1200/// ..
1201/// %x = vector.insert %.., %..[N-1]
1202///
1203class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1204public:
1205 using OpRewritePattern::OpRewritePattern;
1206
1207 LogicalResult matchAndRewrite(vector::OuterProductOp op,
1208 PatternRewriter &rewriter) const override {
1209 VectorType resType = op.getResultVectorType();
1210 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1211 return failure();
1212
1213 auto loc = op.getLoc();
1214
1215 VectorType lhsType = op.getOperandVectorTypeLHS();
1216 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1217 Type eltType = resType.getElementType();
1218 bool isInt = isa<IntegerType, IndexType>(Val: eltType);
1219 Value acc = op.getAcc();
1220 vector::CombiningKind kind = op.getKind();
1221
1222 // Vector mask setup.
1223 OpBuilder::InsertionGuard guard(rewriter);
1224 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1225 Operation *rootOp;
1226 Value mask;
1227 if (maskableOp.isMasked()) {
1228 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1229 rootOp = maskableOp.getMaskingOp();
1230 mask = maskableOp.getMaskingOp().getMask();
1231 } else {
1232 rootOp = op;
1233 }
1234
1235 if (!rhsType) {
1236 // Special case: AXPY operation.
1237 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1238 std::optional<Value> mult = createContractArithOp(
1239 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1240 if (!mult.has_value())
1241 return failure();
1242 rewriter.replaceOp(op: rootOp, newValues: *mult);
1243 return success();
1244 }
1245
1246 Value result = rewriter.create<arith::ConstantOp>(
1247 loc, resType, rewriter.getZeroAttr(resType));
1248 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1249 Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
1250 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1251 Value r = nullptr;
1252 if (acc)
1253 r = rewriter.create<vector::ExtractOp>(loc, acc, d);
1254 Value extrMask;
1255 if (mask)
1256 extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
1257
1258 std::optional<Value> m = createContractArithOp(
1259 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1260 if (!m.has_value())
1261 return failure();
1262 result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
1263 }
1264
1265 rewriter.replaceOp(op: rootOp, newValues: result);
1266 return success();
1267 }
1268};
1269
1270/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1271/// semantics to:
1272/// ```
1273/// %mta = maybe_transpose
1274/// %mtb = maybe_transpose
1275/// %flattened_a = vector.shape_cast %mta
1276/// %flattened_b = vector.shape_cast %mtb
1277/// %flattened_d = vector.matmul %flattened_a, %flattened_b
1278/// %mtd = vector.shape_cast %flattened_d
1279/// %d = maybe_untranspose %mtd
1280/// %e = add %c, %d
1281/// ```
1282/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
1283//
1284/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
1285/// vector.transpose operations are inserted if the vector.contract op is not a
1286/// row-major matrix multiply.
1287FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1288 vector::ContractionOp op, MaskingOpInterface maskOp,
1289 PatternRewriter &rew) const {
1290 // TODO: Support vector.mask.
1291 if (maskOp)
1292 return failure();
1293
1294 if (vectorTransformOptions.vectorContractLowering !=
1295 vector::VectorContractLowering::Matmul)
1296 return failure();
1297 if (failed(filter(op)))
1298 return failure();
1299
1300 auto iteratorTypes = op.getIteratorTypes().getValue();
1301 if (!isParallelIterator(iteratorTypes[0]) ||
1302 !isParallelIterator(iteratorTypes[1]) ||
1303 !isReductionIterator(iteratorTypes[2]))
1304 return failure();
1305
1306 Type elementType = op.getLhsType().getElementType();
1307 if (!elementType.isIntOrFloat())
1308 return failure();
1309
1310 Type dstElementType = op.getType();
1311 if (auto vecType = dyn_cast<VectorType>(dstElementType))
1312 dstElementType = vecType.getElementType();
1313 if (elementType != dstElementType)
1314 return failure();
1315
1316 // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1317 // Bail out if the contraction cannot be put in this form.
1318 MLIRContext *ctx = op.getContext();
1319 Location loc = op.getLoc();
1320 AffineExpr m, n, k;
1321 bindDims(ctx: rew.getContext(), exprs&: m, exprs&: n, exprs&: k);
1322 // LHS must be A(m, k) or A(k, m).
1323 Value lhs = op.getLhs();
1324 auto lhsMap = op.getIndexingMapsArray()[0];
1325 if (lhsMap == AffineMap::get(dimCount: 3, symbolCount: 0, results: {k, m}, context: ctx))
1326 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1327 else if (lhsMap != AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, k}, context: ctx))
1328 return failure();
1329
1330 // RHS must be B(k, n) or B(n, k).
1331 Value rhs = op.getRhs();
1332 auto rhsMap = op.getIndexingMapsArray()[1];
1333 if (rhsMap == AffineMap::get(dimCount: 3, symbolCount: 0, results: {n, k}, context: ctx))
1334 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1335 else if (rhsMap != AffineMap::get(dimCount: 3, symbolCount: 0, results: {k, n}, context: ctx))
1336 return failure();
1337
1338 // At this point lhs and rhs are in row-major.
1339 VectorType lhsType = cast<VectorType>(lhs.getType());
1340 VectorType rhsType = cast<VectorType>(rhs.getType());
1341 int64_t lhsRows = lhsType.getDimSize(0);
1342 int64_t lhsColumns = lhsType.getDimSize(1);
1343 int64_t rhsColumns = rhsType.getDimSize(1);
1344
1345 Type flattenedLHSType =
1346 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1347 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1348
1349 Type flattenedRHSType =
1350 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1351 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1352
1353 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1354 rhsColumns);
1355 mul = rew.create<vector::ShapeCastOp>(
1356 loc,
1357 VectorType::get({lhsRows, rhsColumns},
1358 getElementTypeOrSelf(op.getAcc().getType())),
1359 mul);
1360
1361 // ACC must be C(m, n) or C(n, m).
1362 auto accMap = op.getIndexingMapsArray()[2];
1363 if (accMap == AffineMap::get(dimCount: 3, symbolCount: 0, results: {n, m}, context: ctx))
1364 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1365 else if (accMap != AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, n}, context: ctx))
1366 llvm_unreachable("invalid contraction semantics");
1367
1368 Value res =
1369 isa<IntegerType>(elementType)
1370 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1371 : static_cast<Value>(
1372 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1373
1374 return res;
1375}
1376} // namespace
1377
1378void mlir::vector::populateVectorContractLoweringPatterns(
1379 RewritePatternSet &patterns, VectorTransformsOptions options,
1380 PatternBenefit benefit, bool disableOuterProductLowering) {
1381 if (!disableOuterProductLowering)
1382 patterns.add<OuterProductOpLowering>(arg: patterns.getContext(), args&: benefit);
1383 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1384 ContractionOpToOuterProductOpLowering>(
1385 arg&: options, args: patterns.getContext(), args&: benefit);
1386}
1387
1388void mlir::vector::populateVectorOuterProductLoweringPatterns(
1389 RewritePatternSet &patterns, PatternBenefit benefit) {
1390 patterns.add<OuterProductOpLowering>(arg: patterns.getContext(), args&: benefit);
1391}
1392

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp