1//===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.contract' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26#include "mlir/IR/BuiltinAttributeInterfaces.h"
27#include "mlir/IR/BuiltinTypes.h"
28#include "mlir/IR/ImplicitLocOpBuilder.h"
29#include "mlir/IR/Location.h"
30#include "mlir/IR/Matchers.h"
31#include "mlir/IR/PatternMatch.h"
32#include "mlir/IR/TypeUtilities.h"
33#include "mlir/Interfaces/VectorInterfaces.h"
34
35#define DEBUG_TYPE "vector-contract-lowering"
36
37using namespace mlir;
38using namespace mlir::vector;
39
40//===----------------------------------------------------------------------===//
41// Helper functions
42//===----------------------------------------------------------------------===//
43// Helper to find an index in an affine map.
44static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
45 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
46 int64_t idx = map.getDimPosition(idx: i);
47 if (idx == index)
48 return i;
49 }
50 return std::nullopt;
51}
52
53// Helper to construct iterator types with one index removed.
54static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
55 int64_t index) {
56 SmallVector<Attribute> results;
57 for (const auto &it : llvm::enumerate(iteratorTypes)) {
58 int64_t idx = it.index();
59 if (idx == index)
60 continue;
61 results.push_back(it.value());
62 }
63 return results;
64}
65
66// Helper to construct an affine map with one index removed.
67static AffineMap adjustMap(AffineMap map, int64_t index,
68 PatternRewriter &rewriter) {
69 auto *ctx = rewriter.getContext();
70 SmallVector<AffineExpr> results;
71 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
72 int64_t idx = map.getDimPosition(idx: i);
73 if (idx == index)
74 continue;
75 // Re-insert remaining indices, but renamed when occurring
76 // after the removed index.
77 auto targetExpr = getAffineDimExpr(position: idx < index ? idx : idx - 1, context: ctx);
78 results.push_back(Elt: targetExpr);
79 }
80 return AffineMap::get(dimCount: map.getNumDims() - 1, symbolCount: 0, results, context: ctx);
81}
82
83// Helper method to possibly drop a dimension in a load.
84// TODO
85static Value reshapeLoad(Location loc, Value val, VectorType type,
86 int64_t index, int64_t pos,
87 PatternRewriter &rewriter) {
88 if (index == -1)
89 return val;
90
91 // At extraction dimension?
92 if (index == 0)
93 return rewriter.create<vector::ExtractOp>(loc, val, pos);
94
95 // Unroll leading dimensions.
96 VectorType vType = VectorType::Builder(type).dropDim(0);
97 VectorType resType = VectorType::Builder(type).dropDim(index);
98 Value result = rewriter.create<arith::ConstantOp>(
99 loc, resType, rewriter.getZeroAttr(resType));
100 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
101 Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
102 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
103 result = rewriter.create<vector::InsertOp>(loc, load, result, d);
104 }
105 return result;
106}
107
108// Helper method to possibly drop a dimension in a store.
109// TODO
110static Value reshapeStore(Location loc, Value val, Value result,
111 VectorType type, int64_t index, int64_t pos,
112 PatternRewriter &rewriter) {
113 // Unmodified?
114 if (index == -1)
115 return val;
116 // At insertion dimension?
117 if (index == 0)
118 return rewriter.create<vector::InsertOp>(loc, val, result, pos);
119
120 // Unroll leading dimensions.
121 VectorType vType = VectorType::Builder(type).dropDim(0);
122 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
123 Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
124 Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
125 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
126 result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
127 }
128 return result;
129}
130
131/// Helper to create arithmetic operation associated with a kind of contraction.
132static std::optional<Value>
133createContractArithOp(Location loc, Value x, Value y, Value acc,
134 vector::CombiningKind kind, PatternRewriter &rewriter,
135 bool isInt, Value mask = Value()) {
136 using vector::CombiningKind;
137 Value mul;
138
139 if (isInt) {
140 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
141 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
142 // Only valid for floating point types.
143 return std::nullopt;
144 mul = rewriter.create<arith::MulIOp>(loc, x, y);
145 } else {
146 // Float case.
147 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
148 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
149 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
150 kind == CombiningKind::XOR)
151 // Only valid for integer types.
152 return std::nullopt;
153 // Special case for fused multiply-add.
154 if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
155 Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
156 if (mask)
157 // The fma op doesn't need explicit masking. However, fma ops used in
158 // reductions must preserve previous 'acc' values for masked-out lanes.
159 fma = selectPassthru(builder&: rewriter, mask, newValue: fma, passthru: acc);
160 return fma;
161 }
162 mul = rewriter.create<arith::MulFOp>(loc, x, y);
163 }
164
165 if (!acc)
166 return std::optional<Value>(mul);
167
168 return makeArithReduction(rewriter, loc, kind, mul, acc,
169 /*fastmath=*/nullptr, mask);
170}
171
172/// Return the positions of the reductions in the given map.
173static SmallVector<int64_t> getReductionIndex(AffineMap map,
174 ArrayAttr iteratorTypes) {
175 SmallVector<int64_t> dimsIdx;
176 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
177 if (isReductionIterator(iteratorTypes[map.getDimPosition(idx: i)]))
178 dimsIdx.push_back(Elt: i);
179 }
180 return dimsIdx;
181}
182
183/// Look for a given dimension in an affine map and return its position. Return
184/// std::nullopt if the dimension is not in the map results.
185static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
186 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
187 if (map.getDimPosition(idx: i) == dim)
188 return i;
189 }
190 return std::nullopt;
191}
192
193/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
194/// operands `x` and `y`.
195static Value createAdd(Location loc, Value x, Value y, bool isInt,
196 PatternRewriter &rewriter) {
197 if (isInt)
198 return rewriter.create<arith::AddIOp>(loc, x, y);
199 return rewriter.create<arith::AddFOp>(loc, x, y);
200}
201
202/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
203/// operands `x and `y`.
204static Value createMul(Location loc, Value x, Value y, bool isInt,
205 PatternRewriter &rewriter) {
206 if (isInt)
207 return rewriter.create<arith::MulIOp>(loc, x, y);
208 return rewriter.create<arith::MulFOp>(loc, x, y);
209}
210
211namespace {
212
213/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
214/// semantics to:
215/// ```
216/// %flattened_a = vector.shape_cast %a
217/// %flattened_b = vector.shape_cast %b
218/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
219/// %d = vector.shape_cast %%flattened_d
220/// %e = add %c, %d
221/// ```
222/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
223//
224/// This only kicks in when vectorContractLowering is set to Matmul and
225/// the vector.contract op is a row-major matrix multiply.
226class ContractionOpToMatmulOpLowering
227 : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
228public:
229 using MaskableOpRewritePattern::MaskableOpRewritePattern;
230
231 using FilterConstraintType =
232 std::function<LogicalResult(vector::ContractionOp op)>;
233
234 static LogicalResult defaultFilter(vector::ContractionOp op) {
235 return success();
236 }
237
238 ContractionOpToMatmulOpLowering(
239 vector::VectorContractLowering vectorContractLowering,
240 MLIRContext *context, PatternBenefit benefit = 1,
241 FilterConstraintType constraint = defaultFilter)
242 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
243 vectorContractLowering(vectorContractLowering),
244 filter(std::move(constraint)) {}
245
246 FailureOr<Value>
247 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
248 PatternRewriter &rewriter) const override;
249
250private:
251 /// Options to control the vector patterns.
252 vector::VectorContractLowering vectorContractLowering;
253 FilterConstraintType filter;
254};
255
256/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
257/// semantics to a reduction_size-unrolled sequence:
258/// ```
259/// %at = vector.transpose %a, [1, 0]
260/// %bRow0 = vector.extract %b[0]
261/// %atRow0 = vector.extract %at[0]
262/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
263/// ...
264/// %bRowK = vector.extract %b[K]
265/// %atRowK = vector.extract %at[K]
266/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
267/// ```
268///
269/// This only kicks in when vectorContractLowering is set to OuterProduct and
270/// the vector.contract op is a row-major matrix multiply.
271class ContractionOpToOuterProductOpLowering
272 : public MaskableOpRewritePattern<vector::ContractionOp> {
273public:
274 using MaskableOpRewritePattern::MaskableOpRewritePattern;
275
276 using FilterConstraintType =
277 std::function<LogicalResult(vector::ContractionOp op)>;
278
279 static LogicalResult defaultFilter(vector::ContractionOp op) {
280 return success();
281 }
282
283 ContractionOpToOuterProductOpLowering(
284 vector::VectorContractLowering vectorContractLowering,
285 MLIRContext *context, PatternBenefit benefit = 1,
286 FilterConstraintType constraint = defaultFilter)
287 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
288 vectorContractLowering(vectorContractLowering),
289 filter(std::move(constraint)) {}
290
291 FailureOr<Value>
292 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
293 PatternRewriter &rewriter) const override;
294
295private:
296 /// Options to control the vector patterns.
297 vector::VectorContractLowering vectorContractLowering;
298 FilterConstraintType filter;
299};
300
301/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
302/// semantics to an output-size-unrolled sequence:
303/// ```
304/// %out = arith.constant ... : vector<MxNxelt_type>
305/// %bt = vector.transpose %b, [1, 0]
306/// %aRow0 = vector.extract %a[0]
307/// %btRow0 = vector.extract %bt[0]
308/// %c00 = vector.reduce %atRow0, %bRow0
309/// %out00 = vector.insert %c00, %out[0, 0]
310/// ...
311/// %aRowLast = vector.extract %at[M-1]
312/// %btRowLast = vector.extract %b[N-1]
313/// %cLastLast = vector.reduce %atRowLast, %bRowLast
314/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
315/// ```
316///
317/// This only kicks in when VectorTransformsOptions is set to Dot and
318/// the vector.contract op is a row-major matmul or matvec.
319class ContractionOpToDotLowering
320 : public MaskableOpRewritePattern<vector::ContractionOp> {
321public:
322 using MaskableOpRewritePattern::MaskableOpRewritePattern;
323
324 using FilterConstraintType =
325 std::function<LogicalResult(vector::ContractionOp op)>;
326
327 static LogicalResult defaultFilter(vector::ContractionOp op) {
328 return success();
329 }
330
331 ContractionOpToDotLowering(
332 vector::VectorContractLowering vectorContractLowering,
333 MLIRContext *context, PatternBenefit benefit = 1,
334 const FilterConstraintType &constraint = defaultFilter)
335 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
336 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
337
338 FailureOr<Value>
339 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
340 PatternRewriter &rewriter) const override;
341
342private:
343 /// Options to control the vector patterns.
344 vector::VectorContractLowering vectorContractLowering;
345 FilterConstraintType filter;
346};
347
348/// Progressive lowering of ContractionOp.
349///
350/// One:
351/// %x = vector.contract with at least one free/batch dimension
352/// is replaced by:
353/// %a = vector.contract with one less free/batch dimension
354/// %b = vector.contract with one less free/batch dimension
355/// ..
356/// %x = combine %a %b ..
357/// until a pure contraction is reached (no free/batch dimensions),
358/// which is replaced by a dot-product.
359///
360/// This only kicks in when either VectorTransformsOptions is set
361/// to Dot or when other contraction patterns fail.
362class ContractionOpLowering
363 : public MaskableOpRewritePattern<vector::ContractionOp> {
364public:
365 using MaskableOpRewritePattern::MaskableOpRewritePattern;
366 using FilterConstraintType =
367 std::function<LogicalResult(vector::ContractionOp op)>;
368
369 static LogicalResult defaultFilter(vector::ContractionOp op) {
370 return success();
371 }
372
373 ContractionOpLowering(
374 vector::VectorContractLowering vectorContractLoweringOption,
375 MLIRContext *context, PatternBenefit benefit = 1,
376 FilterConstraintType constraint = defaultFilter)
377 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
378 vectorContractLoweringOption(vectorContractLoweringOption),
379 filter(std::move(constraint)) {}
380
381 FailureOr<Value>
382 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
383 PatternRewriter &rewriter) const override;
384
385private:
386 /// Options to control the vector patterns.
387 vector::VectorContractLowering vectorContractLoweringOption;
388 FilterConstraintType filter;
389 // Lower one parallel dimension.
390 FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
391 vector::ContractionOp op, int64_t lhsIndex,
392 int64_t rhsIndex, Value mask) const;
393 // Lower one reduction dimension.
394 FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
395 vector::ContractionOp op, Value mask) const;
396};
397
398/// Generate a vector implementation for matmat, matvec and tmatvec.
399/// This unrolls outer-products along the reduction dimension.
400struct UnrolledOuterProductGenerator
401 : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
402 UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
403 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
404 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
405 res(op.getAcc()), lhsType(op.getLhsType()) {
406 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
407 if (maskableOp.isMasked())
408 mask = maskableOp.getMaskingOp().getMask();
409 }
410
411 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
412 if (!v)
413 return v;
414 return rewriter.create<vector::TransposeOp>(loc, v, perm);
415 }
416
417 Value promote(Value v, Type dstElementType) {
418 Type elementType = v.getType();
419 auto vecType = dyn_cast<VectorType>(elementType);
420 if (vecType)
421 elementType = vecType.getElementType();
422 if (elementType == dstElementType)
423 return v;
424 Type promotedType = dstElementType;
425 if (vecType)
426 promotedType = vecType.clone(promotedType);
427 if (isa<FloatType>(dstElementType))
428 return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
429 return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
430 }
431
432 FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
433 VectorType lhsType, int reductionSize,
434 std::optional<Value> maybeMask = std::nullopt) {
435 // Incremental support for masking.
436 if (mask && !maybeMask.has_value())
437 return failure();
438
439 Type resElementType = cast<VectorType>(res.getType()).getElementType();
440 for (int64_t k = 0; k < reductionSize; ++k) {
441 Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
442 Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
443 extractA = promote(v: extractA, dstElementType: resElementType);
444 extractB = promote(v: extractB, dstElementType: resElementType);
445 Value extractMask;
446 if (maybeMask.has_value() && maybeMask.value())
447 extractMask =
448 rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
449
450 Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
451 loc, res.getType(), extractA, extractB, res, kind);
452 res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
453 }
454 return res;
455 }
456
457 /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
458 /// dimension `reductionDim`. If the dimension is a scalable dimension,
459 /// returns "nullopt".
460 std::optional<int64_t> getReductionSize(VectorType vecType,
461 int64_t reductionDim) {
462 // Cannot unroll scalable dimension.
463 if (vecType.getScalableDims()[reductionDim])
464 return std::nullopt;
465 int64_t reductionSize = vecType.getDimSize(reductionDim);
466 assert(reductionSize > 0 &&
467 "Reduction dim must be a known static size to allow unrolling");
468 return reductionSize;
469 }
470
471 /// Two outer parallel, one inner reduction (matmat flavor).
472 FailureOr<Value> matmat() {
473 if (!iters({Par(), Par(), Red()}))
474 return failure();
475 // Set up the parallel/reduction structure in the right form.
476 AffineExpr m, n, k;
477 bindDims(rewriter.getContext(), m, n, k);
478
479 // Classical row-major matmul: Just permute the lhs.
480 if (layout({{m, k}, {k, n}, {m, n}})) {
481 if (auto reductionSize = getReductionSize(lhsType, 1)) {
482 // Note: `t` creates new IR. It must be nested within this `if` check
483 // so that no IR is created when then pattern returns "failure".
484 Value tLhs = t(v: lhs);
485 Value tMask = t(v: mask, perm: {2, 0, 1});
486 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
487 }
488 }
489 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
490 if (layout({{m, k}, {n, k}, {m, n}})) {
491 if (auto reductionSize = getReductionSize(lhsType, 1)) {
492 Value tLhs = t(v: lhs);
493 Value tRhs = t(v: rhs);
494 Value tMask = t(v: mask, perm: {2, 0, 1});
495 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
496 }
497 }
498 // No need to permute anything.
499 if (layout({{k, m}, {k, n}, {m, n}})) {
500 if (auto reductionSize = getReductionSize(lhsType, 0)) {
501 Value tMask = t(v: mask, perm: {2, 0, 1});
502 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
503 }
504 }
505 // Just permute the rhs.
506 if (layout({{k, m}, {n, k}, {m, n}})) {
507 if (auto reductionSize = getReductionSize(lhsType, 0)) {
508 Value tRhs = t(v: rhs);
509 Value tMask = t(v: mask, perm: {2, 0, 1});
510 return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
511 }
512 }
513 // Transposed output: swap RHS and LHS.
514 // Classical row-major matmul: permute the lhs.
515 if (layout({{m, k}, {k, n}, {n, m}})) {
516 if (auto reductionSize = getReductionSize(lhsType, 1)) {
517 Value tLhs = t(v: lhs);
518 Value tMask = t(v: mask, perm: {2, 0, 1});
519 return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
520 }
521 }
522 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
523 if (layout({{m, k}, {n, k}, {n, m}})) {
524 if (auto reductionSize = getReductionSize(lhsType, 1)) {
525 Value tRhs = t(v: rhs);
526 Value tLhs = t(v: lhs);
527 Value tMask = t(v: mask, perm: {2, 0, 1});
528 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
529 }
530 }
531 if (layout({{k, m}, {k, n}, {n, m}})) {
532 if (auto reductionSize = getReductionSize(lhsType, 0)) {
533 Value tMask = t(v: mask, perm: {2, 0, 1});
534 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
535 }
536 }
537 if (layout({{k, m}, {n, k}, {n, m}})) {
538 if (auto reductionSize = getReductionSize(lhsType, 0)) {
539 Value tRhs = t(v: rhs);
540 Value tMask = t(v: mask, perm: {2, 0, 1});
541 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
542 }
543 }
544 return failure();
545 }
546
547 //
548 // One outer parallel, one inner reduction (matvec flavor).
549 // Mask needs to be transposed everywhere to turn the reduction dimension
550 // outermost as required by outerproduct.
551 //
552 FailureOr<Value> matvec() {
553 if (!iters({Par(), Red()}))
554 return failure();
555 AffineExpr m, k;
556 bindDims(rewriter.getContext(), m, k);
557
558 // Case mat-vec: transpose.
559 if (layout({{m, k}, {k}, {m}})) {
560 if (auto reductionSize = getReductionSize(lhsType, 1)) {
561 Value tLhs = t(v: lhs);
562 Value tMask = t(v: mask);
563 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
564 }
565 }
566 // Case mat-trans-vec: ready to go.
567 if (layout({{k, m}, {k}, {m}})) {
568 if (auto reductionSize = getReductionSize(lhsType, 0)) {
569 Value tMask = t(v: mask);
570 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
571 }
572 }
573 // Case vec-mat: swap and transpose.
574 if (layout({{k}, {m, k}, {m}})) {
575 if (auto reductionSize = getReductionSize(lhsType, 0)) {
576 Value tRhs = t(v: rhs);
577 Value tMask = t(v: mask);
578 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
579 }
580 }
581 // Case vec-mat-trans: swap and ready to go.
582 if (layout({{k}, {k, m}, {m}})) {
583 if (auto reductionSize = getReductionSize(lhsType, 0)) {
584 Value tMask = t(v: mask);
585 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
586 }
587 }
588 return failure();
589 }
590
591 //
592 // One outer reduction, one inner parallel (tmatvec flavor).
593 // Mask already has the shape of the outer product.
594 //
595 FailureOr<Value> tmatvec() {
596 if (!iters({Red(), Par()}))
597 return failure();
598 AffineExpr k, m;
599 bindDims(rewriter.getContext(), k, m);
600
601 // Case mat-vec: transpose.
602 if (layout({{m, k}, {k}, {m}}))
603 if (auto reductionSize = getReductionSize(lhsType, 1))
604 return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
605 // Case mat-trans-vec: ready to go.
606 if (layout({{k, m}, {k}, {m}}))
607 if (auto reductionSize = getReductionSize(lhsType, 0))
608 return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
609 // Case vec-mat: swap and transpose.
610 if (layout({{k}, {m, k}, {m}}))
611 if (auto reductionSize = getReductionSize(lhsType, 0))
612 return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
613 // Case vec-mat-trans: swap and ready to go.
614 if (layout({{k}, {k, m}, {m}}))
615 if (auto reductionSize = getReductionSize(lhsType, 0))
616 return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
617 return failure();
618 }
619
620private:
621 vector::CombiningKind kind;
622 Value lhs, rhs, res, mask;
623 VectorType lhsType;
624};
625
626/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
627/// semantics to a reduction_size-unrolled sequence:
628/// ```
629/// %at = vector.transpose %a, [1, 0]
630/// %bRow0 = vector.extract %b[0]
631/// %atRow0 = vector.extract %at[0]
632/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
633/// ...
634/// %bRowK = vector.extract %b[K]
635/// %atRowK = vector.extract %at[K]
636/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
637/// ```
638///
639/// This only kicks in when vectorContractLowering is set to OuterProduct but
640/// otherwise supports any layout permutation of the matrix-multiply.
641FailureOr<Value>
642ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
643 vector::ContractionOp op, MaskingOpInterface maskOp,
644 PatternRewriter &rewriter) const {
645 if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
646 return failure();
647
648 if (failed(filter(op)))
649 return failure();
650
651 UnrolledOuterProductGenerator e(rewriter, op);
652 FailureOr<Value> matmatRes = e.matmat();
653 if (succeeded(Result: matmatRes)) {
654 return matmatRes;
655 }
656 FailureOr<Value> matvecRes = e.matvec();
657 if (succeeded(Result: matvecRes)) {
658 return matvecRes;
659 }
660
661 FailureOr<Value> tmatvecRes = e.tmatvec();
662 return tmatvecRes;
663}
664
665FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
666 vector::ContractionOp op, MaskingOpInterface maskOp,
667 PatternRewriter &rewriter) const {
668 // TODO: Support vector.mask.
669 if (maskOp)
670 return failure();
671
672 if (failed(filter(op)))
673 return failure();
674
675 if (vectorContractLowering != vector::VectorContractLowering::Dot)
676 return failure();
677
678 auto iteratorTypes = op.getIteratorTypes().getValue();
679 static constexpr std::array<int64_t, 2> perm = {1, 0};
680 Location loc = op.getLoc();
681 Value lhs = op.getLhs(), rhs = op.getRhs();
682
683 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
684 auto infer = [&](MapList m) {
685 return AffineMap::inferFromExprList(m, op.getContext());
686 };
687 AffineExpr m, n, k;
688 bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k);
689 SmallVector<AffineMap> maps = op.getIndexingMapsArray();
690 //
691 // In the following we wish to make the reduction dimension innermost so we
692 // can load vectors and just fmul + reduce into a scalar.
693 //
694 if (isParallelIterator(iteratorTypes[0]) &&
695 isParallelIterator(iteratorTypes[1]) &&
696 isReductionIterator(iteratorTypes[2])) {
697 //
698 // Two outer parallel, one inner reduction (matmat flavor).
699 //
700 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
701 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
702 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
703 // No need to permute anything.
704 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
705 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
706 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
707 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
708 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
709 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
710 // This is the classical row-major matmul. Just permute the lhs.
711 Value tmp = lhs;
712 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
713 rhs = tmp;
714 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
715 std::swap(a&: lhs, b&: rhs);
716 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
717 Value tmp = lhs;
718 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
719 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
720 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
721 Value tmp = rhs;
722 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
723 lhs = tmp;
724 } else {
725 return failure();
726 }
727 } else if (isParallelIterator(iteratorTypes[0]) &&
728 isReductionIterator(iteratorTypes[1])) {
729 //
730 // One outer parallel, one inner reduction (matvec flavor)
731 //
732 if (maps == infer({{m, n}, {n}, {m}})) {
733 // No need to permute anything.
734 } else if (maps == infer({{n, m}, {n}, {m}})) {
735 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
736 } else if (maps == infer({{n}, {m, n}, {m}})) {
737 std::swap(a&: lhs, b&: rhs);
738 } else if (maps == infer({{n}, {n, m}, {m}})) {
739 std::swap(a&: lhs, b&: rhs);
740 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
741 } else {
742 return failure();
743 }
744 } else {
745 return failure();
746 }
747
748 VectorType dstType = cast<VectorType>(op.getResultType());
749 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
750 "Expected dst type of rank 1 or 2");
751
752 unsigned rank = dstType.getRank();
753 unsigned dstRows = dstType.getShape()[0];
754 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
755
756 // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
757 Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
758 rewriter.getZeroAttr(dstType));
759 bool isInt = isa<IntegerType>(dstType.getElementType());
760 llvm::SmallVector<Value> extractedCols;
761 extractedCols.reserve(N: dstColumns);
762 for (unsigned r = 0; r < dstRows; ++r) {
763 Value rowLhs = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
764 for (unsigned c = 0; c < dstColumns; ++c) {
765 // Extract each respective row and column of the LHS and RHS once to
766 // avoid having duplicate SSA values pointing to the same rows/columns.
767 if (r == 0) {
768 Value colRhs =
769 rank == 1 ? rhs
770 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
771 extractedCols.push_back(Elt: colRhs);
772 }
773 Value extractedColRhs = extractedCols[c];
774 Value product =
775 createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
776 Value sum = rewriter.create<vector::ReductionOp>(
777 op.getLoc(), vector::CombiningKind::ADD, product);
778
779 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
780 : SmallVector<int64_t, 2>{r, c};
781 res = rewriter.create<vector::InsertOp>(op.getLoc(), sum, res, pos);
782 }
783 }
784 if (auto acc = op.getAcc())
785 res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
786 return res;
787}
788
789/// Lower vector.contract with all size one reduction dimensions to
790/// elementwise ops when possible.
791struct ContractOpToElementwise
792 : public MaskableOpRewritePattern<vector::ContractionOp> {
793 using MaskableOpRewritePattern::MaskableOpRewritePattern;
794 using FilterConstraintType =
795 std::function<LogicalResult(vector::ContractionOp op)>;
796 static LogicalResult defaultFilter(vector::ContractionOp op) {
797 return success();
798 }
799 ContractOpToElementwise(
800 vector::VectorContractLowering vectorContractLowering,
801 MLIRContext *context, PatternBenefit benefit = 1,
802 const FilterConstraintType &constraint = defaultFilter)
803 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
804 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
805
806 FailureOr<Value>
807 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
808 MaskingOpInterface maskOp,
809 PatternRewriter &rewriter) const override {
810 // TODO: Support vector.mask.
811 if (maskOp)
812 return failure();
813
814 if (failed(filter(contractOp)))
815 return failure();
816
817 if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
818 return failure();
819
820 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
821 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
822 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
823 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
824 SmallVector<int64_t> lhsReductionDims =
825 getReductionIndex(lhsMap, contractOp.getIteratorTypes());
826 SmallVector<int64_t> rhsReductionDims =
827 getReductionIndex(rhsMap, contractOp.getIteratorTypes());
828 // All the reduction dimensions must be a size 1.
829 for (int64_t dim : lhsReductionDims) {
830 if (lhsShape[dim] != 1)
831 return failure();
832 }
833 for (int64_t dim : rhsReductionDims) {
834 if (rhsShape[dim] != 1)
835 return failure();
836 }
837 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
838 unsigned numParallelDims = accMap.getNumResults();
839 unsigned numLhsDimToBroadcast =
840 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
841 unsigned numRhsDimToBroadcast =
842 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
843 SmallVector<int64_t> lhsDims;
844 SmallVector<int64_t> lhsTranspose;
845 SmallVector<int64_t> rhsDims;
846 SmallVector<int64_t> rhsTranspose;
847 for (int64_t dim : lhsReductionDims)
848 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
849 for (int64_t dim : rhsReductionDims)
850 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
851 // Loop through the parallel dimensions to calculate the dimensions to
852 // broadcast and to permute in order to extract only parallel dimensions.
853 for (unsigned i = 0; i < numParallelDims; i++) {
854 std::optional<unsigned> lhsDim =
855 getDimPosition(map: lhsMap, dim: accMap.getDimPosition(idx: i));
856 if (lhsDim) {
857 lhsTranspose.push_back(Elt: numLhsDimToBroadcast + *lhsDim);
858 } else {
859 // If the parallel dimension doesn't exist we will have to broadcast it.
860 lhsDims.push_back(
861 Elt: cast<VectorType>(contractOp.getResultType()).getDimSize(i));
862 lhsTranspose.push_back(Elt: lhsDims.size() - 1);
863 }
864 std::optional<unsigned> rhsDim =
865 getDimPosition(map: rhsMap, dim: accMap.getDimPosition(idx: i));
866 if (rhsDim) {
867 rhsTranspose.push_back(Elt: numRhsDimToBroadcast + *rhsDim);
868 } else {
869 // If the parallel dimension doesn't exist we will have to broadcast it.
870 rhsDims.push_back(
871 Elt: cast<VectorType>(contractOp.getResultType()).getDimSize(i));
872 rhsTranspose.push_back(Elt: rhsDims.size() - 1);
873 }
874 }
875 Value newLhs = contractOp.getLhs();
876 Value newRhs = contractOp.getRhs();
877 Location loc = contractOp.getLoc();
878 if (!lhsDims.empty()) {
879 lhsDims.append(in_start: lhsShape.begin(), in_end: lhsShape.end());
880 auto expandedType =
881 VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
882 newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
883 }
884 if (!rhsDims.empty()) {
885 rhsDims.append(in_start: rhsShape.begin(), in_end: rhsShape.end());
886 auto expandedType =
887 VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
888 newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
889 }
890 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
891 newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
892 newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
893 SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
894 SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
895 newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
896 newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
897 std::optional<Value> result =
898 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
899 contractOp.getKind(), rewriter, isInt);
900 if (result)
901 return *result;
902
903 return failure();
904 }
905
906private:
907 /// Options to control the vector patterns.
908 vector::VectorContractLowering vectorContractLowering;
909 FilterConstraintType filter;
910};
911
912/// Progressive lowering of ContractionOp.
913/// One:
914/// %x = vector.contract with at least one free/batch dimension
915/// is replaced by:
916/// %a = vector.contract with one less free/batch dimension
917/// %b = vector.contract with one less free/batch dimension
918/// ..
919/// %x = combine %a %b ..
920/// until a pure contraction is reached (no free/batch dimensions),
921/// which is replaced by a dot-product.
922///
923/// This only kicks in when either vectorContractLoweringOption is set
924/// to DOT or when other contraction patterns fail.
925//
926// TODO: break down into transpose/reshape/cast ops
927// when they become available to avoid code dup
928// TODO: investigate lowering order impact on performance
929FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
930 vector::ContractionOp op, MaskingOpInterface maskOp,
931 PatternRewriter &rewriter) const {
932 if (failed(filter(op)))
933 return failure();
934
935 // TODO: support mixed mode contract lowering.
936 if (op.getLhsType().getElementType() !=
937 getElementTypeOrSelf(op.getAccType()) ||
938 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
939 return failure();
940
941 // TODO: the code below assumes the default contraction, make sure it supports
942 // other kinds before enabling this lowering.
943 if (op.getKind() != vector::CombiningKind::ADD) {
944 return rewriter.notifyMatchFailure(
945 op, "contractions other than 'add' not supported");
946 }
947
948 // TODO: implement benefits, cost models.
949 MLIRContext *ctx = op.getContext();
950
951 ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
952 FailureOr<Value> newVal1 =
953 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
954 if (!failed(Result: newVal1))
955 return newVal1;
956
957 ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
958 FailureOr<Value> newVal2 =
959 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
960 if (!failed(Result: newVal2))
961 return newVal2;
962
963 ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
964 FailureOr<Value> newVal3 =
965 pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
966 if (!failed(Result: newVal3))
967 return newVal3;
968
969 ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
970 FailureOr<Value> newVal4 =
971 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
972 if (!failed(Result: newVal4))
973 return newVal4;
974
975 // Vector mask setup.
976
977 Value mask;
978 if (maskOp)
979 mask = maskOp.getMask();
980 // Find first batch dimension in LHS/RHS, and lower when found.
981 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
982 if (!batchDimMap.empty()) {
983 int64_t lhsIndex = batchDimMap[0].first;
984 int64_t rhsIndex = batchDimMap[0].second;
985 auto newOp = lowerParallel(rewriter, op: op, lhsIndex, rhsIndex, mask);
986 if (failed(newOp))
987 return failure();
988 return newOp;
989 }
990
991 // Collect contracting dimensions.
992 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
993 op.getContractingDimMap();
994 DenseSet<int64_t> lhsContractingDimSet;
995 DenseSet<int64_t> rhsContractingDimSet;
996 for (auto &dimPair : contractingDimMap) {
997 lhsContractingDimSet.insert(dimPair.first);
998 rhsContractingDimSet.insert(dimPair.second);
999 }
1000
1001 // Find first free dimension in LHS, and lower when found.
1002 VectorType lhsType = op.getLhsType();
1003 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
1004 if (lhsContractingDimSet.count(V: lhsIndex) == 0) {
1005 auto newOp = lowerParallel(rewriter, op: op, lhsIndex, /*rhsIndex=*/-1, mask);
1006 if (failed(newOp))
1007 return failure();
1008 return newOp;
1009 }
1010 }
1011
1012 // Find first free dimension in RHS, and lower when found.
1013 VectorType rhsType = op.getRhsType();
1014 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
1015 if (rhsContractingDimSet.count(V: rhsIndex) == 0) {
1016 auto newOp = lowerParallel(rewriter, op: op, /*lhsIndex=*/-1, rhsIndex, mask);
1017 if (failed(newOp))
1018 return failure();
1019 return newOp;
1020 }
1021 }
1022
1023 // Lower the first remaining reduction dimension.
1024 if (!contractingDimMap.empty()) {
1025 auto newOp = lowerReduction(rewriter, op: op, mask);
1026 if (failed(newOp))
1027 return failure();
1028 return newOp;
1029 }
1030
1031 return failure();
1032}
1033
1034// Lower one parallel dimension.
1035// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1036// TODO: consider reusing existing contract unrolling
1037FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
1038 vector::ContractionOp op,
1039 int64_t lhsIndex,
1040 int64_t rhsIndex,
1041 Value mask) const {
1042 VectorType lhsType = op.getLhsType();
1043 VectorType rhsType = op.getRhsType();
1044 VectorType resType = cast<VectorType>(op.getResultType());
1045 // Find the iterator type index and result index.
1046 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1047 int64_t iterIndex = -1;
1048 int64_t dimSize = -1;
1049 if (lhsIndex >= 0) {
1050 iterIndex = iMap[0].getDimPosition(idx: lhsIndex);
1051 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(idx: rhsIndex))
1052 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1053 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1054 << " to map to the same dimension";
1055 });
1056 if (lhsType.getScalableDims()[lhsIndex])
1057 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1058 diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1059 << ") is not supported yet";
1060 });
1061 dimSize = lhsType.getDimSize(lhsIndex);
1062 } else if (rhsIndex >= 0) {
1063 iterIndex = iMap[1].getDimPosition(idx: rhsIndex);
1064 if (rhsType.getScalableDims()[rhsIndex])
1065 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1066 diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1067 << ") is not supported yet";
1068 });
1069 dimSize = rhsType.getDimSize(rhsIndex);
1070 }
1071 if (iterIndex < 0)
1072 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1073 diag << "expected either lhsIndex=" << lhsIndex
1074 << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1075 });
1076 // value_or(-1) means that we tolerate a dimension not appearing
1077 // in the result map. That can't happen for actual parallel iterators, but
1078 // the caller ContractionOpLowering::matchAndRewrite is currently calling
1079 // lowerParallel also for the case of unit-size reduction dims appearing only
1080 // on one of LHS or RHS, not both. At the moment, such cases are created by
1081 // CastAwayContractionLeadingOneDim, so we need to either support that or
1082 // modify that pattern.
1083 int64_t resIndex = getResultIndex(map: iMap[2], index: iterIndex).value_or(u: -1);
1084 if (resIndex == -1 && dimSize != 1)
1085 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1086 diag << "expected the dimension for iterIndex=" << iterIndex
1087 << " to either appear in the result map, or to be a unit dimension";
1088 });
1089
1090 // Construct new iterator types and affine map array attribute.
1091 std::array<AffineMap, 3> lowIndexingMaps = {
1092 adjustMap(map: iMap[0], index: iterIndex, rewriter),
1093 adjustMap(map: iMap[1], index: iterIndex, rewriter),
1094 adjustMap(map: iMap[2], index: iterIndex, rewriter)};
1095 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1096 auto lowIter =
1097 rewriter.getArrayAttr(value: adjustIter(op.getIteratorTypes(), iterIndex));
1098 // Unroll into a series of lower dimensional vector.contract ops.
1099 Location loc = op.getLoc();
1100 Value result = rewriter.create<arith::ConstantOp>(
1101 loc, resType, rewriter.getZeroAttr(resType));
1102
1103 for (int64_t d = 0; d < dimSize; ++d) {
1104 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1105 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1106 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1107
1108 Value lowMask;
1109 if (mask)
1110 lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1111 iterIndex, d, rewriter);
1112
1113 Operation *lowContract = rewriter.create<vector::ContractionOp>(
1114 loc, lhs, rhs, acc, lowAffine, lowIter);
1115 lowContract = maskOperation(builder&: rewriter, maskableOp: lowContract, mask: lowMask);
1116 result = reshapeStore(loc, lowContract->getResult(idx: 0), result, resType,
1117 resIndex, d, rewriter);
1118 }
1119 return result;
1120}
1121
1122// Lower one reduction dimension.
1123FailureOr<Value> ContractionOpLowering::lowerReduction(
1124 PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1125 auto loc = op.getLoc();
1126 VectorType lhsType = op.getLhsType();
1127 VectorType rhsType = op.getRhsType();
1128 Type resType = op.getResultType();
1129 if (isa<VectorType>(Val: resType))
1130 return rewriter.notifyMatchFailure(op,
1131 "did not expect a VectorType result");
1132 bool isInt = isa<IntegerType>(Val: resType);
1133 // Use iterator index 0.
1134 int64_t iterIndex = 0;
1135 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1136 std::optional<int64_t> lookupLhs = getResultIndex(map: iMap[0], index: iterIndex);
1137 std::optional<int64_t> lookupRhs = getResultIndex(map: iMap[1], index: iterIndex);
1138 if (!lookupLhs.has_value())
1139 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1140 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1141 });
1142 if (!lookupRhs.has_value())
1143 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1144 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1145 });
1146 int64_t lhsIndex = *lookupLhs;
1147 int64_t rhsIndex = *lookupRhs;
1148 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1149 if (dimSize != rhsType.getDimSize(rhsIndex))
1150 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1151 diag << "expect LHS dimension " << lhsIndex
1152 << " to have the same size as RHS dimension " << rhsIndex;
1153 });
1154 // Base case.
1155 if (lhsType.getRank() == 1) {
1156 if (rhsType.getRank() != 1)
1157 return rewriter.notifyMatchFailure(
1158 op, "When LHS has rank 1, expected also RHS to have rank 1");
1159 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1160 auto kind = vector::CombiningKind::ADD;
1161
1162 Value acc = op.getAcc();
1163 Operation *reductionOp =
1164 acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
1165 : rewriter.create<vector::ReductionOp>(loc, kind, m);
1166 return maskOperation(builder&: rewriter, maskableOp: reductionOp, mask)->getResult(idx: 0);
1167 }
1168 // Construct new iterator types and affine map array attribute.
1169 std::array<AffineMap, 3> lowIndexingMaps = {
1170 adjustMap(map: iMap[0], index: iterIndex, rewriter),
1171 adjustMap(map: iMap[1], index: iterIndex, rewriter),
1172 adjustMap(map: iMap[2], index: iterIndex, rewriter)};
1173 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1174 auto lowIter =
1175 rewriter.getArrayAttr(value: adjustIter(op.getIteratorTypes(), iterIndex));
1176 // Unroll into a series of lower dimensional vector.contract ops.
1177 // By feeding the initial accumulator into the first contraction,
1178 // and the result of each contraction into the next, eventually
1179 // the sum of all reductions is computed.
1180 Value result = op.getAcc();
1181 for (int64_t d = 0; d < dimSize; ++d) {
1182 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1183 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1184 Value newMask;
1185 if (mask)
1186 newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1187 iterIndex, d, rewriter);
1188
1189 Operation *newContract = rewriter.create<vector::ContractionOp>(
1190 loc, lhs, rhs, result, lowAffine, lowIter);
1191 result = maskOperation(builder&: rewriter, maskableOp: newContract, mask: newMask)->getResult(idx: 0);
1192 }
1193 return result;
1194}
1195
1196/// Progressive lowering of OuterProductOp.
1197/// One:
1198/// %x = vector.outerproduct %lhs, %rhs, %acc
1199/// is replaced by:
1200/// %z = zero-result
1201/// %0 = vector.extract %lhs[0]
1202/// %1 = vector.broadcast %0
1203/// %2 = vector.extract %acc[0]
1204/// %3 = vector.fma %1, %rhs, %2
1205/// %4 = vector.insert %3, %z[0]
1206/// ..
1207/// %x = vector.insert %.., %..[N-1]
1208///
1209class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1210public:
1211 using OpRewritePattern::OpRewritePattern;
1212
1213 LogicalResult matchAndRewrite(vector::OuterProductOp op,
1214 PatternRewriter &rewriter) const override {
1215 VectorType resType = op.getResultVectorType();
1216 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1217 return failure();
1218
1219 auto loc = op.getLoc();
1220
1221 VectorType lhsType = op.getOperandVectorTypeLHS();
1222 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1223 Type eltType = resType.getElementType();
1224 bool isInt = isa<IntegerType, IndexType>(Val: eltType);
1225 Value acc = op.getAcc();
1226 vector::CombiningKind kind = op.getKind();
1227
1228 // Vector mask setup.
1229 OpBuilder::InsertionGuard guard(rewriter);
1230 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1231 Operation *rootOp;
1232 Value mask;
1233 if (maskableOp.isMasked()) {
1234 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1235 rootOp = maskableOp.getMaskingOp();
1236 mask = maskableOp.getMaskingOp().getMask();
1237 } else {
1238 rootOp = op;
1239 }
1240
1241 if (!rhsType) {
1242 // Special case: AXPY operation.
1243 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
1244 std::optional<Value> mult = createContractArithOp(
1245 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1246 if (!mult.has_value())
1247 return failure();
1248 rewriter.replaceOp(op: rootOp, newValues: *mult);
1249 return success();
1250 }
1251
1252 Value result = rewriter.create<arith::ConstantOp>(
1253 loc, resType, rewriter.getZeroAttr(resType));
1254 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1255 Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
1256 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
1257 Value r = nullptr;
1258 if (acc)
1259 r = rewriter.create<vector::ExtractOp>(loc, acc, d);
1260 Value extrMask;
1261 if (mask)
1262 extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
1263
1264 std::optional<Value> m = createContractArithOp(
1265 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1266 if (!m.has_value())
1267 return failure();
1268 result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
1269 }
1270
1271 rewriter.replaceOp(op: rootOp, newValues: result);
1272 return success();
1273 }
1274};
1275
1276/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1277/// semantics to:
1278/// ```
1279/// %mta = maybe_transpose
1280/// %mtb = maybe_transpose
1281/// %flattened_a = vector.shape_cast %mta
1282/// %flattened_b = vector.shape_cast %mtb
1283/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
1284/// %mtd = vector.shape_cast %flattened_d
1285/// %d = maybe_untranspose %mtd
1286/// %e = add %c, %d
1287/// ```
1288/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
1289//
1290/// This only kicks in when vectorContractLowering is set to `Matmul`.
1291/// vector.transpose operations are inserted if the vector.contract op is not a
1292/// row-major matrix multiply.
1293///
1294/// Scalable vectors are not supported.
1295FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
1296 vector::ContractionOp op, MaskingOpInterface maskOp,
1297 PatternRewriter &rew) const {
1298 // TODO: Support vector.mask.
1299 if (maskOp)
1300 return failure();
1301
1302 if (vectorContractLowering != vector::VectorContractLowering::Matmul)
1303 return failure();
1304 if (failed(filter(op)))
1305 return failure();
1306
1307 auto iteratorTypes = op.getIteratorTypes().getValue();
1308 if (!isParallelIterator(iteratorTypes[0]) ||
1309 !isParallelIterator(iteratorTypes[1]) ||
1310 !isReductionIterator(iteratorTypes[2]))
1311 return failure();
1312
1313 Type opResType = op.getType();
1314 VectorType vecType = dyn_cast<VectorType>(opResType);
1315 if (vecType && vecType.isScalable()) {
1316 // Note - this is sufficient to reject all cases with scalable vectors.
1317 return failure();
1318 }
1319
1320 Type elementType = op.getLhsType().getElementType();
1321 if (!elementType.isIntOrFloat())
1322 return failure();
1323
1324 Type dstElementType = vecType ? vecType.getElementType() : opResType;
1325 if (elementType != dstElementType)
1326 return failure();
1327
1328 // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
1329 // Bail out if the contraction cannot be put in this form.
1330 MLIRContext *ctx = op.getContext();
1331 Location loc = op.getLoc();
1332 AffineExpr m, n, k;
1333 bindDims(ctx: rew.getContext(), exprs&: m, exprs&: n, exprs&: k);
1334 // LHS must be A(m, k) or A(k, m).
1335 Value lhs = op.getLhs();
1336 auto lhsMap = op.getIndexingMapsArray()[0];
1337 if (lhsMap == AffineMap::get(dimCount: 3, symbolCount: 0, results: {k, m}, context: ctx))
1338 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
1339 else if (lhsMap != AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, k}, context: ctx))
1340 return failure();
1341
1342 // RHS must be B(k, n) or B(n, k).
1343 Value rhs = op.getRhs();
1344 auto rhsMap = op.getIndexingMapsArray()[1];
1345 if (rhsMap == AffineMap::get(dimCount: 3, symbolCount: 0, results: {n, k}, context: ctx))
1346 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
1347 else if (rhsMap != AffineMap::get(dimCount: 3, symbolCount: 0, results: {k, n}, context: ctx))
1348 return failure();
1349
1350 // At this point lhs and rhs are in row-major.
1351 VectorType lhsType = cast<VectorType>(lhs.getType());
1352 VectorType rhsType = cast<VectorType>(rhs.getType());
1353 int64_t lhsRows = lhsType.getDimSize(0);
1354 int64_t lhsColumns = lhsType.getDimSize(1);
1355 int64_t rhsColumns = rhsType.getDimSize(1);
1356
1357 Type flattenedLHSType =
1358 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1359 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
1360
1361 Type flattenedRHSType =
1362 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1363 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
1364
1365 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
1366 rhsColumns);
1367 mul = rew.create<vector::ShapeCastOp>(
1368 loc,
1369 VectorType::get({lhsRows, rhsColumns},
1370 getElementTypeOrSelf(op.getAcc().getType())),
1371 mul);
1372
1373 // ACC must be C(m, n) or C(n, m).
1374 auto accMap = op.getIndexingMapsArray()[2];
1375 if (accMap == AffineMap::get(dimCount: 3, symbolCount: 0, results: {n, m}, context: ctx))
1376 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
1377 else if (accMap != AffineMap::get(dimCount: 3, symbolCount: 0, results: {m, n}, context: ctx))
1378 llvm_unreachable("invalid contraction semantics");
1379
1380 Value res =
1381 isa<IntegerType>(elementType)
1382 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
1383 : static_cast<Value>(
1384 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
1385
1386 return res;
1387}
1388} // namespace
1389
1390void mlir::vector::populateVectorContractLoweringPatterns(
1391 RewritePatternSet &patterns,
1392 VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1393 bool disableOuterProductLowering) {
1394 if (!disableOuterProductLowering)
1395 patterns.add<OuterProductOpLowering>(arg: patterns.getContext(), args&: benefit);
1396 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
1397 ContractionOpToOuterProductOpLowering>(
1398 vectorContractLoweringOption, patterns.getContext(), benefit);
1399}
1400
1401void mlir::vector::populateVectorOuterProductLoweringPatterns(
1402 RewritePatternSet &patterns, PatternBenefit benefit) {
1403 patterns.add<OuterProductOpLowering>(arg: patterns.getContext(), args&: benefit);
1404}
1405

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp