1 | //=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements target-independent rewrites and utilities to emulate the |
10 | // 'vector.maskedload' and 'vector.maskedstore' operation. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
15 | #include "mlir/Dialect/SCF/IR/SCF.h" |
16 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
17 | |
18 | using namespace mlir; |
19 | |
20 | namespace { |
21 | |
22 | /// Convert vector.maskedload |
23 | /// |
24 | /// Before: |
25 | /// |
26 | /// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru |
27 | /// |
28 | /// After: |
29 | /// |
30 | /// %ivalue = %pass_thru |
31 | /// %m = vector.extract %mask[0] |
32 | /// %result0 = scf.if %m { |
33 | /// %v = memref.load %base[%idx_0, %idx_1] |
34 | /// %combined = vector.insert %v, %ivalue[0] |
35 | /// scf.yield %combined |
36 | /// } else { |
37 | /// scf.yield %ivalue |
38 | /// } |
39 | /// %m = vector.extract %mask[1] |
40 | /// %result1 = scf.if %m { |
41 | /// %v = memref.load %base[%idx_0, %idx_1 + 1] |
42 | /// %combined = vector.insert %v, %result0[1] |
43 | /// scf.yield %combined |
44 | /// } else { |
45 | /// scf.yield %result0 |
46 | /// } |
47 | /// ... |
48 | /// |
49 | struct VectorMaskedLoadOpConverter final |
50 | : OpRewritePattern<vector::MaskedLoadOp> { |
51 | using OpRewritePattern::OpRewritePattern; |
52 | |
53 | LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp, |
54 | PatternRewriter &rewriter) const override { |
55 | VectorType maskVType = maskedLoadOp.getMaskVectorType(); |
56 | if (maskVType.getShape().size() != 1) |
57 | return rewriter.notifyMatchFailure( |
58 | maskedLoadOp, "expected vector.maskedstore with 1-D mask" ); |
59 | |
60 | Location loc = maskedLoadOp.getLoc(); |
61 | int64_t maskLength = maskVType.getShape()[0]; |
62 | |
63 | Type indexType = rewriter.getIndexType(); |
64 | Value mask = maskedLoadOp.getMask(); |
65 | Value base = maskedLoadOp.getBase(); |
66 | Value iValue = maskedLoadOp.getPassThru(); |
67 | auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices()); |
68 | Value one = rewriter.create<arith::ConstantOp>( |
69 | loc, indexType, IntegerAttr::get(indexType, 1)); |
70 | for (int64_t i = 0; i < maskLength; ++i) { |
71 | auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i); |
72 | |
73 | auto ifOp = rewriter.create<scf::IfOp>( |
74 | loc, maskBit, |
75 | [&](OpBuilder &builder, Location loc) { |
76 | auto loadedValue = |
77 | builder.create<memref::LoadOp>(loc, base, indices); |
78 | auto combinedValue = |
79 | builder.create<vector::InsertOp>(loc, loadedValue, iValue, i); |
80 | builder.create<scf::YieldOp>(loc, combinedValue.getResult()); |
81 | }, |
82 | [&](OpBuilder &builder, Location loc) { |
83 | builder.create<scf::YieldOp>(loc, iValue); |
84 | }); |
85 | iValue = ifOp.getResult(0); |
86 | |
87 | indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one); |
88 | } |
89 | |
90 | rewriter.replaceOp(maskedLoadOp, iValue); |
91 | |
92 | return success(); |
93 | } |
94 | }; |
95 | |
96 | /// Convert vector.maskedstore |
97 | /// |
98 | /// Before: |
99 | /// |
100 | /// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value |
101 | /// |
102 | /// After: |
103 | /// |
104 | /// %m = vector.extract %mask[0] |
105 | /// scf.if %m { |
106 | /// %extracted = vector.extract %value[0] |
107 | /// memref.store %extracted, %base[%idx_0, %idx_1] |
108 | /// } |
109 | /// %m = vector.extract %mask[1] |
110 | /// scf.if %m { |
111 | /// %extracted = vector.extract %value[1] |
112 | /// memref.store %extracted, %base[%idx_0, %idx_1 + 1] |
113 | /// } |
114 | /// ... |
115 | /// |
116 | struct VectorMaskedStoreOpConverter final |
117 | : OpRewritePattern<vector::MaskedStoreOp> { |
118 | using OpRewritePattern::OpRewritePattern; |
119 | |
120 | LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp, |
121 | PatternRewriter &rewriter) const override { |
122 | VectorType maskVType = maskedStoreOp.getMaskVectorType(); |
123 | if (maskVType.getShape().size() != 1) |
124 | return rewriter.notifyMatchFailure( |
125 | maskedStoreOp, "expected vector.maskedstore with 1-D mask" ); |
126 | |
127 | Location loc = maskedStoreOp.getLoc(); |
128 | int64_t maskLength = maskVType.getShape()[0]; |
129 | |
130 | Type indexType = rewriter.getIndexType(); |
131 | Value mask = maskedStoreOp.getMask(); |
132 | Value base = maskedStoreOp.getBase(); |
133 | Value value = maskedStoreOp.getValueToStore(); |
134 | auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices()); |
135 | Value one = rewriter.create<arith::ConstantOp>( |
136 | loc, indexType, IntegerAttr::get(indexType, 1)); |
137 | for (int64_t i = 0; i < maskLength; ++i) { |
138 | auto maskBit = rewriter.create<vector::ExtractOp>(loc, mask, i); |
139 | |
140 | auto ifOp = rewriter.create<scf::IfOp>(loc, maskBit, /*else=*/false); |
141 | rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
142 | auto = rewriter.create<vector::ExtractOp>(loc, value, i); |
143 | rewriter.create<memref::StoreOp>(loc, extractedValue, base, indices); |
144 | |
145 | rewriter.setInsertionPointAfter(ifOp); |
146 | indices.back() = rewriter.create<arith::AddIOp>(loc, indices.back(), one); |
147 | } |
148 | |
149 | rewriter.eraseOp(op: maskedStoreOp); |
150 | |
151 | return success(); |
152 | } |
153 | }; |
154 | |
155 | } // namespace |
156 | |
157 | void mlir::vector::populateVectorMaskedLoadStoreEmulationPatterns( |
158 | RewritePatternSet &patterns, PatternBenefit benefit) { |
159 | patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>( |
160 | arg: patterns.getContext(), args&: benefit); |
161 | } |
162 | |