1 | //===- DataLayoutPropagation.cpp -----------------------------------------===/// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
13 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
14 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
16 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
17 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
18 | #include "mlir/IR/Dominance.h" |
19 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
20 | #include "llvm/ADT/TypeSwitch.h" |
21 | #include "llvm/Support/Debug.h" |
22 | #include <optional> |
23 | |
24 | namespace mlir { |
25 | #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION |
26 | #include "mlir/Dialect/Linalg/Passes.h.inc" |
27 | } // namespace mlir |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::linalg; |
31 | |
32 | #define DEBUG_TYPE "linalg-data-layout-propagation" |
33 | |
34 | namespace { |
35 | |
36 | static bool hasGatherSemantics(linalg::GenericOp genericOp) { |
37 | for (Operation &op : genericOp.getBody()->getOperations()) |
38 | if (isa<tensor::ExtractOp, linalg::IndexOp>(op)) |
39 | return true; |
40 | return false; |
41 | } |
42 | |
43 | // The struct contains the infomation about mapping packing information to |
44 | // the iteration domain of Linalg ops. |
45 | struct PackInfo { |
46 | int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; |
47 | // InnerDimsPos on iteration domain, which follows the order in pack ops. |
48 | SmallVector<int64_t> tiledDimsPos; |
49 | // The sizes of tiling data dimensions on iteration domain. |
50 | llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping; |
51 | // The mapping from a dimension of iteration domain to the corresponding inner |
52 | // tiling dimension on iteration domain. |
53 | llvm::DenseMap<int64_t, int64_t> tileToPointMapping; |
54 | // The permutation of outer dims (on domain). |
55 | SmallVector<int64_t> outerDimsOnDomainPerm; |
56 | }; |
57 | |
58 | template <typename OpTy> |
59 | static FailureOr<PackInfo> |
60 | getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, |
61 | OpTy packOrUnPackOp) { |
62 | static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value, |
63 | "applies to only pack or unpack operations" ); |
64 | LLVM_DEBUG( |
65 | { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n" ; }); |
66 | |
67 | AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); |
68 | SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); |
69 | SmallVector<utils::IteratorType> iterators = |
70 | genericOp.getIteratorTypesArray(); |
71 | |
72 | PackInfo packInfo; |
73 | int64_t origNumDims = indexingMap.getNumDims(); |
74 | SmallVector<AffineExpr> exprs(indexingMap.getResults()); |
75 | ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos(); |
76 | for (auto [index, innerDimPos, tileSize] : |
77 | llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()), |
78 | innerDimsPos, packOrUnPackOp.getMixedTiles())) { |
79 | auto expr = exprs[innerDimPos]; |
80 | if (!isa<AffineDimExpr>(expr)) |
81 | return failure(); |
82 | int64_t domainDimPos = |
83 | cast<AffineDimExpr>(exprs[innerDimPos]).getPosition(); |
84 | if (!isParallelIterator(iterators[domainDimPos])) |
85 | return failure(); |
86 | packInfo.tiledDimsPos.push_back(domainDimPos); |
87 | packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; |
88 | packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; |
89 | LLVM_DEBUG({ |
90 | llvm::dbgs() << "map innerDimPos=" << innerDimPos |
91 | << " to iteration dimension (d" << domainDimPos << ", d" |
92 | << packInfo.tileToPointMapping[domainDimPos] |
93 | << "), which has size=(" |
94 | << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n" ; |
95 | }); |
96 | } |
97 | |
98 | // Bail out if a tiled dimension is present in a map but not as an affine dim |
99 | // expression. |
100 | auto areAllAffineDimExpr = [&](int dim) { |
101 | for (AffineMap map : indexingMaps) { |
102 | if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) { |
103 | return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr); |
104 | })) { |
105 | return false; |
106 | } |
107 | } |
108 | return true; |
109 | }; |
110 | for (int64_t i : packInfo.tiledDimsPos) |
111 | if (!areAllAffineDimExpr(i)) |
112 | return failure(); |
113 | |
114 | // Get the outer dims perm on the iteration domain. Start by identifying the |
115 | // set of domain dims affected by the outer permutation along with the |
116 | // permuted ordering for those dims. Then the full outer dims permutation can |
117 | // be constructed by replacing the affected dims with the permuted result in a |
118 | // numLoops-rank identity. e.g. |
119 | // outerDimsPerm = [1, 2, 0] |
120 | // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) |
121 | // |
122 | // permutedOuterDims = [4, 3, 1] |
123 | // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] |
124 | // |
125 | // Non-affine dim expressions must not be permuted by the outer dims |
126 | // permutation. |
127 | SmallVector<int64_t> permutedOuterDims; |
128 | for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { |
129 | auto permutedExpr = indexingMap.getResult(idx: dim); |
130 | if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) { |
131 | permutedOuterDims.push_back(dimExpr.getPosition()); |
132 | continue; |
133 | } |
134 | |
135 | // TODO: Allow propagation with transposes on non affine dim expressions, |
136 | // e.g. d0 + d1 which implies transposing both dims simultaneously while |
137 | // maintaining the relative position between them. |
138 | if (static_cast<int64_t>(index) != dim) |
139 | return failure(); |
140 | } |
141 | if (!permutedOuterDims.empty()) { |
142 | int64_t outerDimIndex = 0; |
143 | llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(), |
144 | permutedOuterDims.end()); |
145 | for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) |
146 | packInfo.outerDimsOnDomainPerm.push_back( |
147 | permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] |
148 | : i); |
149 | LLVM_DEBUG({ |
150 | llvm::dbgs() << "map outer dimsDimsPerm to " ; |
151 | for (auto dim : packInfo.outerDimsOnDomainPerm) |
152 | llvm::dbgs() << dim << " " ; |
153 | llvm::dbgs() << "\n" ; |
154 | }); |
155 | } |
156 | |
157 | return packInfo; |
158 | } |
159 | |
160 | static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, |
161 | ArrayRef<AffineExpr> exprs) { |
162 | // Compute `outer_dims_perm`. See example: |
163 | // current exprs : (d0, d1, d2, d3) -> (d2, d3) |
164 | // perm : [0, 3, 1, 2] |
165 | // First map d2, d3 with their position in the array as: |
166 | // currentPositionTileLoops: dim | pos |
167 | // d2 | 0 |
168 | // d3 | 1 |
169 | // then scan `perm` in order and get the `outer_dims_perm` |
170 | // to be used, here it would be [1, 0]. |
171 | assert(!perm.empty() && "expect perm not to be empty" ); |
172 | assert(!exprs.empty() && "expect exprs not to be empty" ); |
173 | if (exprs.size() == 1) |
174 | return {}; |
175 | SmallVector<int64_t> outerDimsPerm; |
176 | DenseMap<int64_t, int64_t> currentPositionTileLoops; |
177 | for (auto [pos, expr] : llvm::enumerate(exprs)) { |
178 | // Here we rely on the assumption that the outer dims permutation |
179 | // when propagating currently requires that non-affine dim expressions |
180 | // are not permuted, thus allowing the identity assignment below. |
181 | if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) |
182 | currentPositionTileLoops[dimExpr.getPosition()] = pos; |
183 | else |
184 | currentPositionTileLoops[pos] = pos; |
185 | } |
186 | for (int64_t loopIdx : perm) { |
187 | if (currentPositionTileLoops.count(loopIdx)) |
188 | outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); |
189 | } |
190 | return outerDimsPerm; |
191 | } |
192 | |
193 | /// Returns a tuple for packed operand and indexing_map with the assumptions: |
194 | /// 1) The generic op is the producer of the pack op. |
195 | /// 2) The generic op has only one result. |
196 | /// If the operand is a scalar or packing dimensions are all irrelevant to the |
197 | /// operand, the operand and the updated indexing map will be returned. |
198 | /// Otherwise, it returns the packed operand and the updated indexing map. E.g., |
199 | /// |
200 | /// #map0 = affine_map<(d0, d1) -> (d0, d1)> |
201 | /// #map1 = affine_map<(d0, d1) -> (d0)> |
202 | /// #map2 = affine_map<(d0, d1) -> (d1)> |
203 | /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], |
204 | /// iterator_types = ["parallel", "parallel"]} |
205 | /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) |
206 | /// outs(%init : tensor<?x?xf32>) { |
207 | /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): |
208 | /// %4 = arith.addf %arg3, %arg4 : f32 |
209 | /// linalg.yield %4 : f32 |
210 | /// } -> tensor<?x?xf32> |
211 | /// %1 = tensor.pack %0 |
212 | /// inner_dims_pos = [0, 1] |
213 | /// inner_tiles = [8, 2] |
214 | /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> |
215 | /// |
216 | /// Taking the first input operand as an example, the inner tile size of d1 is |
217 | /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> |
218 | /// affine_map<(d1, d3)>` will be returned. |
219 | /// |
220 | /// %pack = tensor.pack %arg0 |
221 | /// inner_dims_pos = [0] |
222 | /// inner_tiles = [8] |
223 | /// into %init : tensor<?xf32> -> tensor<?x8xf32> |
224 | static std::tuple<Value, AffineMap> |
225 | getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, |
226 | GenericOp genericOp, OpOperand *opOperand) { |
227 | int64_t numOrigLoops = genericOp.getNumLoops(); |
228 | int64_t numInnerLoops = packInfo.getNumTiledLoops(); |
229 | int64_t numLoops = numOrigLoops + numInnerLoops; |
230 | AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); |
231 | llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim; |
232 | SmallVector<AffineExpr> exprs(origIndexingMap.getResults()); |
233 | |
234 | // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. |
235 | if (genericOp.isScalar(opOperand) || exprs.empty()) |
236 | return std::make_tuple(opOperand->get(), |
237 | AffineMap::get(numLoops, 0, exprs, b.getContext())); |
238 | |
239 | // Step 1. Construct the information of packing data dimensions; append inner |
240 | // dimensions to the indexing maps for the operand. |
241 | for (auto [index, expr] : llvm::enumerate(exprs)) { |
242 | if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { |
243 | int64_t dimPos = dimExpr.getPosition(); |
244 | domainDimToOperandDim[dimPos] = index; |
245 | continue; |
246 | } |
247 | } |
248 | SmallVector<int64_t> innerDimsPos; |
249 | SmallVector<OpFoldResult> innerTileSizes; |
250 | for (auto dimPos : packInfo.tiledDimsPos) { |
251 | if (!domainDimToOperandDim.count(dimPos)) |
252 | continue; |
253 | int64_t index = domainDimToOperandDim[dimPos]; |
254 | innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); |
255 | innerDimsPos.push_back(index); |
256 | exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); |
257 | } |
258 | |
259 | // Step 2. Handle outer dim permutations. |
260 | SmallVector<int64_t> outerDimsPerm; |
261 | if (!packInfo.outerDimsOnDomainPerm.empty()) { |
262 | outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); |
263 | |
264 | // Step 2.1: Fold transpose into the linalg.generic. |
265 | SmallVector<int64_t> inversedOuterPerm = |
266 | invertPermutationVector(packInfo.outerDimsOnDomainPerm); |
267 | for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) { |
268 | if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) { |
269 | int64_t dimPos = dimExpr.getPosition(); |
270 | exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); |
271 | continue; |
272 | } |
273 | assert(isa<AffineConstantExpr>(exprs[i]) && |
274 | "Attempted to permute non-constant and non-affine dim expression" ); |
275 | } |
276 | // Step 2.2: Undo the transposition on `exprs` and propagate the |
277 | // transposition on the pack using outerDimsPerm. |
278 | if (!outerDimsPerm.empty()) { |
279 | SmallVector<AffineExpr> auxVec = exprs; |
280 | for (const auto &en : enumerate(outerDimsPerm)) |
281 | auxVec[en.index()] = exprs[en.value()]; |
282 | exprs = auxVec; |
283 | } |
284 | } |
285 | auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); |
286 | |
287 | // The operand does not have dimensions that relates to pack op. |
288 | if (innerDimsPos.empty() && outerDimsPerm.empty()) |
289 | return std::make_tuple(opOperand->get(), indexingMap); |
290 | |
291 | auto empty = tensor::PackOp::createDestinationTensor( |
292 | b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); |
293 | auto packedOperand = b.create<tensor::PackOp>( |
294 | loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, |
295 | /*padding=*/std::nullopt, outerDimsPerm); |
296 | return std::make_tuple(packedOperand, indexingMap); |
297 | } |
298 | |
299 | /// Pack a genericOp and return it. |
300 | static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, |
301 | Value dest, AffineMap packedOutIndexingMap, |
302 | const PackInfo &packInfo) { |
303 | Location loc = genericOp.getLoc(); |
304 | SmallVector<Value> inputOperands; |
305 | SmallVector<AffineMap> indexingMaps; |
306 | for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { |
307 | auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( |
308 | rewriter, loc, packInfo, genericOp, inputOperand); |
309 | inputOperands.push_back(packedOperand); |
310 | indexingMaps.push_back(packedIndexingMap); |
311 | } |
312 | |
313 | int64_t numInnerLoops = packInfo.getNumTiledLoops(); |
314 | SmallVector<utils::IteratorType> iterTypes = |
315 | genericOp.getIteratorTypesArray(); |
316 | iterTypes.append(numInnerLoops, utils::IteratorType::parallel); |
317 | |
318 | indexingMaps.push_back(packedOutIndexingMap); |
319 | |
320 | auto newGenericOp = rewriter.create<linalg::GenericOp>( |
321 | loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, |
322 | /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); |
323 | rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), |
324 | newGenericOp.getRegion().begin()); |
325 | return newGenericOp; |
326 | } |
327 | |
328 | /// Bubbles up tensor.pack op through a producer generic op. This |
329 | /// swap pack(generic) to generic(pack). The new generic op works on packed |
330 | /// domain; pack ops are created for input and output operands. E.g., |
331 | /// |
332 | /// #map0 = affine_map<(d0, d1) -> (d0, d1)> |
333 | /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> |
334 | /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> |
335 | /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32> |
336 | /// %3 = linalg.generic {indexing_maps = [#map0, #map0], |
337 | /// iterator_types = ["parallel", "parallel"]} |
338 | /// ins(%arg0 : tensor<?x?xf32>) |
339 | /// outs(%2 : tensor<?x?xf32>) { |
340 | /// ^bb0(%arg3: f32, %arg4: f32): |
341 | /// %4 = arith.addf %arg3, %arg3 : f32 |
342 | /// linalg.yield %4 : f32 |
343 | /// } -> tensor<?x?xf32> |
344 | /// %4 = tensor.pack %3 |
345 | /// inner_dims_pos = [0, 1] |
346 | /// inner_tiles = [8, 2] |
347 | /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> |
348 | /// |
349 | /// will be converted to |
350 | /// |
351 | /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> |
352 | /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> |
353 | /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
354 | /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> |
355 | /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> |
356 | /// %0 = affine.apply #map()[%dim] |
357 | /// %1 = affine.apply #map1()[%dim_0] |
358 | /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32> |
359 | /// %pack = tensor.pack %arg0 |
360 | /// inner_dims_pos = [0, 1] |
361 | /// inner_tiles = [8, 2] |
362 | /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32> |
363 | /// %3 = linalg.generic {indexing_maps = [#map2, #map2], |
364 | /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
365 | /// ins(%pack : tensor<?x?x8x2xf32>) |
366 | /// outs(%arg1 : tensor<?x?x8x2xf32>) { |
367 | /// ^bb0(%in: f32, %out: f32): |
368 | /// %4 = arith.addf %in, %in : f32 |
369 | /// linalg.yield %4 : f32 |
370 | /// } -> tensor<?x?x8x2xf32> |
371 | static FailureOr<GenericOp> |
372 | bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, |
373 | const ControlPropagationFn &controlFn) { |
374 | auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); |
375 | if (!genericOp) |
376 | return failure(); |
377 | |
378 | // User controlled propagation function. |
379 | if (!controlFn(genericOp)) |
380 | return failure(); |
381 | |
382 | // TODO: Enable propagation in the presence of linalg.index and |
383 | // tensor.extract, likely as a separate pattern as the pack information and |
384 | // propagation decision needs to be inferred from the region of the generic. |
385 | if (hasGatherSemantics(genericOp)) |
386 | return failure(); |
387 | |
388 | // TODO: Relax the restriction. We are able to bubble up the pack op through |
389 | // multi-result generic op. It just needs more work. |
390 | if (genericOp.getNumResults() != 1) |
391 | return failure(); |
392 | |
393 | // Bail-out if the result of the generic has multiple uses, as bubbling up |
394 | // creates recomputation if the generic has multiple users. |
395 | // TODO: Enable the case where every use is an identical pack op as no |
396 | // recomputation is needed in that case. |
397 | if (!genericOp->getResult(0).hasOneUse()) |
398 | return failure(); |
399 | |
400 | // We want to move the pack not the generic. |
401 | OpBuilder::InsertionGuard guard(rewriter); |
402 | rewriter.setInsertionPoint(genericOp); |
403 | |
404 | // We need to handle two cases: |
405 | // 1) The tensor.pack destination is a tensor.empty. If this is the case, we |
406 | // create a new tensor.empty to avoid breaking dominance, as we are moving the |
407 | // tensor.pack above the linalg.generic. |
408 | // 2) The destination is not a tensor.empty. In this case we can replace only |
409 | // if the destination of the tensor.pack dominates the linalg.generic. |
410 | Value packOpDest = packOp.getDest(); |
411 | if (!packOpDest.hasOneUse()) |
412 | return failure(); |
413 | if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { |
414 | packOpDest = rewriter.create<tensor::EmptyOp>( |
415 | genericOp->getLoc(), emptyOp.getMixedSizes(), |
416 | emptyOp.getType().getElementType()); |
417 | } else { |
418 | DominanceInfo dom(genericOp); |
419 | if (!dom.properlyDominates(packOpDest, genericOp)) |
420 | return failure(); |
421 | } |
422 | |
423 | // TODO: Add an option for allowing padding values. It could introduce |
424 | // undefined behavior if we unconditionally propagate pack op through all |
425 | // the ops. E.g., if the padding value is zero and there are division ops in |
426 | // a generic op. Some values of padding area could be NaN (0/0). |
427 | if (packOp.getPaddingValue()) |
428 | return failure(); |
429 | |
430 | OpOperand *opOperand = genericOp.getDpsInitOperand(0); |
431 | auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); |
432 | if (failed(packInfo)) |
433 | return failure(); |
434 | |
435 | // Rebuild the indexing map for the corresponding init operand. |
436 | auto [packedOutOperand, packedOutIndexingMap] = |
437 | getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, |
438 | genericOp, opOperand); |
439 | |
440 | // If the dps init operand of the generic is a tensor.empty forward the pack |
441 | // op destination. |
442 | Value dest = packedOutOperand; |
443 | if (auto initTensor = genericOp.getDpsInitOperand(0) |
444 | ->get() |
445 | .getDefiningOp<tensor::EmptyOp>()) { |
446 | dest = packOpDest; |
447 | } |
448 | return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, |
449 | *packInfo); |
450 | } |
451 | |
452 | /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. |
453 | struct BubbleUpPackOpThroughGenericOpPattern |
454 | : public OpRewritePattern<tensor::PackOp> { |
455 | public: |
456 | BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, |
457 | ControlPropagationFn fun) |
458 | : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} |
459 | |
460 | LogicalResult matchAndRewrite(tensor::PackOp packOp, |
461 | PatternRewriter &rewriter) const override { |
462 | auto genericOp = |
463 | bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); |
464 | if (failed(genericOp)) |
465 | return failure(); |
466 | rewriter.replaceOp(packOp, genericOp->getResults()); |
467 | return success(); |
468 | } |
469 | |
470 | private: |
471 | ControlPropagationFn controlFn; |
472 | }; |
473 | |
474 | /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to |
475 | /// add as many zero padding dimensions in `high` and `low` based on the number |
476 | /// of point loops. |
477 | class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> { |
478 | public: |
479 | BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) |
480 | : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} |
481 | |
482 | LogicalResult matchAndRewrite(tensor::PackOp packOp, |
483 | PatternRewriter &rewriter) const override { |
484 | auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>(); |
485 | if (!padOp) |
486 | return failure(); |
487 | |
488 | // User controlled propagation function. |
489 | if (!controlFn(padOp)) |
490 | return failure(); |
491 | |
492 | if (!padOp.getResult().hasOneUse()) |
493 | return failure(); |
494 | |
495 | // TODO: Enable padding when the padding values are the same. |
496 | if (packOp.getPaddingValue()) |
497 | return failure(); |
498 | |
499 | // Fail for non-constant padding values. The body of the pad could |
500 | // depend on the padding indices and/or properties of the padded |
501 | // tensor so for now we fail. |
502 | // TODO: Support non-constant padding values. |
503 | Value paddingVal = padOp.getConstantPaddingValue(); |
504 | if (!paddingVal) |
505 | return failure(); |
506 | |
507 | if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>()) |
508 | return failure(); |
509 | |
510 | ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); |
511 | ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); |
512 | |
513 | // Bail out if one of the padded dimension is a tiled one. |
514 | llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); |
515 | llvm::SmallBitVector innerDims(paddedDims.size()); |
516 | for (int64_t dim : innerDimsPos) |
517 | innerDims.flip(dim); |
518 | if (paddedDims.anyCommon(RHS: innerDims)) |
519 | return failure(); |
520 | |
521 | Location loc = padOp->getLoc(); |
522 | OpBuilder::InsertionGuard guard(rewriter); |
523 | rewriter.setInsertionPoint(padOp); |
524 | |
525 | auto empty = tensor::PackOp::createDestinationTensor( |
526 | rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos, |
527 | outerDimsPerm); |
528 | Value packedSource = rewriter.create<tensor::PackOp>( |
529 | loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(), |
530 | /*padding=*/std::nullopt, outerDimsPerm); |
531 | |
532 | // If we have `outer_dims_perms` we need to adjust the padded dimensions. |
533 | SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); |
534 | SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); |
535 | if (!outerDimsPerm.empty()) { |
536 | applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); |
537 | applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); |
538 | } |
539 | // The tiled dimensions were verified to be unpadded above, so here we |
540 | // just append 0 for the inner tile dimensions. |
541 | size_t pointLoopsSize = innerDimsPos.size(); |
542 | lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
543 | highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
544 | |
545 | auto newPadOp = rewriter.create<tensor::PadOp>( |
546 | loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal, |
547 | padOp.getNofold()); |
548 | rewriter.replaceOp(packOp, newPadOp.getResult()); |
549 | return success(); |
550 | } |
551 | |
552 | private: |
553 | ControlPropagationFn controlFn; |
554 | }; |
555 | |
556 | /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices. |
557 | /// |
558 | /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and |
559 | /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the |
560 | /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most |
561 | /// non-unit projected dims in pos [2, 3] is 2. |
562 | /// |
563 | /// If all candidates in a reassociation are unit dims, it chooses the |
564 | /// inner-most dim pos. |
565 | static SmallVector<int64_t> |
566 | projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos, |
567 | ArrayRef<ReassociationIndices> reassocIndices, |
568 | ArrayRef<int64_t> targetShape) { |
569 | SmallVector<int64_t> projectedDimsPos; |
570 | for (auto pos : dimsPos) { |
571 | // In the case all dims are unit, this will return the inner-most one. |
572 | int64_t projectedPos = reassocIndices[pos].back(); |
573 | for (auto i : llvm::reverse(reassocIndices[pos])) { |
574 | int64_t dim = targetShape[i]; |
575 | if (dim > 1 || ShapedType::isDynamic(dim)) { |
576 | projectedPos = i; |
577 | break; |
578 | } |
579 | } |
580 | projectedDimsPos.push_back(projectedPos); |
581 | } |
582 | return projectedDimsPos; |
583 | } |
584 | |
585 | /// Check if all dims in dimsPos are divisible by the corresponding tile sizes. |
586 | static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos, |
587 | ArrayRef<int64_t> shape, |
588 | ArrayRef<int64_t> tileSizes) { |
589 | for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { |
590 | int64_t dim = shape[pos]; |
591 | if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) |
592 | return false; |
593 | } |
594 | return true; |
595 | } |
596 | |
597 | /// Permutate the reassociation indices and reindex them in the sequence order. |
598 | /// Returns the next dim pos in the sequence. |
599 | /// |
600 | /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it |
601 | /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into |
602 | /// [[0], [1, 2]]. |
603 | static int64_t applyPermutationAndReindexReassoc( |
604 | SmallVector<ReassociationIndices> &reassocIndices, |
605 | ArrayRef<int64_t> permutation) { |
606 | applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation); |
607 | int64_t nextPos = 0; |
608 | for (ReassociationIndices &indices : reassocIndices) { |
609 | for (auto &index : indices) { |
610 | index = nextPos; |
611 | nextPos += 1; |
612 | } |
613 | } |
614 | return nextPos; |
615 | } |
616 | |
617 | /// Bubble up pack op through collapse shape op when the packed dims can be |
618 | /// projected to the dims before collapsing. This is possible when the inner |
619 | /// tile sizes can divide the projected dims. |
620 | /// |
621 | /// For example: |
622 | /// |
623 | /// %collapsed = tensor.collapse_shape %in [[0, 1], 2] |
624 | /// : tensor<?x16x4xf32> into tensor<?x4xf32> |
625 | /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] |
626 | /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty |
627 | /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32> |
628 | /// |
629 | /// can be transformed into: |
630 | /// |
631 | /// %pack = tensor.pack %in outer_dims_perm = [1, 2] |
632 | /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty |
633 | /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> |
634 | /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4] |
635 | /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1> |
636 | static LogicalResult |
637 | bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, |
638 | tensor::PackOp packOp, |
639 | PatternRewriter &rewriter) { |
640 | SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles(); |
641 | ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); |
642 | ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); |
643 | |
644 | ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape(); |
645 | SmallVector<ReassociationIndices> reassocIndices = |
646 | collapseOp.getReassociationIndices(); |
647 | // Project inner tile pos to the dim pos before collapsing. For example, if |
648 | // dims [x, y] is collapsed into [z], packing on dim z can be projected back |
649 | // to pack on dim y. |
650 | // |
651 | // Project to inner-most non-unit dims to increase the chance that they can be |
652 | // divided by the inner tile sizes. This is correct because for [..., x, 1], |
653 | // packing on dim 1 is equivalent to packing on dim x. |
654 | SmallVector<int64_t> projectedInnerDimsPos = |
655 | projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape); |
656 | |
657 | if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape, |
658 | innerTileSizes)) { |
659 | return failure(); |
660 | } |
661 | // Expand the outer dims permutation with the associated source dims for the |
662 | // new permutation after bubbling. This is because moving a collapsed dim is |
663 | // equivalent to moving the associated source dims together. |
664 | SmallVector<int64_t> newOuterDimsPerm; |
665 | for (auto outerPos : outerDimsPerm) { |
666 | newOuterDimsPerm.insert(newOuterDimsPerm.end(), |
667 | reassocIndices[outerPos].begin(), |
668 | reassocIndices[outerPos].end()); |
669 | } |
670 | |
671 | auto emptyOp = tensor::PackOp::createDestinationTensor( |
672 | rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), |
673 | projectedInnerDimsPos, newOuterDimsPerm); |
674 | auto newPackOp = rewriter.create<tensor::PackOp>( |
675 | packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, |
676 | packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); |
677 | |
678 | SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; |
679 | // First apply the permutation on the reassociations of the outer dims. |
680 | // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] |
681 | // -> [[0], [1, 2]] |
682 | int64_t nextPos = |
683 | applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); |
684 | // Then add direct mapping for the inner tile dims. |
685 | for (size_t i = 0; i < innerDimsPos.size(); ++i) { |
686 | newReassocIndices.push_back({nextPos}); |
687 | nextPos += 1; |
688 | } |
689 | |
690 | auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( |
691 | collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); |
692 | rewriter.replaceOp(packOp, newCollapseOp); |
693 | |
694 | return success(); |
695 | } |
696 | |
697 | class BubbleUpPackOpThroughReshapeOp final |
698 | : public OpRewritePattern<tensor::PackOp> { |
699 | public: |
700 | BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun) |
701 | : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} |
702 | |
703 | LogicalResult matchAndRewrite(tensor::PackOp packOp, |
704 | PatternRewriter &rewriter) const override { |
705 | Operation *srcOp = packOp.getSource().getDefiningOp(); |
706 | // Currently only support when the pack op is the only user. |
707 | if (!srcOp || !(srcOp->getNumResults() == 1) || |
708 | !srcOp->getResult(idx: 0).hasOneUse()) { |
709 | return failure(); |
710 | } |
711 | // Currently only support static inner tile sizes. |
712 | if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) { |
713 | return ShapedType::isDynamic(size); |
714 | })) { |
715 | return failure(); |
716 | } |
717 | |
718 | // User controlled propagation function. |
719 | if (!controlFn(srcOp)) |
720 | return failure(); |
721 | |
722 | return TypeSwitch<Operation *, LogicalResult>(srcOp) |
723 | .Case([&](tensor::CollapseShapeOp op) { |
724 | return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter); |
725 | }) |
726 | .Default([](Operation *) { return failure(); }); |
727 | } |
728 | |
729 | private: |
730 | ControlPropagationFn controlFn; |
731 | }; |
732 | |
733 | /// Push down unpack op through expand shape op when the packed dims can be |
734 | /// projected to the dims after expanding. This is possible when the inner tile |
735 | /// sizes can divide the projected dims. |
736 | /// |
737 | /// For example: |
738 | /// |
739 | /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1] |
740 | /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty |
741 | /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32> |
742 | /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]] |
743 | /// : tensor<?x256xf32> into tensor<?x256x256xf32> |
744 | /// |
745 | /// can be transformed into: |
746 | /// |
747 | /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]] |
748 | /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> |
749 | /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2] |
750 | /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty |
751 | /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> |
752 | static LogicalResult |
753 | pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp, |
754 | tensor::ExpandShapeOp expandOp, |
755 | PatternRewriter &rewriter) { |
756 | SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles(); |
757 | ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos(); |
758 | ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm(); |
759 | |
760 | ArrayRef<int64_t> dstShape = expandOp.getType().getShape(); |
761 | SmallVector<ReassociationIndices> reassocIndices = |
762 | expandOp.getReassociationIndices(); |
763 | // Project inner tile pos to the dim pos after expanding. For example, if dims |
764 | // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack |
765 | // on dim y. |
766 | // |
767 | // Project to inner-most non-unit dims to increase the chance that they can be |
768 | // divided by the inner tile sizes. This is correct because for [..., x, 1], |
769 | // unpacking on dim 1 is equivalent to unpacking on dim x. |
770 | SmallVector<int64_t> projectedInnerDimsPos = |
771 | projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape); |
772 | |
773 | if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape, |
774 | innerTileSizes)) { |
775 | return failure(); |
776 | } |
777 | // Expand the outer dims permutation with the associated expanded dims for the |
778 | // new permutation after pushing. This is because moving a source dim is |
779 | // equivalent to moving the associated expanded dims together. |
780 | SmallVector<int64_t> newOuterDimsPerm; |
781 | for (auto outerPos : outerDimsPerm) { |
782 | newOuterDimsPerm.insert(newOuterDimsPerm.end(), |
783 | reassocIndices[outerPos].begin(), |
784 | reassocIndices[outerPos].end()); |
785 | } |
786 | |
787 | SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; |
788 | // First apply the permutation on the reassociations of the outer dims. |
789 | // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] |
790 | // -> [[0], [1, 2]] |
791 | int64_t nextPos = |
792 | applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); |
793 | // Then add direct mapping for the inner tile dims. |
794 | for (size_t i = 0; i < innerDimsPos.size(); ++i) { |
795 | newReassocIndices.push_back({nextPos}); |
796 | nextPos += 1; |
797 | } |
798 | |
799 | RankedTensorType newExpandType = |
800 | tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes, |
801 | projectedInnerDimsPos, newOuterDimsPerm); |
802 | auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( |
803 | expandOp.getLoc(), newExpandType, unPackOp.getSource(), |
804 | newReassocIndices); |
805 | |
806 | auto emptyOp = tensor::UnPackOp::createDestinationTensor( |
807 | rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), |
808 | projectedInnerDimsPos, newOuterDimsPerm); |
809 | auto newUnPackOp = rewriter.create<tensor::UnPackOp>( |
810 | unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, |
811 | projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); |
812 | rewriter.replaceOp(expandOp, newUnPackOp); |
813 | |
814 | return success(); |
815 | } |
816 | |
817 | class PushDownUnPackOpThroughReshapeOp final |
818 | : public OpRewritePattern<tensor::UnPackOp> { |
819 | public: |
820 | PushDownUnPackOpThroughReshapeOp(MLIRContext *context, |
821 | ControlPropagationFn fun) |
822 | : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) { |
823 | } |
824 | |
825 | LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp, |
826 | PatternRewriter &rewriter) const override { |
827 | Value result = unPackOp.getResult(); |
828 | // Currently only support unpack op with the single user. |
829 | if (!result.hasOneUse()) { |
830 | return failure(); |
831 | } |
832 | // Currently only support static inner tile sizes. |
833 | if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) { |
834 | return ShapedType::isDynamic(size); |
835 | })) { |
836 | return failure(); |
837 | } |
838 | |
839 | Operation *consumerOp = *result.user_begin(); |
840 | // User controlled propagation function. |
841 | if (!controlFn(consumerOp)) |
842 | return failure(); |
843 | |
844 | return TypeSwitch<Operation *, LogicalResult>(consumerOp) |
845 | .Case([&](tensor::ExpandShapeOp op) { |
846 | return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter); |
847 | }) |
848 | .Default([](Operation *) { return failure(); }); |
849 | } |
850 | |
851 | private: |
852 | ControlPropagationFn controlFn; |
853 | }; |
854 | |
855 | // TODO: Relax this restriction. We should unpack a generic op also |
856 | // in the presence of multiple unpack ops as producers. |
857 | /// Return the unpacked operand, if present, for the current generic op. |
858 | static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { |
859 | OpOperand *unPackedOperand = nullptr; |
860 | for (OpOperand &operand : genericOp->getOpOperands()) { |
861 | auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>(); |
862 | if (!unPackOp) |
863 | continue; |
864 | if (unPackedOperand) |
865 | return failure(); |
866 | unPackedOperand = &operand; |
867 | } |
868 | if (!unPackedOperand) |
869 | return failure(); |
870 | return unPackedOperand; |
871 | } |
872 | |
873 | /// Push down a tensor.unpack op through a generic op. |
874 | /// The new generic op works on packed domain; pack ops are created for input |
875 | /// and output operands. A tensor.unpack op is inserted right after the packed |
876 | /// generic. E.g. |
877 | /// |
878 | /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
879 | /// |
880 | /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. |
881 | /// |
882 | /// %0 = tensor.empty() : tensor<12x56x56x64xf32> |
883 | /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] |
884 | /// inner_dims_pos = [3] inner_tiles = [32] into %0 |
885 | /// %2 = linalg.generic {indexing_maps = [#map], |
886 | /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
887 | /// outs(%1 : tensor<12x56x56x64xf32>) { |
888 | /// ^bb0(%out : f32): |
889 | /// linalg.yield %out : f32 |
890 | /// } -> tensor<12x56x56x64xf32> |
891 | /// |
892 | /// will be converted to |
893 | /// |
894 | /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> |
895 | /// |
896 | /// %0 = tensor.empty() : tensor<12x56x56x64xf32> |
897 | /// %1 = linalg.generic {indexing_maps = [#map], |
898 | /// iterator_types = ["parallel", "parallel", "parallel", |
899 | /// "parallel", "parallel"]} |
900 | /// outs(%arg0 : tensor<12x2x56x56x32xf32>) { |
901 | /// ^bb0(%out : f32): |
902 | /// linalg.yield %out : f32 |
903 | /// } -> tensor<12x2x56x56x32xf32> |
904 | /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2] |
905 | /// inner_dims_pos = [3] inner_tiles = [32] into %0 |
906 | /// |
907 | static FailureOr<std::tuple<GenericOp, Value>> |
908 | pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) { |
909 | if (genericOp.getNumResults() != 1) |
910 | return failure(); |
911 | |
912 | if (hasGatherSemantics(genericOp)) |
913 | return failure(); |
914 | |
915 | // Collect the unPacked operand, if present. |
916 | auto maybeUnPackedOperand = getUnPackedOperand(genericOp); |
917 | if (failed(maybeUnPackedOperand)) |
918 | return failure(); |
919 | OpOperand *unPackedOperand = *(maybeUnPackedOperand); |
920 | |
921 | // Extract packing information. |
922 | tensor::UnPackOp producerUnPackOp = |
923 | unPackedOperand->get().getDefiningOp<tensor::UnPackOp>(); |
924 | assert(producerUnPackOp && "expect a valid UnPackOp" ); |
925 | auto packInfo = |
926 | getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); |
927 | if (failed(packInfo)) |
928 | return failure(); |
929 | |
930 | // Rebuild the indexing map for the corresponding init operand. |
931 | auto [packedOutOperand, packedOutIndexingMap] = |
932 | getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, |
933 | genericOp, genericOp.getDpsInitOperand(0)); |
934 | auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>(); |
935 | |
936 | // If the dps init operand of the generic is a tensor.empty, do not pack it |
937 | // and forward the new tensor.empty as a destination. |
938 | Value dest = packedOutOperand; |
939 | if (auto initTensor = genericOp.getDpsInitOperand(0) |
940 | ->get() |
941 | .getDefiningOp<tensor::EmptyOp>()) { |
942 | if (destPack) |
943 | dest = destPack.getDest(); |
944 | } |
945 | |
946 | // Pack the genericOp. |
947 | GenericOp newGenericOp = |
948 | packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); |
949 | Value newResult = |
950 | newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); |
951 | |
952 | // If the output is unaffected, no need to unpack. |
953 | if (!destPack) |
954 | return std::make_tuple(newGenericOp, newResult); |
955 | |
956 | auto mixedTiles = destPack.getMixedTiles(); |
957 | auto innerDimsPos = destPack.getInnerDimsPos(); |
958 | auto outerDimsPerm = destPack.getOuterDimsPerm(); |
959 | |
960 | // If the output type for the generic differs from the source |
961 | // unpack op, we need to create a new destination tensor. In the |
962 | // dynamic case we always need a new destination. |
963 | auto loc = genericOp.getLoc(); |
964 | Value unPackDest = producerUnPackOp.getDest(); |
965 | auto genericOutType = |
966 | cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType()); |
967 | if (producerUnPackOp.getDestType() != genericOutType || |
968 | !genericOutType.hasStaticShape()) { |
969 | unPackDest = tensor::UnPackOp::createDestinationTensor( |
970 | rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm); |
971 | } |
972 | |
973 | // Insert an unPackOp right after the packed generic. |
974 | Value unPackOpRes = |
975 | rewriter |
976 | .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos, |
977 | mixedTiles, outerDimsPerm) |
978 | .getResult(); |
979 | |
980 | return std::make_tuple(newGenericOp, unPackOpRes); |
981 | } |
982 | |
983 | // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. |
984 | struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { |
985 | public: |
986 | PushDownUnPackOpThroughGenericOp(MLIRContext *context, |
987 | ControlPropagationFn fun) |
988 | : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} |
989 | |
990 | LogicalResult matchAndRewrite(GenericOp genericOp, |
991 | PatternRewriter &rewriter) const override { |
992 | if (!controlFn(genericOp)) |
993 | return failure(); |
994 | |
995 | auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp); |
996 | if (failed(genericAndRepl)) |
997 | return failure(); |
998 | rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); |
999 | return success(); |
1000 | } |
1001 | |
1002 | private: |
1003 | ControlPropagationFn controlFn; |
1004 | }; |
1005 | |
1006 | /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to |
1007 | /// add as many zero padding dimensions in `high` and `low` based on the number |
1008 | /// of point loops. |
1009 | struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { |
1010 | PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) |
1011 | : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {} |
1012 | |
1013 | LogicalResult matchAndRewrite(tensor::PadOp padOp, |
1014 | PatternRewriter &rewriter) const override { |
1015 | tensor::UnPackOp unpackOp = |
1016 | padOp.getSource().getDefiningOp<tensor::UnPackOp>(); |
1017 | if (!unpackOp) |
1018 | return failure(); |
1019 | |
1020 | if (!controlFn(padOp)) |
1021 | return failure(); |
1022 | |
1023 | Location loc = padOp.getLoc(); |
1024 | // Bail out if one of the padded dimension is a tiled one. |
1025 | llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); |
1026 | ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); |
1027 | llvm::SmallBitVector innerDims(paddedDims.size()); |
1028 | for (int64_t dim : innerDimsPos) |
1029 | innerDims.flip(dim); |
1030 | if (paddedDims.anyCommon(RHS: innerDims)) |
1031 | return failure(); |
1032 | |
1033 | Value paddingVal = padOp.getConstantPaddingValue(); |
1034 | if (!paddingVal) |
1035 | return failure(); |
1036 | |
1037 | // If we have `outer_dims_perms` we need to adjust the padded dimensions. |
1038 | ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); |
1039 | SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); |
1040 | SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); |
1041 | if (!outerDimsPerm.empty()) { |
1042 | applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); |
1043 | applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); |
1044 | } |
1045 | // Add zero padding for the point loops. |
1046 | size_t pointLoopsSize = innerDimsPos.size(); |
1047 | lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
1048 | highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
1049 | |
1050 | auto newPadOp = rewriter.create<tensor::PadOp>( |
1051 | loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, |
1052 | paddingVal, padOp.getNofold()); |
1053 | |
1054 | // Inject the tensor.unpack right after the packed padOp. |
1055 | Value outputUnPack = rewriter.create<tensor::EmptyOp>( |
1056 | loc, padOp.getResultType().getShape(), |
1057 | padOp.getResultType().getElementType()); |
1058 | |
1059 | Value replacement = rewriter.create<tensor::UnPackOp>( |
1060 | loc, newPadOp.getResult(), outputUnPack, innerDimsPos, |
1061 | unpackOp.getMixedTiles(), outerDimsPerm); |
1062 | rewriter.replaceOp(padOp, replacement); |
1063 | return success(); |
1064 | } |
1065 | |
1066 | private: |
1067 | ControlPropagationFn controlFn; |
1068 | }; |
1069 | |
1070 | } // namespace |
1071 | |
1072 | void mlir::linalg::populateDataLayoutPropagationPatterns( |
1073 | RewritePatternSet &patterns, |
1074 | const ControlPropagationFn &controlPackUnPackPropagation) { |
1075 | patterns |
1076 | .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp, |
1077 | BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp, |
1078 | PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( |
1079 | arg: patterns.getContext(), args: controlPackUnPackPropagation); |
1080 | } |
1081 | |