1//===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.contract' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/Utils/IndexingUtils.h"
17#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
20#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/IR/TypeUtilities.h"
25
26#define DEBUG_TYPE "vector-contract-lowering"
27
28using namespace mlir;
29using namespace mlir::vector;
30
31//===----------------------------------------------------------------------===//
32// Helper functions
33//===----------------------------------------------------------------------===//
34// Helper to find an index in an affine map.
35static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
36 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
37 int64_t idx = map.getDimPosition(idx: i);
38 if (idx == index)
39 return i;
40 }
41 return std::nullopt;
42}
43
44// Helper to construct iterator types with one index removed.
45static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
46 int64_t index) {
47 SmallVector<Attribute> results;
48 for (const auto &it : llvm::enumerate(First&: iteratorTypes)) {
49 int64_t idx = it.index();
50 if (idx == index)
51 continue;
52 results.push_back(Elt: it.value());
53 }
54 return results;
55}
56
57// Helper to construct an affine map with one index removed.
58static AffineMap adjustMap(AffineMap map, int64_t index,
59 PatternRewriter &rewriter) {
60 auto *ctx = rewriter.getContext();
61 SmallVector<AffineExpr> results;
62 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
63 int64_t idx = map.getDimPosition(idx: i);
64 if (idx == index)
65 continue;
66 // Re-insert remaining indices, but renamed when occurring
67 // after the removed index.
68 auto targetExpr = getAffineDimExpr(position: idx < index ? idx : idx - 1, context: ctx);
69 results.push_back(Elt: targetExpr);
70 }
71 return AffineMap::get(dimCount: map.getNumDims() - 1, symbolCount: 0, results, context: ctx);
72}
73
74// Helper method to possibly drop a dimension in a load.
75// TODO
76static Value reshapeLoad(Location loc, Value val, VectorType type,
77 int64_t index, int64_t pos,
78 PatternRewriter &rewriter) {
79 if (index == -1)
80 return val;
81
82 // At extraction dimension?
83 if (index == 0)
84 return rewriter.create<vector::ExtractOp>(location: loc, args&: val, args&: pos);
85
86 // Unroll leading dimensions.
87 VectorType vType = VectorType::Builder(type).dropDim(pos: 0);
88 VectorType resType = VectorType::Builder(type).dropDim(pos: index);
89 Value result = rewriter.create<arith::ConstantOp>(
90 location: loc, args&: resType, args: rewriter.getZeroAttr(type: resType));
91 for (int64_t d = 0, e = resType.getDimSize(idx: 0); d < e; d++) {
92 Value ext = rewriter.create<vector::ExtractOp>(location: loc, args&: val, args&: d);
93 Value load = reshapeLoad(loc, val: ext, type: vType, index: index - 1, pos, rewriter);
94 result = rewriter.create<vector::InsertOp>(location: loc, args&: load, args&: result, args&: d);
95 }
96 return result;
97}
98
99// Helper method to possibly drop a dimension in a store.
100// TODO
101static Value reshapeStore(Location loc, Value val, Value result,
102 VectorType type, int64_t index, int64_t pos,
103 PatternRewriter &rewriter) {
104 // Unmodified?
105 if (index == -1)
106 return val;
107 // At insertion dimension?
108 if (index == 0)
109 return rewriter.create<vector::InsertOp>(location: loc, args&: val, args&: result, args&: pos);
110
111 // Unroll leading dimensions.
112 VectorType vType = VectorType::Builder(type).dropDim(pos: 0);
113 for (int64_t d = 0, e = type.getDimSize(idx: 0); d < e; d++) {
114 Value ext = rewriter.create<vector::ExtractOp>(location: loc, args&: result, args&: d);
115 Value ins = rewriter.create<vector::ExtractOp>(location: loc, args&: val, args&: d);
116 Value sto = reshapeStore(loc, val: ins, result: ext, type: vType, index: index - 1, pos, rewriter);
117 result = rewriter.create<vector::InsertOp>(location: loc, args&: sto, args&: result, args&: d);
118 }
119 return result;
120}
121
122/// Helper to create arithmetic operation associated with a kind of contraction.
123static std::optional<Value>
124createContractArithOp(Location loc, Value x, Value y, Value acc,
125 vector::CombiningKind kind, PatternRewriter &rewriter,
126 bool isInt, Value mask = Value()) {
127 using vector::CombiningKind;
128 Value mul;
129
130 if (isInt) {
131 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
132 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
133 // Only valid for floating point types.
134 return std::nullopt;
135 mul = rewriter.create<arith::MulIOp>(location: loc, args&: x, args&: y);
136 } else {
137 // Float case.
138 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
139 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
140 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
141 kind == CombiningKind::XOR)
142 // Only valid for integer types.
143 return std::nullopt;
144 // Special case for fused multiply-add.
145 if (acc && isa<VectorType>(Val: acc.getType()) && kind == CombiningKind::ADD) {
146 Value fma = rewriter.create<vector::FMAOp>(location: loc, args&: x, args&: y, args&: acc);
147 if (mask)
148 // The fma op doesn't need explicit masking. However, fma ops used in
149 // reductions must preserve previous 'acc' values for masked-out lanes.
150 fma = selectPassthru(builder&: rewriter, mask, newValue: fma, passthru: acc);
151 return fma;
152 }
153 mul = rewriter.create<arith::MulFOp>(location: loc, args&: x, args&: y);
154 }
155
156 if (!acc)
157 return std::optional<Value>(mul);
158
159 return makeArithReduction(b&: rewriter, loc, kind, v1: mul, acc,
160 /*fastmath=*/nullptr, mask);
161}
162
163/// Return the positions of the reductions in the given map.
164static SmallVector<int64_t> getReductionIndex(AffineMap map,
165 ArrayAttr iteratorTypes) {
166 SmallVector<int64_t> dimsIdx;
167 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
168 if (isReductionIterator(attr: iteratorTypes[map.getDimPosition(idx: i)]))
169 dimsIdx.push_back(Elt: i);
170 }
171 return dimsIdx;
172}
173
174/// Look for a given dimension in an affine map and return its position. Return
175/// std::nullopt if the dimension is not in the map results.
176static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
177 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
178 if (map.getDimPosition(idx: i) == dim)
179 return i;
180 }
181 return std::nullopt;
182}
183
184/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
185/// operands `x` and `y`.
186static Value createAdd(Location loc, Value x, Value y, bool isInt,
187 PatternRewriter &rewriter) {
188 if (isInt)
189 return rewriter.create<arith::AddIOp>(location: loc, args&: x, args&: y);
190 return rewriter.create<arith::AddFOp>(location: loc, args&: x, args&: y);
191}
192
193/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
194/// operands `x and `y`.
195static Value createMul(Location loc, Value x, Value y, bool isInt,
196 PatternRewriter &rewriter) {
197 if (isInt)
198 return rewriter.create<arith::MulIOp>(location: loc, args&: x, args&: y);
199 return rewriter.create<arith::MulFOp>(location: loc, args&: x, args&: y);
200}
201
202namespace {
203
204/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
205/// semantics to:
206/// ```
207/// %flattened_a = vector.shape_cast %a
208/// %flattened_b = vector.shape_cast %b
209/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
210/// %d = vector.shape_cast %%flattened_d
211/// %e = add %c, %d
212/// ```
213/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
214//
215/// This only kicks in when vectorContractLowering is set to Matmul and
216/// the vector.contract op is a row-major matrix multiply.
217class ContractionOpToMatmulOpLowering
218 : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
219public:
220 using MaskableOpRewritePattern::MaskableOpRewritePattern;
221
222 using FilterConstraintType =
223 std::function<LogicalResult(vector::ContractionOp op)>;
224
225 static LogicalResult defaultFilter(vector::ContractionOp op) {
226 return success();
227 }
228
229 ContractionOpToMatmulOpLowering(
230 vector::VectorContractLowering vectorContractLowering,
231 MLIRContext *context, PatternBenefit benefit = 1,
232 FilterConstraintType constraint = defaultFilter)
233 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
234 vectorContractLowering(vectorContractLowering),
235 filter(std::move(constraint)) {}
236
237 FailureOr<Value>
238 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
239 PatternRewriter &rewriter) const override;
240
241private:
242 /// Options to control the vector patterns.
243 vector::VectorContractLowering vectorContractLowering;
244 FilterConstraintType filter;
245};
246
247/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
248/// semantics to a reduction_size-unrolled sequence:
249/// ```
250/// %at = vector.transpose %a, [1, 0]
251/// %bRow0 = vector.extract %b[0]
252/// %atRow0 = vector.extract %at[0]
253/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
254/// ...
255/// %bRowK = vector.extract %b[K]
256/// %atRowK = vector.extract %at[K]
257/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
258/// ```
259///
260/// This only kicks in when vectorContractLowering is set to OuterProduct and
261/// the vector.contract op is a row-major matrix multiply.
262class ContractionOpToOuterProductOpLowering
263 : public MaskableOpRewritePattern<vector::ContractionOp> {
264public:
265 using MaskableOpRewritePattern::MaskableOpRewritePattern;
266
267 using FilterConstraintType =
268 std::function<LogicalResult(vector::ContractionOp op)>;
269
270 static LogicalResult defaultFilter(vector::ContractionOp op) {
271 return success();
272 }
273
274 ContractionOpToOuterProductOpLowering(
275 vector::VectorContractLowering vectorContractLowering,
276 MLIRContext *context, PatternBenefit benefit = 1,
277 FilterConstraintType constraint = defaultFilter)
278 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
279 vectorContractLowering(vectorContractLowering),
280 filter(std::move(constraint)) {}
281
282 FailureOr<Value>
283 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
284 PatternRewriter &rewriter) const override;
285
286private:
287 /// Options to control the vector patterns.
288 vector::VectorContractLowering vectorContractLowering;
289 FilterConstraintType filter;
290};
291
292/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
293/// semantics to an output-size-unrolled sequence:
294/// ```
295/// %out = arith.constant ... : vector<MxNxelt_type>
296/// %bt = vector.transpose %b, [1, 0]
297/// %aRow0 = vector.extract %a[0]
298/// %btRow0 = vector.extract %bt[0]
299/// %c00 = vector.reduce %atRow0, %bRow0
300/// %out00 = vector.insert %c00, %out[0, 0]
301/// ...
302/// %aRowLast = vector.extract %at[M-1]
303/// %btRowLast = vector.extract %b[N-1]
304/// %cLastLast = vector.reduce %atRowLast, %bRowLast
305/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
306/// ```
307///
308/// This only kicks in when VectorTransformsOptions is set to Dot and
309/// the vector.contract op is a row-major matmul or matvec.
310class ContractionOpToDotLowering
311 : public MaskableOpRewritePattern<vector::ContractionOp> {
312public:
313 using MaskableOpRewritePattern::MaskableOpRewritePattern;
314
315 using FilterConstraintType =
316 std::function<LogicalResult(vector::ContractionOp op)>;
317
318 static LogicalResult defaultFilter(vector::ContractionOp op) {
319 return success();
320 }
321
322 ContractionOpToDotLowering(
323 vector::VectorContractLowering vectorContractLowering,
324 MLIRContext *context, PatternBenefit benefit = 1,
325 const FilterConstraintType &constraint = defaultFilter)
326 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
327 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
328
329 FailureOr<Value>
330 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
331 PatternRewriter &rewriter) const override;
332
333private:
334 /// Options to control the vector patterns.
335 vector::VectorContractLowering vectorContractLowering;
336 FilterConstraintType filter;
337};
338
339/// Progressive lowering of ContractionOp.
340///
341/// One:
342/// %x = vector.contract with at least one free/batch dimension
343/// is replaced by:
344/// %a = vector.contract with one less free/batch dimension
345/// %b = vector.contract with one less free/batch dimension
346/// ..
347/// %x = combine %a %b ..
348/// until a pure contraction is reached (no free/batch dimensions),
349/// which is replaced by a dot-product.
350///
351/// This only kicks in when either VectorTransformsOptions is set
352/// to Dot or when other contraction patterns fail.
353class ContractionOpLowering
354 : public MaskableOpRewritePattern<vector::ContractionOp> {
355public:
356 using MaskableOpRewritePattern::MaskableOpRewritePattern;
357 using FilterConstraintType =
358 std::function<LogicalResult(vector::ContractionOp op)>;
359
360 static LogicalResult defaultFilter(vector::ContractionOp op) {
361 return success();
362 }
363
364 ContractionOpLowering(
365 vector::VectorContractLowering vectorContractLoweringOption,
366 MLIRContext *context, PatternBenefit benefit = 1,
367 FilterConstraintType constraint = defaultFilter)
368 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
369 vectorContractLoweringOption(vectorContractLoweringOption),
370 filter(std::move(constraint)) {}
371
372 FailureOr<Value>
373 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
374 PatternRewriter &rewriter) const override;
375
376private:
377 /// Options to control the vector patterns.
378 vector::VectorContractLowering vectorContractLoweringOption;
379 FilterConstraintType filter;
380 // Lower one parallel dimension.
381 FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
382 vector::ContractionOp op, int64_t lhsIndex,
383 int64_t rhsIndex, Value mask) const;
384 // Lower one reduction dimension.
385 FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
386 vector::ContractionOp op, Value mask) const;
387};
388
389/// Generate a vector implementation for matmat, matvec and tmatvec.
390/// This unrolls outer-products along the reduction dimension.
391struct UnrolledOuterProductGenerator
392 : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
393 UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
394 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
395 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
396 res(op.getAcc()), lhsType(op.getLhsType()) {
397 auto maskableOp = cast<MaskableOpInterface>(Val: op.getOperation());
398 if (maskableOp.isMasked())
399 mask = maskableOp.getMaskingOp().getMask();
400 }
401
402 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
403 if (!v)
404 return v;
405 return rewriter.create<vector::TransposeOp>(location: loc, args&: v, args&: perm);
406 }
407
408 Value promote(Value v, Type dstElementType) {
409 Type elementType = v.getType();
410 auto vecType = dyn_cast<VectorType>(Val&: elementType);
411 if (vecType)
412 elementType = vecType.getElementType();
413 if (elementType == dstElementType)
414 return v;
415 Type promotedType = dstElementType;
416 if (vecType)
417 promotedType = vecType.clone(elementType: promotedType);
418 if (isa<FloatType>(Val: dstElementType))
419 return rewriter.create<arith::ExtFOp>(location: loc, args&: promotedType, args&: v);
420 return rewriter.create<arith::ExtSIOp>(location: loc, args&: promotedType, args&: v);
421 }
422
423 FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
424 VectorType lhsType, int reductionSize,
425 std::optional<Value> maybeMask = std::nullopt) {
426 // Incremental support for masking.
427 if (mask && !maybeMask.has_value())
428 return failure();
429
430 Type resElementType = cast<VectorType>(Val: res.getType()).getElementType();
431 for (int64_t k = 0; k < reductionSize; ++k) {
432 Value extractA = rewriter.create<vector::ExtractOp>(location: loc, args&: lhs, args&: k);
433 Value extractB = rewriter.create<vector::ExtractOp>(location: loc, args&: rhs, args&: k);
434 extractA = promote(v: extractA, dstElementType: resElementType);
435 extractB = promote(v: extractB, dstElementType: resElementType);
436 Value extractMask;
437 if (maybeMask.has_value() && maybeMask.value())
438 extractMask =
439 rewriter.create<vector::ExtractOp>(location: loc, args&: maybeMask.value(), args&: k);
440
441 Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
442 location: loc, args: res.getType(), args&: extractA, args&: extractB, args&: res, args&: kind);
443 res = maskOperation(builder&: rewriter, maskableOp: outerProdOp, mask: extractMask)->getResult(idx: 0);
444 }
445 return res;
446 }
447
448 /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
449 /// dimension `reductionDim`. If the dimension is a scalable dimension,
450 /// returns "nullopt".
451 std::optional<int64_t> getReductionSize(VectorType vecType,
452 int64_t reductionDim) {
453 // Cannot unroll scalable dimension.
454 if (vecType.getScalableDims()[reductionDim])
455 return std::nullopt;
456 int64_t reductionSize = vecType.getDimSize(idx: reductionDim);
457 assert(reductionSize > 0 &&
458 "Reduction dim must be a known static size to allow unrolling");
459 return reductionSize;
460 }
461
462 /// Two outer parallel, one inner reduction (matmat flavor).
463 FailureOr<Value> matmat() {
464 if (!iters(its: {Par(), Par(), Red()}))
465 return failure();
466 // Set up the parallel/reduction structure in the right form.
467 AffineExpr m, n, k;
468 bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k);
469
470 // Classical row-major matmul: Just permute the lhs.
471 if (layout(l: {{m, k}, {k, n}, {m, n}})) {
472 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 1)) {
473 // Note: `t` creates new IR. It must be nested within this `if` check
474 // so that no IR is created when then pattern returns "failure".
475 Value tLhs = t(v: lhs);
476 Value tMask = t(v: mask, perm: {2, 0, 1});
477 return outerProd(lhs: tLhs, rhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
478 }
479 }
480 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
481 if (layout(l: {{m, k}, {n, k}, {m, n}})) {
482 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 1)) {
483 Value tLhs = t(v: lhs);
484 Value tRhs = t(v: rhs);
485 Value tMask = t(v: mask, perm: {2, 0, 1});
486 return outerProd(lhs: tLhs, rhs: tRhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
487 }
488 }
489 // No need to permute anything.
490 if (layout(l: {{k, m}, {k, n}, {m, n}})) {
491 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 0)) {
492 Value tMask = t(v: mask, perm: {2, 0, 1});
493 return outerProd(lhs, rhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
494 }
495 }
496 // Just permute the rhs.
497 if (layout(l: {{k, m}, {n, k}, {m, n}})) {
498 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 0)) {
499 Value tRhs = t(v: rhs);
500 Value tMask = t(v: mask, perm: {2, 0, 1});
501 return outerProd(lhs, rhs: tRhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
502 }
503 }
504 // Transposed output: swap RHS and LHS.
505 // Classical row-major matmul: permute the lhs.
506 if (layout(l: {{m, k}, {k, n}, {n, m}})) {
507 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 1)) {
508 Value tLhs = t(v: lhs);
509 Value tMask = t(v: mask, perm: {2, 0, 1});
510 return outerProd(lhs: rhs, rhs: tLhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
511 }
512 }
513 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
514 if (layout(l: {{m, k}, {n, k}, {n, m}})) {
515 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 1)) {
516 Value tRhs = t(v: rhs);
517 Value tLhs = t(v: lhs);
518 Value tMask = t(v: mask, perm: {2, 0, 1});
519 return outerProd(lhs: tRhs, rhs: tLhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
520 }
521 }
522 if (layout(l: {{k, m}, {k, n}, {n, m}})) {
523 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 0)) {
524 Value tMask = t(v: mask, perm: {2, 0, 1});
525 return outerProd(lhs: rhs, rhs: lhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
526 }
527 }
528 if (layout(l: {{k, m}, {n, k}, {n, m}})) {
529 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 0)) {
530 Value tRhs = t(v: rhs);
531 Value tMask = t(v: mask, perm: {2, 0, 1});
532 return outerProd(lhs: tRhs, rhs: lhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
533 }
534 }
535 return failure();
536 }
537
538 //
539 // One outer parallel, one inner reduction (matvec flavor).
540 // Mask needs to be transposed everywhere to turn the reduction dimension
541 // outermost as required by outerproduct.
542 //
543 FailureOr<Value> matvec() {
544 if (!iters(its: {Par(), Red()}))
545 return failure();
546 AffineExpr m, k;
547 bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: k);
548
549 // Case mat-vec: transpose.
550 if (layout(l: {{m, k}, {k}, {m}})) {
551 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 1)) {
552 Value tLhs = t(v: lhs);
553 Value tMask = t(v: mask);
554 return outerProd(lhs: tLhs, rhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
555 }
556 }
557 // Case mat-trans-vec: ready to go.
558 if (layout(l: {{k, m}, {k}, {m}})) {
559 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 0)) {
560 Value tMask = t(v: mask);
561 return outerProd(lhs, rhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
562 }
563 }
564 // Case vec-mat: swap and transpose.
565 if (layout(l: {{k}, {m, k}, {m}})) {
566 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 0)) {
567 Value tRhs = t(v: rhs);
568 Value tMask = t(v: mask);
569 return outerProd(lhs: tRhs, rhs: lhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
570 }
571 }
572 // Case vec-mat-trans: swap and ready to go.
573 if (layout(l: {{k}, {k, m}, {m}})) {
574 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 0)) {
575 Value tMask = t(v: mask);
576 return outerProd(lhs: rhs, rhs: lhs, res, lhsType, reductionSize: *reductionSize, maybeMask: tMask);
577 }
578 }
579 return failure();
580 }
581
582 //
583 // One outer reduction, one inner parallel (tmatvec flavor).
584 // Mask already has the shape of the outer product.
585 //
586 FailureOr<Value> tmatvec() {
587 if (!iters(its: {Red(), Par()}))
588 return failure();
589 AffineExpr k, m;
590 bindDims(ctx: rewriter.getContext(), exprs&: k, exprs&: m);
591
592 // Case mat-vec: transpose.
593 if (layout(l: {{m, k}, {k}, {m}}))
594 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 1))
595 return outerProd(lhs: t(v: lhs), rhs, res, lhsType, reductionSize: *reductionSize, maybeMask: mask);
596 // Case mat-trans-vec: ready to go.
597 if (layout(l: {{k, m}, {k}, {m}}))
598 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 0))
599 return outerProd(lhs, rhs, res, lhsType, reductionSize: *reductionSize, maybeMask: mask);
600 // Case vec-mat: swap and transpose.
601 if (layout(l: {{k}, {m, k}, {m}}))
602 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 0))
603 return outerProd(lhs: t(v: rhs), rhs: lhs, res, lhsType, reductionSize: *reductionSize, maybeMask: mask);
604 // Case vec-mat-trans: swap and ready to go.
605 if (layout(l: {{k}, {k, m}, {m}}))
606 if (auto reductionSize = getReductionSize(vecType: lhsType, reductionDim: 0))
607 return outerProd(lhs: rhs, rhs: lhs, res, lhsType, reductionSize: *reductionSize, maybeMask: mask);
608 return failure();
609 }
610
611private:
612 vector::CombiningKind kind;
613 Value lhs, rhs, res, mask;
614 VectorType lhsType;
615};
616
617/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
618/// semantics to a reduction_size-unrolled sequence:
619/// ```
620/// %at = vector.transpose %a, [1, 0]
621/// %bRow0 = vector.extract %b[0]
622/// %atRow0 = vector.extract %at[0]
623/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
624/// ...
625/// %bRowK = vector.extract %b[K]
626/// %atRowK = vector.extract %at[K]
627/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
628/// ```
629///
630/// This only kicks in when vectorContractLowering is set to OuterProduct but
631/// otherwise supports any layout permutation of the matrix-multiply.
632FailureOr<Value>
633ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
634 vector::ContractionOp op, MaskingOpInterface maskOp,
635 PatternRewriter &rewriter) const {
636 if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
637 return failure();
638
639 if (failed(Result: filter(op)))
640 return failure();
641
642 UnrolledOuterProductGenerator e(rewriter, op);
643 FailureOr<Value> matmatRes = e.matmat();
644 if (succeeded(Result: matmatRes)) {
645 return matmatRes;
646 }
647 FailureOr<Value> matvecRes = e.matvec();
648 if (succeeded(Result: matvecRes)) {
649 return matvecRes;
650 }
651
652 FailureOr<Value> tmatvecRes = e.tmatvec();
653 return tmatvecRes;
654}
655
656FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
657 vector::ContractionOp op, MaskingOpInterface maskOp,
658 PatternRewriter &rewriter) const {
659 // TODO: Support vector.mask.
660 if (maskOp)
661 return failure();
662
663 if (failed(Result: filter(op)))
664 return failure();
665
666 if (vectorContractLowering != vector::VectorContractLowering::Dot)
667 return failure();
668
669 auto iteratorTypes = op.getIteratorTypes().getValue();
670 static constexpr std::array<int64_t, 2> perm = {1, 0};
671 Location loc = op.getLoc();
672 Value lhs = op.getLhs(), rhs = op.getRhs();
673
674 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
675 auto infer = [&](MapList m) {
676 return AffineMap::inferFromExprList(exprsList: m, context: op.getContext());
677 };
678 AffineExpr m, n, k;
679 bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k);
680 SmallVector<AffineMap> maps = op.getIndexingMapsArray();
681 //
682 // In the following we wish to make the reduction dimension innermost so we
683 // can load vectors and just fmul + reduce into a scalar.
684 //
685 if (isParallelIterator(attr: iteratorTypes[0]) &&
686 isParallelIterator(attr: iteratorTypes[1]) &&
687 isReductionIterator(attr: iteratorTypes[2])) {
688 //
689 // Two outer parallel, one inner reduction (matmat flavor).
690 //
691 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
692 rhs = rewriter.create<vector::TransposeOp>(location: loc, args&: rhs, args: perm);
693 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
694 // No need to permute anything.
695 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
696 lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: lhs, args: perm);
697 rhs = rewriter.create<vector::TransposeOp>(location: loc, args&: rhs, args: perm);
698 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
699 lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: lhs, args: perm);
700 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
701 // This is the classical row-major matmul. Just permute the lhs.
702 Value tmp = lhs;
703 lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: rhs, args: perm);
704 rhs = tmp;
705 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
706 std::swap(a&: lhs, b&: rhs);
707 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
708 Value tmp = lhs;
709 lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: rhs, args: perm);
710 rhs = rewriter.create<vector::TransposeOp>(location: loc, args&: tmp, args: perm);
711 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
712 Value tmp = rhs;
713 rhs = rewriter.create<vector::TransposeOp>(location: loc, args&: lhs, args: perm);
714 lhs = tmp;
715 } else {
716 return failure();
717 }
718 } else if (isParallelIterator(attr: iteratorTypes[0]) &&
719 isReductionIterator(attr: iteratorTypes[1])) {
720 //
721 // One outer parallel, one inner reduction (matvec flavor)
722 //
723 if (maps == infer({{m, n}, {n}, {m}})) {
724 // No need to permute anything.
725 } else if (maps == infer({{n, m}, {n}, {m}})) {
726 lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: lhs, args: perm);
727 } else if (maps == infer({{n}, {m, n}, {m}})) {
728 std::swap(a&: lhs, b&: rhs);
729 } else if (maps == infer({{n}, {n, m}, {m}})) {
730 std::swap(a&: lhs, b&: rhs);
731 lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: lhs, args: perm);
732 } else {
733 return failure();
734 }
735 } else {
736 return failure();
737 }
738
739 VectorType dstType = cast<VectorType>(Val: op.getResultType());
740 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
741 "Expected dst type of rank 1 or 2");
742
743 unsigned rank = dstType.getRank();
744 unsigned dstRows = dstType.getShape()[0];
745 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
746
747 // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
748 Value res = rewriter.create<arith::ConstantOp>(location: loc, args&: dstType,
749 args: rewriter.getZeroAttr(type: dstType));
750 bool isInt = isa<IntegerType>(Val: dstType.getElementType());
751 llvm::SmallVector<Value> extractedCols;
752 extractedCols.reserve(N: dstColumns);
753 for (unsigned r = 0; r < dstRows; ++r) {
754 Value rowLhs = rewriter.create<vector::ExtractOp>(location: op.getLoc(), args&: lhs, args&: r);
755 for (unsigned c = 0; c < dstColumns; ++c) {
756 // Extract each respective row and column of the LHS and RHS once to
757 // avoid having duplicate SSA values pointing to the same rows/columns.
758 if (r == 0) {
759 Value colRhs =
760 rank == 1 ? rhs
761 : rewriter.create<vector::ExtractOp>(location: op.getLoc(), args&: rhs, args&: c);
762 extractedCols.push_back(Elt: colRhs);
763 }
764 Value extractedColRhs = extractedCols[c];
765 Value product =
766 createMul(loc: op.getLoc(), x: rowLhs, y: extractedColRhs, isInt, rewriter);
767 Value sum = rewriter.create<vector::ReductionOp>(
768 location: op.getLoc(), args: vector::CombiningKind::ADD, args&: product);
769
770 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
771 : SmallVector<int64_t, 2>{r, c};
772 res = rewriter.create<vector::InsertOp>(location: op.getLoc(), args&: sum, args&: res, args&: pos);
773 }
774 }
775 if (auto acc = op.getAcc())
776 res = createAdd(loc: op.getLoc(), x: res, y: acc, isInt, rewriter);
777 return res;
778}
779
780/// Lower vector.contract with all size one reduction dimensions to
781/// elementwise ops when possible.
782struct ContractOpToElementwise
783 : public MaskableOpRewritePattern<vector::ContractionOp> {
784 using MaskableOpRewritePattern::MaskableOpRewritePattern;
785 using FilterConstraintType =
786 std::function<LogicalResult(vector::ContractionOp op)>;
787 static LogicalResult defaultFilter(vector::ContractionOp op) {
788 return success();
789 }
790 ContractOpToElementwise(
791 vector::VectorContractLowering vectorContractLowering,
792 MLIRContext *context, PatternBenefit benefit = 1,
793 const FilterConstraintType &constraint = defaultFilter)
794 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
795 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
796
797 FailureOr<Value>
798 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
799 MaskingOpInterface maskOp,
800 PatternRewriter &rewriter) const override {
801 // TODO: Support vector.mask.
802 if (maskOp)
803 return failure();
804
805 if (failed(Result: filter(contractOp)))
806 return failure();
807
808 if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
809 return failure();
810
811 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
812 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
813 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
814 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
815 SmallVector<int64_t> lhsReductionDims =
816 getReductionIndex(map: lhsMap, iteratorTypes: contractOp.getIteratorTypes());
817 SmallVector<int64_t> rhsReductionDims =
818 getReductionIndex(map: rhsMap, iteratorTypes: contractOp.getIteratorTypes());
819 // All the reduction dimensions must be a size 1.
820 for (int64_t dim : lhsReductionDims) {
821 if (lhsShape[dim] != 1)
822 return failure();
823 }
824 for (int64_t dim : rhsReductionDims) {
825 if (rhsShape[dim] != 1)
826 return failure();
827 }
828 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
829 unsigned numParallelDims = accMap.getNumResults();
830 unsigned numLhsDimToBroadcast =
831 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
832 unsigned numRhsDimToBroadcast =
833 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
834 SmallVector<int64_t> lhsDims;
835 SmallVector<int64_t> lhsTranspose;
836 SmallVector<int64_t> rhsDims;
837 SmallVector<int64_t> rhsTranspose;
838 for (int64_t dim : lhsReductionDims)
839 lhsTranspose.push_back(Elt: numLhsDimToBroadcast + dim);
840 for (int64_t dim : rhsReductionDims)
841 rhsTranspose.push_back(Elt: numRhsDimToBroadcast + dim);
842 // Loop through the parallel dimensions to calculate the dimensions to
843 // broadcast and to permute in order to extract only parallel dimensions.
844 for (unsigned i = 0; i < numParallelDims; i++) {
845 std::optional<unsigned> lhsDim =
846 getDimPosition(map: lhsMap, dim: accMap.getDimPosition(idx: i));
847 if (lhsDim) {
848 lhsTranspose.push_back(Elt: numLhsDimToBroadcast + *lhsDim);
849 } else {
850 // If the parallel dimension doesn't exist we will have to broadcast it.
851 lhsDims.push_back(
852 Elt: cast<VectorType>(Val: contractOp.getResultType()).getDimSize(idx: i));
853 lhsTranspose.push_back(Elt: lhsDims.size() - 1);
854 }
855 std::optional<unsigned> rhsDim =
856 getDimPosition(map: rhsMap, dim: accMap.getDimPosition(idx: i));
857 if (rhsDim) {
858 rhsTranspose.push_back(Elt: numRhsDimToBroadcast + *rhsDim);
859 } else {
860 // If the parallel dimension doesn't exist we will have to broadcast it.
861 rhsDims.push_back(
862 Elt: cast<VectorType>(Val: contractOp.getResultType()).getDimSize(idx: i));
863 rhsTranspose.push_back(Elt: rhsDims.size() - 1);
864 }
865 }
866 Value newLhs = contractOp.getLhs();
867 Value newRhs = contractOp.getRhs();
868 Location loc = contractOp.getLoc();
869 if (!lhsDims.empty()) {
870 lhsDims.append(in_start: lhsShape.begin(), in_end: lhsShape.end());
871 auto expandedType =
872 VectorType::get(shape: lhsDims, elementType: contractOp.getLhsType().getElementType());
873 newLhs = rewriter.create<vector::BroadcastOp>(location: loc, args&: expandedType, args&: newLhs);
874 }
875 if (!rhsDims.empty()) {
876 rhsDims.append(in_start: rhsShape.begin(), in_end: rhsShape.end());
877 auto expandedType =
878 VectorType::get(shape: rhsDims, elementType: contractOp.getRhsType().getElementType());
879 newRhs = rewriter.create<vector::BroadcastOp>(location: loc, args&: expandedType, args&: newRhs);
880 }
881 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
882 newLhs = rewriter.create<vector::TransposeOp>(location: loc, args&: newLhs, args&: lhsTranspose);
883 newRhs = rewriter.create<vector::TransposeOp>(location: loc, args&: newRhs, args&: rhsTranspose);
884 SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
885 SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
886 newLhs = rewriter.create<vector::ExtractOp>(location: loc, args&: newLhs, args&: lhsOffsets);
887 newRhs = rewriter.create<vector::ExtractOp>(location: loc, args&: newRhs, args&: rhsOffsets);
888 std::optional<Value> result =
889 createContractArithOp(loc, x: newLhs, y: newRhs, acc: contractOp.getAcc(),
890 kind: contractOp.getKind(), rewriter, isInt);
891 if (result)
892 return *result;
893
894 return failure();
895 }
896
897private:
898 /// Options to control the vector patterns.
899 vector::VectorContractLowering vectorContractLowering;
900 FilterConstraintType filter;
901};
902
903/// Progressive lowering of ContractionOp.
904/// One:
905/// %x = vector.contract with at least one free/batch dimension
906/// is replaced by:
907/// %a = vector.contract with one less free/batch dimension
908/// %b = vector.contract with one less free/batch dimension
909/// ..
910/// %x = combine %a %b ..
911/// until a pure contraction is reached (no free/batch dimensions),
912/// which is replaced by a dot-product.
913///
914/// This only kicks in when either vectorContractLoweringOption is set
915/// to DOT or when other contraction patterns fail.
916//
917// TODO: break down into transpose/reshape/cast ops
918// when they become available to avoid code dup
919// TODO: investigate lowering order impact on performance
920FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
921 vector::ContractionOp op, MaskingOpInterface maskOp,
922 PatternRewriter &rewriter) const {
923 if (failed(Result: filter(op)))
924 return failure();
925
926 // TODO: support mixed mode contract lowering.
927 if (op.getLhsType().getElementType() !=
928 getElementTypeOrSelf(type: op.getAccType()) ||
929 op.getRhsType().getElementType() != getElementTypeOrSelf(type: op.getAccType()))
930 return failure();
931
932 // TODO: the code below assumes the default contraction, make sure it supports
933 // other kinds before enabling this lowering.
934 if (op.getKind() != vector::CombiningKind::ADD) {
935 return rewriter.notifyMatchFailure(
936 arg&: op, msg: "contractions other than 'add' not supported");
937 }
938
939 // TODO: implement benefits, cost models.
940 MLIRContext *ctx = op.getContext();
941
942 ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
943 FailureOr<Value> newVal1 =
944 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
945 if (!failed(Result: newVal1))
946 return newVal1;
947
948 ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
949 FailureOr<Value> newVal2 =
950 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
951 if (!failed(Result: newVal2))
952 return newVal2;
953
954 ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
955 FailureOr<Value> newVal3 =
956 pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
957 if (!failed(Result: newVal3))
958 return newVal3;
959
960 ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
961 FailureOr<Value> newVal4 =
962 pat4.matchAndRewriteMaskableOp(contractOp: op, maskOp, rewriter);
963 if (!failed(Result: newVal4))
964 return newVal4;
965
966 // Vector mask setup.
967
968 Value mask;
969 if (maskOp)
970 mask = maskOp.getMask();
971 // Find first batch dimension in LHS/RHS, and lower when found.
972 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
973 if (!batchDimMap.empty()) {
974 int64_t lhsIndex = batchDimMap[0].first;
975 int64_t rhsIndex = batchDimMap[0].second;
976 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
977 if (failed(Result: newOp))
978 return failure();
979 return newOp;
980 }
981
982 // Collect contracting dimensions.
983 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
984 op.getContractingDimMap();
985 DenseSet<int64_t> lhsContractingDimSet;
986 DenseSet<int64_t> rhsContractingDimSet;
987 for (auto &dimPair : contractingDimMap) {
988 lhsContractingDimSet.insert(V: dimPair.first);
989 rhsContractingDimSet.insert(V: dimPair.second);
990 }
991
992 // Find first free dimension in LHS, and lower when found.
993 VectorType lhsType = op.getLhsType();
994 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
995 if (lhsContractingDimSet.count(V: lhsIndex) == 0) {
996 auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
997 if (failed(Result: newOp))
998 return failure();
999 return newOp;
1000 }
1001 }
1002
1003 // Find first free dimension in RHS, and lower when found.
1004 VectorType rhsType = op.getRhsType();
1005 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1006 if (rhsContractingDimSet.count(V: rhsIndex) == 0) {
1007 auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
1008 if (failed(Result: newOp))
1009 return failure();
1010 return newOp;
1011 }
1012 }
1013
1014 // Lower the first remaining reduction dimension.
1015 if (!contractingDimMap.empty()) {
1016 auto newOp = lowerReduction(rewriter, op, mask);
1017 if (failed(Result: newOp))
1018 return failure();
1019 return newOp;
1020 }
1021
1022 return failure();
1023}
1024
1025// Lower one parallel dimension.
1026// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1027// TODO: consider reusing existing contract unrolling
1028FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
1029 vector::ContractionOp op,
1030 int64_t lhsIndex,
1031 int64_t rhsIndex,
1032 Value mask) const {
1033 VectorType lhsType = op.getLhsType();
1034 VectorType rhsType = op.getRhsType();
1035 VectorType resType = cast<VectorType>(Val: op.getResultType());
1036 // Find the iterator type index and result index.
1037 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1038 int64_t iterIndex = -1;
1039 int64_t dimSize = -1;
1040 if (lhsIndex >= 0) {
1041 iterIndex = iMap[0].getDimPosition(idx: lhsIndex);
1042 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(idx: rhsIndex))
1043 return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) {
1044 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1045 << " to map to the same dimension";
1046 });
1047 if (lhsType.getScalableDims()[lhsIndex])
1048 return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) {
1049 diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1050 << ") is not supported yet";
1051 });
1052 dimSize = lhsType.getDimSize(idx: lhsIndex);
1053 } else if (rhsIndex >= 0) {
1054 iterIndex = iMap[1].getDimPosition(idx: rhsIndex);
1055 if (rhsType.getScalableDims()[rhsIndex])
1056 return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) {
1057 diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1058 << ") is not supported yet";
1059 });
1060 dimSize = rhsType.getDimSize(idx: rhsIndex);
1061 }
1062 if (iterIndex < 0)
1063 return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) {
1064 diag << "expected either lhsIndex=" << lhsIndex
1065 << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1066 });
1067 // value_or(-1) means that we tolerate a dimension not appearing
1068 // in the result map. That can't happen for actual parallel iterators, but
1069 // the caller ContractionOpLowering::matchAndRewrite is currently calling
1070 // lowerParallel also for the case of unit-size reduction dims appearing only
1071 // on one of LHS or RHS, not both. At the moment, such cases are created by
1072 // CastAwayContractionLeadingOneDim, so we need to either support that or
1073 // modify that pattern.
1074 int64_t resIndex = getResultIndex(map: iMap[2], index: iterIndex).value_or(u: -1);
1075 if (resIndex == -1 && dimSize != 1)
1076 return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) {
1077 diag << "expected the dimension for iterIndex=" << iterIndex
1078 << " to either appear in the result map, or to be a unit dimension";
1079 });
1080
1081 // Construct new iterator types and affine map array attribute.
1082 std::array<AffineMap, 3> lowIndexingMaps = {
1083 adjustMap(map: iMap[0], index: iterIndex, rewriter),
1084 adjustMap(map: iMap[1], index: iterIndex, rewriter),
1085 adjustMap(map: iMap[2], index: iterIndex, rewriter)};
1086 auto lowAffine = rewriter.getAffineMapArrayAttr(values: lowIndexingMaps);
1087 auto lowIter =
1088 rewriter.getArrayAttr(value: adjustIter(iteratorTypes: op.getIteratorTypes(), index: iterIndex));
1089 // Unroll into a series of lower dimensional vector.contract ops.
1090 Location loc = op.getLoc();
1091 Value result = rewriter.create<arith::ConstantOp>(
1092 location: loc, args&: resType, args: rewriter.getZeroAttr(type: resType));
1093
1094 for (int64_t d = 0; d < dimSize; ++d) {
1095 auto lhs = reshapeLoad(loc, val: op.getLhs(), type: lhsType, index: lhsIndex, pos: d, rewriter);
1096 auto rhs = reshapeLoad(loc, val: op.getRhs(), type: rhsType, index: rhsIndex, pos: d, rewriter);
1097 auto acc = reshapeLoad(loc, val: op.getAcc(), type: resType, index: resIndex, pos: d, rewriter);
1098
1099 Value lowMask;
1100 if (mask)
1101 lowMask = reshapeLoad(loc, val: mask, type: cast<VectorType>(Val: mask.getType()),
1102 index: iterIndex, pos: d, rewriter);
1103
1104 Operation *lowContract = rewriter.create<vector::ContractionOp>(
1105 location: loc, args&: lhs, args&: rhs, args&: acc, args&: lowAffine, args&: lowIter);
1106 lowContract = maskOperation(builder&: rewriter, maskableOp: lowContract, mask: lowMask);
1107 result = reshapeStore(loc, val: lowContract->getResult(idx: 0), result, type: resType,
1108 index: resIndex, pos: d, rewriter);
1109 }
1110 return result;
1111}
1112
1113// Lower one reduction dimension.
1114FailureOr<Value> ContractionOpLowering::lowerReduction(
1115 PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1116 auto loc = op.getLoc();
1117 VectorType lhsType = op.getLhsType();
1118 VectorType rhsType = op.getRhsType();
1119 Type resType = op.getResultType();
1120 if (isa<VectorType>(Val: resType))
1121 return rewriter.notifyMatchFailure(arg&: op,
1122 msg: "did not expect a VectorType result");
1123 bool isInt = isa<IntegerType>(Val: resType);
1124 // Use iterator index 0.
1125 int64_t iterIndex = 0;
1126 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1127 std::optional<int64_t> lookupLhs = getResultIndex(map: iMap[0], index: iterIndex);
1128 std::optional<int64_t> lookupRhs = getResultIndex(map: iMap[1], index: iterIndex);
1129 if (!lookupLhs.has_value())
1130 return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) {
1131 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1132 });
1133 if (!lookupRhs.has_value())
1134 return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) {
1135 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1136 });
1137 int64_t lhsIndex = *lookupLhs;
1138 int64_t rhsIndex = *lookupRhs;
1139 int64_t dimSize = lhsType.getDimSize(idx: lhsIndex);
1140 if (dimSize != rhsType.getDimSize(idx: rhsIndex))
1141 return rewriter.notifyMatchFailure(op, reasonCallback: [&](Diagnostic &diag) {
1142 diag << "expect LHS dimension " << lhsIndex
1143 << " to have the same size as RHS dimension " << rhsIndex;
1144 });
1145 // Base case.
1146 if (lhsType.getRank() == 1) {
1147 if (rhsType.getRank() != 1)
1148 return rewriter.notifyMatchFailure(
1149 arg&: op, msg: "When LHS has rank 1, expected also RHS to have rank 1");
1150 Value m = createMul(loc, x: op.getLhs(), y: op.getRhs(), isInt, rewriter);
1151 auto kind = vector::CombiningKind::ADD;
1152
1153 Value acc = op.getAcc();
1154 Operation *reductionOp =
1155 acc ? rewriter.create<vector::ReductionOp>(location: loc, args&: kind, args&: m, args&: acc)
1156 : rewriter.create<vector::ReductionOp>(location: loc, args&: kind, args&: m);
1157 return maskOperation(builder&: rewriter, maskableOp: reductionOp, mask)->getResult(idx: 0);
1158 }
1159 // Construct new iterator types and affine map array attribute.
1160 std::array<AffineMap, 3> lowIndexingMaps = {
1161 adjustMap(map: iMap[0], index: iterIndex, rewriter),
1162 adjustMap(map: iMap[1], index: iterIndex, rewriter),
1163 adjustMap(map: iMap[2], index: iterIndex, rewriter)};
1164 auto lowAffine = rewriter.getAffineMapArrayAttr(values: lowIndexingMaps);
1165 auto lowIter =
1166 rewriter.getArrayAttr(value: adjustIter(iteratorTypes: op.getIteratorTypes(), index: iterIndex));
1167 // Unroll into a series of lower dimensional vector.contract ops.
1168 // By feeding the initial accumulator into the first contraction,
1169 // and the result of each contraction into the next, eventually
1170 // the sum of all reductions is computed.
1171 Value result = op.getAcc();
1172 for (int64_t d = 0; d < dimSize; ++d) {
1173 auto lhs = reshapeLoad(loc, val: op.getLhs(), type: lhsType, index: lhsIndex, pos: d, rewriter);
1174 auto rhs = reshapeLoad(loc, val: op.getRhs(), type: rhsType, index: rhsIndex, pos: d, rewriter);
1175 Value newMask;
1176 if (mask)
1177 newMask = reshapeLoad(loc, val: mask, type: cast<VectorType>(Val: mask.getType()),
1178 index: iterIndex, pos: d, rewriter);
1179
1180 Operation *newContract = rewriter.create<vector::ContractionOp>(
1181 location: loc, args&: lhs, args&: rhs, args&: result, args&: lowAffine, args&: lowIter);
1182 result = maskOperation(builder&: rewriter, maskableOp: newContract, mask: newMask)->getResult(idx: 0);
1183 }
1184 return result;
1185}
1186
1187/// Progressive lowering of OuterProductOp.
1188/// One:
1189/// %x = vector.outerproduct %lhs, %rhs, %acc
1190/// is replaced by:
1191/// %z = zero-result
1192/// %0 = vector.extract %lhs[0]
1193/// %1 = vector.broadcast %0
1194/// %2 = vector.extract %acc[0]
1195/// %3 = vector.fma %1, %rhs, %2
1196/// %4 = vector.insert %3, %z[0]
1197/// ..
1198/// %x = vector.insert %.., %..[N-1]
1199///
1200class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1201public:
1202 using OpRewritePattern::OpRewritePattern;
1203
1204 LogicalResult matchAndRewrite(vector::OuterProductOp op,
1205 PatternRewriter &rewriter) const override {
1206 VectorType resType = op.getResultVectorType();
1207 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1208 return failure();
1209
1210 auto loc = op.getLoc();
1211
1212 VectorType lhsType = op.getOperandVectorTypeLHS();
1213 VectorType rhsType = dyn_cast<VectorType>(Val: op.getOperandTypeRHS());
1214 Type eltType = resType.getElementType();
1215 bool isInt = isa<IntegerType, IndexType>(Val: eltType);
1216 Value acc = op.getAcc();
1217 vector::CombiningKind kind = op.getKind();
1218
1219 // Vector mask setup.
1220 OpBuilder::InsertionGuard guard(rewriter);
1221 auto maskableOp = cast<vector::MaskableOpInterface>(Val: op.getOperation());
1222 Operation *rootOp;
1223 Value mask;
1224 if (maskableOp.isMasked()) {
1225 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1226 rootOp = maskableOp.getMaskingOp();
1227 mask = maskableOp.getMaskingOp().getMask();
1228 } else {
1229 rootOp = op;
1230 }
1231
1232 if (!rhsType) {
1233 // Special case: AXPY operation.
1234 Value b = rewriter.create<vector::BroadcastOp>(location: loc, args&: lhsType, args: op.getRhs());
1235 std::optional<Value> mult = createContractArithOp(
1236 loc, x: op.getLhs(), y: b, acc, kind, rewriter, isInt, mask);
1237 if (!mult.has_value())
1238 return failure();
1239 rewriter.replaceOp(op: rootOp, newValues: *mult);
1240 return success();
1241 }
1242
1243 Value result = rewriter.create<arith::ConstantOp>(
1244 location: loc, args&: resType, args: rewriter.getZeroAttr(type: resType));
1245 for (int64_t d = 0, e = resType.getDimSize(idx: 0); d < e; ++d) {
1246 Value x = rewriter.create<vector::ExtractOp>(location: loc, args: op.getLhs(), args&: d);
1247 Value a = rewriter.create<vector::BroadcastOp>(location: loc, args&: rhsType, args&: x);
1248 Value r = nullptr;
1249 if (acc)
1250 r = rewriter.create<vector::ExtractOp>(location: loc, args&: acc, args&: d);
1251 Value extrMask;
1252 if (mask)
1253 extrMask = rewriter.create<vector::ExtractOp>(location: loc, args&: mask, args&: d);
1254
1255 std::optional<Value> m = createContractArithOp(
1256 loc, x: a, y: op.getRhs(), acc: r, kind, rewriter, isInt, mask: extrMask);
1257 if (!m.has_value())
1258 return failure();
1259 result = rewriter.create<vector::InsertOp>(location: loc, args&: *m, args&: result, args&: d);
1260 }
1261
1262 rewriter.replaceOp(op: rootOp, newValues: result);
1263 return success();
1264 }
1265};
1266
1267/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1268/// semantics to:
1269/// ```
1270/// %mta = maybe_transpose
1271/// %mtb = maybe_transpose
1272/// %flattened_a = vector.shape_cast %mta
1273/// %flattened_b = vector.shape_cast %mtb
1274/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
1275/// %mtd = vector.shape_cast %flattened_d
1276/// %d = maybe_untranspose %mtd
1277/// %e = add %c, %d
1278/// ```
1279/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
1280//
1281/// This only kicks in when vectorContractLowering is set to `Matmul`.
1282/// vector.transpose operations are inserted if the vector.contract op is not a
1283/// row-major matrix multiply.
1284///
1285/// Scalable vectors are not supported.
1286FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1287 vector::ContractionOp op, MaskingOpInterface maskOp,
1288 PatternRewriter &rew) const {
1289 // TODO: Support vector.mask.
1290 if (maskOp)
1291 return failure();
1292
1293 if (vectorContractLowering != vector::VectorContractLowering::Matmul)
1294 return failure();
1295 if (failed(Result: filter(op)))
1296 return failure();
1297
1298 auto iteratorTypes = op.getIteratorTypes().getValue();
1299 if (!isParallelIterator(attr: iteratorTypes[0]) ||
1300 !isParallelIterator(attr: iteratorTypes[1]) ||
1301 !isReductionIterator(attr: iteratorTypes[2]))
1302 return failure();
1303
1304 Type opResType = op.getType();
1305 VectorType vecType = dyn_cast<VectorType>(Val&: opResType);
1306 if (vecType && vecType.isScalable()) {
1307 // Note - this is sufficient to reject all cases with scalable vectors.
1308 return failure();
1309 }
1310
1311 Type elementType = op.getLhsType().getElementType();
1312 if (!elementType.isIntOrFloat())
1313 return failure();
1314
1315 Type dstElementType = vecType ? vecType.getElementType() : opResType;
1316 if (elementType != dstElementType)
1317 return failure();
1318
1319 // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1320 // Bail out if the contraction cannot be put in this form.
1321 MLIRContext *ctx = op.getContext();
1322 Location loc = op.getLoc();
1323 AffineExpr m, n, k;
1324 bindDims(ctx: rew.getContext(), exprs&: m, exprs&: n, exprs&: k);
1325 // LHS must be A(m, k) or A(k, m).
1326 Value lhs = op.getLhs();
1327 auto lhsMap = op.getIndexingMapsArray()[0];
1328 if (lhsMap == AffineMap::get(dimCount: 3, symbolCount: 0, results: {k, m}, context: ctx))
1329 lhs = rew.create<vector::TransposeOp>(location: loc, args&: lhs, args: ArrayRef<int64_t>{1, 0});
1330 else if (lhsMap != AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, k}, context: ctx))
1331 return failure();
1332
1333 // RHS must be B(k, n) or B(n, k).
1334 Value rhs = op.getRhs();
1335 auto rhsMap = op.getIndexingMapsArray()[1];
1336 if (rhsMap == AffineMap::get(dimCount: 3, symbolCount: 0, results: {n, k}, context: ctx))
1337 rhs = rew.create<vector::TransposeOp>(location: loc, args&: rhs, args: ArrayRef<int64_t>{1, 0});
1338 else if (rhsMap != AffineMap::get(dimCount: 3, symbolCount: 0, results: {k, n}, context: ctx))
1339 return failure();
1340
1341 // At this point lhs and rhs are in row-major.
1342 VectorType lhsType = cast<VectorType>(Val: lhs.getType());
1343 VectorType rhsType = cast<VectorType>(Val: rhs.getType());
1344 int64_t lhsRows = lhsType.getDimSize(idx: 0);
1345 int64_t lhsColumns = lhsType.getDimSize(idx: 1);
1346 int64_t rhsColumns = rhsType.getDimSize(idx: 1);
1347
1348 Type flattenedLHSType =
1349 VectorType::get(shape: lhsType.getNumElements(), elementType: lhsType.getElementType());
1350 lhs = rew.create<vector::ShapeCastOp>(location: loc, args&: flattenedLHSType, args&: lhs);
1351
1352 Type flattenedRHSType =
1353 VectorType::get(shape: rhsType.getNumElements(), elementType: rhsType.getElementType());
1354 rhs = rew.create<vector::ShapeCastOp>(location: loc, args&: flattenedRHSType, args&: rhs);
1355
1356 Value mul = rew.create<vector::MatmulOp>(location: loc, args&: lhs, args&: rhs, args&: lhsRows, args&: lhsColumns,
1357 args&: rhsColumns);
1358 mul = rew.create<vector::ShapeCastOp>(
1359 location: loc,
1360 args: VectorType::get(shape: {lhsRows, rhsColumns},
1361 elementType: getElementTypeOrSelf(type: op.getAcc().getType())),
1362 args&: mul);
1363
1364 // ACC must be C(m, n) or C(n, m).
1365 auto accMap = op.getIndexingMapsArray()[2];
1366 if (accMap == AffineMap::get(dimCount: 3, symbolCount: 0, results: {n, m}, context: ctx))
1367 mul = rew.create<vector::TransposeOp>(location: loc, args&: mul, args: ArrayRef<int64_t>{1, 0});
1368 else if (accMap != AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, n}, context: ctx))
1369 llvm_unreachable("invalid contraction semantics");
1370
1371 Value res =
1372 isa<IntegerType>(Val: elementType)
1373 ? static_cast<Value>(rew.create<arith::AddIOp>(location: loc, args: op.getAcc(), args&: mul))
1374 : static_cast<Value>(
1375 rew.create<arith::AddFOp>(location: loc, args: op.getAcc(), args&: mul));
1376
1377 return res;
1378}
1379} // namespace
1380
1381void mlir::vector::populateVectorContractLoweringPatterns(
1382 RewritePatternSet &patterns,
1383 VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1384 bool disableOuterProductLowering) {
1385 if (!disableOuterProductLowering)
1386 patterns.add<OuterProductOpLowering>(arg: patterns.getContext(), args&: benefit);
1387 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1388 ContractionOpToOuterProductOpLowering>(
1389 arg&: vectorContractLoweringOption, args: patterns.getContext(), args&: benefit);
1390}
1391
1392void mlir::vector::populateVectorOuterProductLoweringPatterns(
1393 RewritePatternSet &patterns, PatternBenefit benefit) {
1394 patterns.add<OuterProductOpLowering>(arg: patterns.getContext(), args&: benefit);
1395}
1396

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp