1//=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to emulate the
10// 'vector.maskedload' and 'vector.maskedstore' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/SCF/IR/SCF.h"
16#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17
18using namespace mlir;
19
20namespace {
21
22/// Convert vector.maskedload
23///
24/// Before:
25///
26/// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
27///
28/// After:
29///
30/// %ivalue = %pass_thru
31/// %m = vector.extract %mask[0]
32/// %result0 = scf.if %m {
33/// %v = memref.load %base[%idx_0, %idx_1]
34/// %combined = vector.insert %v, %ivalue[0]
35/// scf.yield %combined
36/// } else {
37/// scf.yield %ivalue
38/// }
39/// %m = vector.extract %mask[1]
40/// %result1 = scf.if %m {
41/// %v = memref.load %base[%idx_0, %idx_1 + 1]
42/// %combined = vector.insert %v, %result0[1]
43/// scf.yield %combined
44/// } else {
45/// scf.yield %result0
46/// }
47/// ...
48///
49struct VectorMaskedLoadOpConverter final
50 : OpRewritePattern<vector::MaskedLoadOp> {
51 using OpRewritePattern::OpRewritePattern;
52
53 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
54 PatternRewriter &rewriter) const override {
55 VectorType maskVType = maskedLoadOp.getMaskVectorType();
56 if (maskVType.getShape().size() != 1)
57 return rewriter.notifyMatchFailure(
58 maskedLoadOp, "expected vector.maskedstore with 1-D mask");
59
60 Location loc = maskedLoadOp.getLoc();
61 int64_t maskLength = maskVType.getShape()[0];
62
63 Type indexType = rewriter.getIndexType();
64 Value mask = maskedLoadOp.getMask();
65 Value base = maskedLoadOp.getBase();
66 Value iValue = maskedLoadOp.getPassThru();
67 auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
68 Value one = rewriter.create<arith::ConstantOp>(
69 loc, indexType, IntegerAttr::get(indexType, 1));
70 for (int64_t i = 0; i < maskLength; ++i) {
71 auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
72
73 auto ifOp = rewriter.create<scf::IfOp>(
74 loc, maskBit,
75 [&](OpBuilder &builder, Location loc) {
76 auto loadedValue =
77 builder.create<memref::LoadOp>(loc, base, indices);
78 auto combinedValue =
79 builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
80 builder.create<scf::YieldOp>(loc, combinedValue.getResult());
81 },
82 [&](OpBuilder &builder, Location loc) {
83 builder.create<scf::YieldOp>(loc, iValue);
84 });
85 iValue = ifOp.getResult(0);
86
87 indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
88 }
89
90 rewriter.replaceOp(maskedLoadOp, iValue);
91
92 return success();
93 }
94};
95
96/// Convert vector.maskedstore
97///
98/// Before:
99///
100/// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
101///
102/// After:
103///
104/// %m = vector.extract %mask[0]
105/// scf.if %m {
106/// %extracted = vector.extract %value[0]
107/// memref.store %extracted, %base[%idx_0, %idx_1]
108/// }
109/// %m = vector.extract %mask[1]
110/// scf.if %m {
111/// %extracted = vector.extract %value[1]
112/// memref.store %extracted, %base[%idx_0, %idx_1 + 1]
113/// }
114/// ...
115///
116struct VectorMaskedStoreOpConverter final
117 : OpRewritePattern<vector::MaskedStoreOp> {
118 using OpRewritePattern::OpRewritePattern;
119
120 LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
121 PatternRewriter &rewriter) const override {
122 VectorType maskVType = maskedStoreOp.getMaskVectorType();
123 if (maskVType.getShape().size() != 1)
124 return rewriter.notifyMatchFailure(
125 maskedStoreOp, "expected vector.maskedstore with 1-D mask");
126
127 Location loc = maskedStoreOp.getLoc();
128 int64_t maskLength = maskVType.getShape()[0];
129
130 Type indexType = rewriter.getIndexType();
131 Value mask = maskedStoreOp.getMask();
132 Value base = maskedStoreOp.getBase();
133 Value value = maskedStoreOp.getValueToStore();
134 auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
135 Value one = rewriter.create<arith::ConstantOp>(
136 loc, indexType, IntegerAttr::get(indexType, 1));
137 for (int64_t i = 0; i < maskLength; ++i) {
138 auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i);
139
140 auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false);
141 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
142 auto extractedValue = rewriter.create<vector::ExtractOp>(loc, value, i);
143 rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices);
144
145 rewriter.setInsertionPointAfter(ifOp);
146 indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one);
147 }
148
149 rewriter.eraseOp(op: maskedStoreOp);
150
151 return success();
152 }
153};
154
155} // namespace
156
157void mlir::vector::populateVectorMaskedLoadStoreEmulationPatterns(
158 RewritePatternSet &patterns, PatternBenefit benefit) {
159 patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
160 arg: patterns.getContext(), args&: benefit);
161}
162

source code of mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp