1 | //===- DataLayoutPropagation.cpp -----------------------------------------===/// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
13 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
14 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
16 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
17 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
18 | #include "mlir/IR/Dominance.h" |
19 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
20 | #include "llvm/ADT/SetOperations.h" |
21 | #include "llvm/ADT/SetVector.h" |
22 | #include "llvm/ADT/TypeSwitch.h" |
23 | #include "llvm/Support/Debug.h" |
24 | #include <optional> |
25 | |
26 | namespace mlir { |
27 | #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION |
28 | #include "mlir/Dialect/Linalg/Passes.h.inc" |
29 | } // namespace mlir |
30 | |
31 | using namespace mlir; |
32 | using namespace mlir::linalg; |
33 | |
34 | #define DEBUG_TYPE "linalg-data-layout-propagation" |
35 | |
36 | namespace { |
37 | |
38 | static bool hasGatherSemantics(linalg::GenericOp genericOp) { |
39 | for (Operation &op : genericOp.getBody()->getOperations()) |
40 | if (isa<tensor::ExtractOp, linalg::IndexOp>(op)) |
41 | return true; |
42 | return false; |
43 | } |
44 | |
45 | // The struct contains the infomation about mapping packing information to |
46 | // the iteration domain of Linalg ops. |
47 | struct PackInfo { |
48 | int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; |
49 | // InnerDimsPos on iteration domain, which follows the order in pack ops. |
50 | SmallVector<int64_t> tiledDimsPos; |
51 | // The sizes of tiling data dimensions on iteration domain. |
52 | llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping; |
53 | // The mapping from a dimension of iteration domain to the corresponding inner |
54 | // tiling dimension on iteration domain. |
55 | llvm::DenseMap<int64_t, int64_t> tileToPointMapping; |
56 | // The permutation of outer dims (on domain). |
57 | SmallVector<int64_t> outerDimsOnDomainPerm; |
58 | }; |
59 | |
60 | template <typename OpTy> |
61 | static FailureOr<PackInfo> |
62 | getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, |
63 | OpTy packOrUnPackOp) { |
64 | static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value, |
65 | "applies to only pack or unpack operations"); |
66 | LLVM_DEBUG( |
67 | { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; }); |
68 | |
69 | AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); |
70 | SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); |
71 | SmallVector<utils::IteratorType> iterators = |
72 | genericOp.getIteratorTypesArray(); |
73 | |
74 | PackInfo packInfo; |
75 | int64_t origNumDims = indexingMap.getNumDims(); |
76 | SmallVector<AffineExpr> exprs(indexingMap.getResults()); |
77 | ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos(); |
78 | for (auto [index, innerDimPos, tileSize] : |
79 | llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()), |
80 | innerDimsPos, packOrUnPackOp.getMixedTiles())) { |
81 | auto expr = exprs[innerDimPos]; |
82 | if (!isa<AffineDimExpr>(expr)) |
83 | return failure(); |
84 | int64_t domainDimPos = |
85 | cast<AffineDimExpr>(exprs[innerDimPos]).getPosition(); |
86 | if (!isParallelIterator(iterators[domainDimPos])) |
87 | return failure(); |
88 | packInfo.tiledDimsPos.push_back(domainDimPos); |
89 | packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; |
90 | packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; |
91 | LLVM_DEBUG({ |
92 | llvm::dbgs() << "map innerDimPos="<< innerDimPos |
93 | << " to iteration dimension (d"<< domainDimPos << ", d" |
94 | << packInfo.tileToPointMapping[domainDimPos] |
95 | << "), which has size=(" |
96 | << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n"; |
97 | }); |
98 | } |
99 | |
100 | // Bail out if a tiled dimension is present in a map but not as an affine dim |
101 | // expression. |
102 | auto areAllAffineDimExpr = [&](int dim) { |
103 | for (AffineMap map : indexingMaps) { |
104 | if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) { |
105 | return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr); |
106 | })) { |
107 | return false; |
108 | } |
109 | } |
110 | return true; |
111 | }; |
112 | for (int64_t i : packInfo.tiledDimsPos) |
113 | if (!areAllAffineDimExpr(i)) |
114 | return failure(); |
115 | |
116 | // Get the outer dims perm on the iteration domain. Start by identifying the |
117 | // set of domain dims affected by the outer permutation along with the |
118 | // permuted ordering for those dims. Then the full outer dims permutation can |
119 | // be constructed by replacing the affected dims with the permuted result in a |
120 | // numLoops-rank identity. e.g. |
121 | // outerDimsPerm = [1, 2, 0] |
122 | // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) |
123 | // |
124 | // permutedOuterDims = [4, 3, 1] |
125 | // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] |
126 | // |
127 | // Non-affine dim expressions must not be permuted by the outer dims |
128 | // permutation. |
129 | SmallVector<int64_t> permutedOuterDims; |
130 | for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { |
131 | auto permutedExpr = indexingMap.getResult(idx: dim); |
132 | if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) { |
133 | permutedOuterDims.push_back(dimExpr.getPosition()); |
134 | continue; |
135 | } |
136 | |
137 | // TODO: Allow propagation with transposes on non affine dim expressions, |
138 | // e.g. d0 + d1 which implies transposing both dims simultaneously while |
139 | // maintaining the relative position between them. |
140 | if (static_cast<int64_t>(index) != dim) |
141 | return failure(); |
142 | } |
143 | if (!permutedOuterDims.empty()) { |
144 | int64_t outerDimIndex = 0; |
145 | llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(), |
146 | permutedOuterDims.end()); |
147 | for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) |
148 | packInfo.outerDimsOnDomainPerm.push_back( |
149 | permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] |
150 | : i); |
151 | LLVM_DEBUG({ |
152 | llvm::dbgs() << "map outer dimsDimsPerm to "; |
153 | for (auto dim : packInfo.outerDimsOnDomainPerm) |
154 | llvm::dbgs() << dim << " "; |
155 | llvm::dbgs() << "\n"; |
156 | }); |
157 | } |
158 | |
159 | return packInfo; |
160 | } |
161 | |
162 | static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, |
163 | ArrayRef<AffineExpr> exprs) { |
164 | // Compute `outer_dims_perm`. See example: |
165 | // current exprs : (d0, d1, d2, d3) -> (d2, d3) |
166 | // perm : [0, 3, 1, 2] |
167 | // First map d2, d3 with their position in the array as: |
168 | // currentPositionTileLoops: dim | pos |
169 | // d2 | 0 |
170 | // d3 | 1 |
171 | // then scan `perm` in order and get the `outer_dims_perm` |
172 | // to be used, here it would be [1, 0]. |
173 | assert(!perm.empty() && "expect perm not to be empty"); |
174 | assert(!exprs.empty() && "expect exprs not to be empty"); |
175 | if (exprs.size() == 1) |
176 | return {}; |
177 | SmallVector<int64_t> outerDimsPerm; |
178 | DenseMap<int64_t, int64_t> currentPositionTileLoops; |
179 | for (auto [pos, expr] : llvm::enumerate(exprs)) { |
180 | // Here we rely on the assumption that the outer dims permutation |
181 | // when propagating currently requires that non-affine dim expressions |
182 | // are not permuted, thus allowing the identity assignment below. |
183 | if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) |
184 | currentPositionTileLoops[dimExpr.getPosition()] = pos; |
185 | else |
186 | currentPositionTileLoops[pos] = pos; |
187 | } |
188 | for (int64_t loopIdx : perm) { |
189 | if (currentPositionTileLoops.count(loopIdx)) |
190 | outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); |
191 | } |
192 | return outerDimsPerm; |
193 | } |
194 | |
195 | /// Returns a tuple for packed operand and indexing_map with the assumptions: |
196 | /// 1) The generic op is the producer of the pack op. |
197 | /// 2) The generic op has only one result. |
198 | /// If the operand is a scalar or packing dimensions are all irrelevant to the |
199 | /// operand, the operand and the updated indexing map will be returned. |
200 | /// Otherwise, it returns the packed operand and the updated indexing map. E.g., |
201 | /// |
202 | /// #map0 = affine_map<(d0, d1) -> (d0, d1)> |
203 | /// #map1 = affine_map<(d0, d1) -> (d0)> |
204 | /// #map2 = affine_map<(d0, d1) -> (d1)> |
205 | /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], |
206 | /// iterator_types = ["parallel", "parallel"]} |
207 | /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) |
208 | /// outs(%init : tensor<?x?xf32>) { |
209 | /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): |
210 | /// %4 = arith.addf %arg3, %arg4 : f32 |
211 | /// linalg.yield %4 : f32 |
212 | /// } -> tensor<?x?xf32> |
213 | /// %1 = linalg.pack %0 |
214 | /// inner_dims_pos = [0, 1] |
215 | /// inner_tiles = [8, 2] |
216 | /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> |
217 | /// |
218 | /// Taking the first input operand as an example, the inner tile size of d1 is |
219 | /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> |
220 | /// affine_map<(d1, d3)>` will be returned. |
221 | /// |
222 | /// %pack = linalg.pack %arg0 |
223 | /// inner_dims_pos = [0] |
224 | /// inner_tiles = [8] |
225 | /// into %init : tensor<?xf32> -> tensor<?x8xf32> |
226 | static std::tuple<Value, AffineMap> |
227 | getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, |
228 | GenericOp genericOp, OpOperand *opOperand) { |
229 | int64_t numOrigLoops = genericOp.getNumLoops(); |
230 | int64_t numInnerLoops = packInfo.getNumTiledLoops(); |
231 | int64_t numLoops = numOrigLoops + numInnerLoops; |
232 | AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); |
233 | llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim; |
234 | SmallVector<AffineExpr> exprs(origIndexingMap.getResults()); |
235 | |
236 | // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. |
237 | if (genericOp.isScalar(opOperand) || exprs.empty()) |
238 | return std::make_tuple(opOperand->get(), |
239 | AffineMap::get(numLoops, 0, exprs, b.getContext())); |
240 | |
241 | // Step 1. Construct the information of packing data dimensions; append inner |
242 | // dimensions to the indexing maps for the operand. |
243 | for (auto [index, expr] : llvm::enumerate(exprs)) { |
244 | if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { |
245 | int64_t dimPos = dimExpr.getPosition(); |
246 | domainDimToOperandDim[dimPos] = index; |
247 | continue; |
248 | } |
249 | } |
250 | SmallVector<int64_t> innerDimsPos; |
251 | SmallVector<OpFoldResult> innerTileSizes; |
252 | for (auto dimPos : packInfo.tiledDimsPos) { |
253 | if (!domainDimToOperandDim.count(dimPos)) |
254 | continue; |
255 | int64_t index = domainDimToOperandDim[dimPos]; |
256 | innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); |
257 | innerDimsPos.push_back(index); |
258 | exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); |
259 | } |
260 | |
261 | // Step 2. Handle outer dim permutations. |
262 | SmallVector<int64_t> outerDimsPerm; |
263 | if (!packInfo.outerDimsOnDomainPerm.empty()) { |
264 | outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); |
265 | |
266 | // Step 2.1: Fold transpose into the linalg.generic. |
267 | SmallVector<int64_t> inversedOuterPerm = |
268 | invertPermutationVector(packInfo.outerDimsOnDomainPerm); |
269 | for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) { |
270 | if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) { |
271 | int64_t dimPos = dimExpr.getPosition(); |
272 | exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); |
273 | continue; |
274 | } |
275 | assert(isa<AffineConstantExpr>(exprs[i]) && |
276 | "Attempted to permute non-constant and non-affine dim expression"); |
277 | } |
278 | // Step 2.2: Undo the transposition on `exprs` and propagate the |
279 | // transposition on the pack using outerDimsPerm. |
280 | if (!outerDimsPerm.empty()) { |
281 | SmallVector<AffineExpr> auxVec = exprs; |
282 | for (const auto &en : enumerate(outerDimsPerm)) |
283 | auxVec[en.index()] = exprs[en.value()]; |
284 | exprs = auxVec; |
285 | } |
286 | } |
287 | auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); |
288 | |
289 | // The operand does not have dimensions that relates to pack op. |
290 | if (innerDimsPos.empty() && outerDimsPerm.empty()) |
291 | return std::make_tuple(opOperand->get(), indexingMap); |
292 | |
293 | auto empty = linalg::PackOp::createDestinationTensor( |
294 | b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); |
295 | auto packedOperand = b.create<linalg::PackOp>( |
296 | loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, |
297 | /*padding=*/std::nullopt, outerDimsPerm); |
298 | return std::make_tuple(packedOperand, indexingMap); |
299 | } |
300 | |
301 | /// This function is a helper subroutine to pack a genericOp and return it. It |
302 | /// will create a new generic op with the packed operand and the packed output |
303 | /// according to packInfo when we attempt to push down unpack or bubble up pack |
304 | /// around it. Implicitly this will only work when a packInfo can be obtained. |
305 | /// This make sure that we are only using this function on parallel permuted |
306 | /// dimensions. |
307 | static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, |
308 | Value dest, AffineMap packedOutIndexingMap, |
309 | const PackInfo &packInfo, |
310 | bool isFoldableUnpackPack) { |
311 | Location loc = genericOp.getLoc(); |
312 | SmallVector<Value> inputOperands; |
313 | SmallVector<Value> inputOperandsFromUnpackedSource; |
314 | SmallVector<AffineMap> indexingMaps; |
315 | auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) { |
316 | return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() && |
317 | packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() && |
318 | llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles()); |
319 | }; |
320 | for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { |
321 | auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( |
322 | rewriter, loc, packInfo, genericOp, inputOperand); |
323 | auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>(); |
324 | auto packOp = packedOperand.getDefiningOp<linalg::PackOp>(); |
325 | if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) { |
326 | inputOperandsFromUnpackedSource.push_back(unpackOp.getSource()); |
327 | } else { |
328 | inputOperandsFromUnpackedSource.push_back(packedOperand); |
329 | } |
330 | inputOperands.push_back(packedOperand); |
331 | indexingMaps.push_back(packedIndexingMap); |
332 | } |
333 | |
334 | // If the unpack->pack sequences can be folded, replace use the sources of |
335 | // the unpack ops in any unpack->pack chains on the generic op operands. |
336 | if (isFoldableUnpackPack) { |
337 | inputOperands = inputOperandsFromUnpackedSource; |
338 | if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) { |
339 | auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>(); |
340 | if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) { |
341 | dest = destUnPack.getSource(); |
342 | } |
343 | } |
344 | } |
345 | |
346 | int64_t numInnerLoops = packInfo.getNumTiledLoops(); |
347 | SmallVector<utils::IteratorType> iterTypes = |
348 | genericOp.getIteratorTypesArray(); |
349 | iterTypes.append(numInnerLoops, utils::IteratorType::parallel); |
350 | |
351 | indexingMaps.push_back(packedOutIndexingMap); |
352 | |
353 | auto newGenericOp = rewriter.create<linalg::GenericOp>( |
354 | loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, |
355 | /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); |
356 | rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), |
357 | newGenericOp.getRegion().begin()); |
358 | return newGenericOp; |
359 | } |
360 | |
361 | /// Bubbles up linalg.pack op through a producer generic op. This |
362 | /// swap pack(generic) to generic(pack). The new generic op works on packed |
363 | /// domain; pack ops are created for input and output operands. E.g., |
364 | /// |
365 | /// #map0 = affine_map<(d0, d1) -> (d0, d1)> |
366 | /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> |
367 | /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> |
368 | /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32> |
369 | /// %3 = linalg.generic {indexing_maps = [#map0, #map0], |
370 | /// iterator_types = ["parallel", "parallel"]} |
371 | /// ins(%arg0 : tensor<?x?xf32>) |
372 | /// outs(%2 : tensor<?x?xf32>) { |
373 | /// ^bb0(%arg3: f32, %arg4: f32): |
374 | /// %4 = arith.addf %arg3, %arg3 : f32 |
375 | /// linalg.yield %4 : f32 |
376 | /// } -> tensor<?x?xf32> |
377 | /// %4 = linalg.pack %3 |
378 | /// inner_dims_pos = [0, 1] |
379 | /// inner_tiles = [8, 2] |
380 | /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> |
381 | /// |
382 | /// will be converted to |
383 | /// |
384 | /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> |
385 | /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> |
386 | /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
387 | /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> |
388 | /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> |
389 | /// %0 = affine.apply #map()[%dim] |
390 | /// %1 = affine.apply #map1()[%dim_0] |
391 | /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32> |
392 | /// %pack = linalg.pack %arg0 |
393 | /// inner_dims_pos = [0, 1] |
394 | /// inner_tiles = [8, 2] |
395 | /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32> |
396 | /// %3 = linalg.generic {indexing_maps = [#map2, #map2], |
397 | /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
398 | /// ins(%pack : tensor<?x?x8x2xf32>) |
399 | /// outs(%arg1 : tensor<?x?x8x2xf32>) { |
400 | /// ^bb0(%in: f32, %out: f32): |
401 | /// %4 = arith.addf %in, %in : f32 |
402 | /// linalg.yield %4 : f32 |
403 | /// } -> tensor<?x?x8x2xf32> |
404 | static FailureOr<GenericOp> |
405 | bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, |
406 | const ControlPropagationFn &controlFn) { |
407 | auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); |
408 | if (!genericOp) |
409 | return failure(); |
410 | |
411 | // User controlled propagation function. |
412 | if (!controlFn(&packOp.getSourceMutable())) |
413 | return failure(); |
414 | |
415 | // TODO: Enable propagation in the presence of linalg.index and |
416 | // tensor.extract, likely as a separate pattern as the pack information and |
417 | // propagation decision needs to be inferred from the region of the generic. |
418 | if (hasGatherSemantics(genericOp)) |
419 | return failure(); |
420 | |
421 | // TODO: Relax the restriction. We are able to bubble up the pack op through |
422 | // multi-result generic op. It just needs more work. |
423 | if (genericOp.getNumResults() != 1) |
424 | return failure(); |
425 | |
426 | // Bail-out if the result of the generic has multiple uses, as bubbling up |
427 | // creates recomputation if the generic has multiple users. |
428 | // TODO: Enable the case where every use is an identical pack op as no |
429 | // recomputation is needed in that case. |
430 | if (!genericOp->getResult(0).hasOneUse()) |
431 | return failure(); |
432 | |
433 | // TODO: Add an option for allowing padding values. It could introduce |
434 | // undefined behavior if we unconditionally propagate pack op through all |
435 | // the ops. E.g., if the padding value is zero and there are division ops in |
436 | // a generic op. Some values of padding area could be NaN (0/0). |
437 | if (packOp.getPaddingValue()) |
438 | return failure(); |
439 | |
440 | OpOperand *opOperand = genericOp.getDpsInitOperand(0); |
441 | auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); |
442 | if (failed(packInfo)) |
443 | return failure(); |
444 | |
445 | // We want to move the pack not the generic. |
446 | OpBuilder::InsertionGuard guard(rewriter); |
447 | rewriter.setInsertionPoint(genericOp); |
448 | |
449 | // We need to handle two cases: |
450 | // 1) The linalg.pack destination is a tensor.empty. If this is the case, we |
451 | // create a new tensor.empty to avoid breaking dominance, as we are moving the |
452 | // linalg.pack above the linalg.generic. |
453 | // 2) The destination is not a tensor.empty. In this case we can replace only |
454 | // if the destination of the linalg.pack dominates the linalg.generic. |
455 | Value packOpDest = packOp.getDest(); |
456 | if (!packOpDest.hasOneUse()) |
457 | return failure(); |
458 | if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { |
459 | packOpDest = rewriter.create<tensor::EmptyOp>( |
460 | genericOp->getLoc(), emptyOp.getMixedSizes(), |
461 | emptyOp.getType().getElementType()); |
462 | } else { |
463 | DominanceInfo dom(genericOp); |
464 | if (!dom.properlyDominates(packOpDest, genericOp)) |
465 | return failure(); |
466 | } |
467 | |
468 | // Rebuild the indexing map for the corresponding init operand. |
469 | auto [packedOutOperand, packedOutIndexingMap] = |
470 | getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, |
471 | genericOp, opOperand); |
472 | |
473 | // If the dps init operand of the generic is a tensor.empty forward the pack |
474 | // op destination. |
475 | Value dest = packedOutOperand; |
476 | if (auto initTensor = genericOp.getDpsInitOperand(0) |
477 | ->get() |
478 | .getDefiningOp<tensor::EmptyOp>()) { |
479 | dest = packOpDest; |
480 | } |
481 | // pack(unpack) isn't naively foldable because the unpack op can be from |
482 | // an arbitrary domain so we need to keep both. |
483 | return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, |
484 | *packInfo, /*isFoldableUnpackPack=*/false); |
485 | } |
486 | |
487 | /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. |
488 | struct BubbleUpPackOpThroughGenericOpPattern |
489 | : public OpRewritePattern<linalg::PackOp> { |
490 | public: |
491 | BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, |
492 | ControlPropagationFn fun) |
493 | : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {} |
494 | |
495 | LogicalResult matchAndRewrite(linalg::PackOp packOp, |
496 | PatternRewriter &rewriter) const override { |
497 | auto genericOp = |
498 | bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); |
499 | if (failed(genericOp)) |
500 | return failure(); |
501 | rewriter.replaceOp(packOp, genericOp->getResults()); |
502 | return success(); |
503 | } |
504 | |
505 | private: |
506 | ControlPropagationFn controlFn; |
507 | }; |
508 | |
509 | /// Propagate a linalg.pack operation up through a tensor.pad. The idea is to |
510 | /// add as many zero padding dimensions in `high` and `low` based on the number |
511 | /// of point loops. |
512 | class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> { |
513 | public: |
514 | BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) |
515 | : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {} |
516 | |
517 | LogicalResult matchAndRewrite(linalg::PackOp packOp, |
518 | PatternRewriter &rewriter) const override { |
519 | auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>(); |
520 | if (!padOp) |
521 | return failure(); |
522 | |
523 | // User controlled propagation function. |
524 | if (!controlFn(&packOp.getSourceMutable())) |
525 | return failure(); |
526 | |
527 | // TODO: Enable padding when the padding values are the same. |
528 | if (packOp.getPaddingValue()) |
529 | return failure(); |
530 | |
531 | // Fail for non-constant padding values. The body of the pad could |
532 | // depend on the padding indices and/or properties of the padded |
533 | // tensor so for now we fail. |
534 | // TODO: Support non-constant padding values. |
535 | Value paddingVal = padOp.getConstantPaddingValue(); |
536 | if (!paddingVal) |
537 | return failure(); |
538 | |
539 | if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>()) |
540 | return failure(); |
541 | |
542 | ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); |
543 | |
544 | // Bail out if one of the padded dimension is a tiled one. |
545 | llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); |
546 | llvm::SmallBitVector innerDims(paddedDims.size()); |
547 | for (int64_t dim : innerDimsPos) |
548 | innerDims.flip(dim); |
549 | if (paddedDims.anyCommon(RHS: innerDims)) |
550 | return failure(); |
551 | |
552 | Location loc = padOp->getLoc(); |
553 | OpBuilder::InsertionGuard guard(rewriter); |
554 | rewriter.setInsertionPoint(padOp); |
555 | |
556 | ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); |
557 | SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles(); |
558 | auto empty = linalg::PackOp::createDestinationTensor( |
559 | rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos, |
560 | outerDimsPerm); |
561 | auto sourcePack = rewriter.create<linalg::PackOp>( |
562 | loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, |
563 | /*padding=*/std::nullopt, outerDimsPerm); |
564 | |
565 | // If we have `outer_dims_perms` we need to adjust the padded dimensions. |
566 | SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); |
567 | SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); |
568 | if (!outerDimsPerm.empty()) { |
569 | applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); |
570 | applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); |
571 | } |
572 | // The tiled dimensions were verified to be unpadded above, so here we |
573 | // just append 0 for the inner tile dimensions. |
574 | size_t pointLoopsSize = innerDimsPos.size(); |
575 | lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
576 | highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
577 | |
578 | auto newPadOp = rewriter.create<tensor::PadOp>( |
579 | loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal, |
580 | padOp.getNofold()); |
581 | |
582 | // If the pad has more than one user, create an unpack on the new pad to |
583 | // replace the other uses. |
584 | if (!padOp->hasOneUse()) { |
585 | auto unpackEmpty = linalg::UnPackOp::createDestinationTensor( |
586 | rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm); |
587 | Value unpackedPad = rewriter.create<linalg::UnPackOp>( |
588 | loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm); |
589 | rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack); |
590 | } |
591 | |
592 | // Replace the pack with the new pad. |
593 | rewriter.replaceOp(packOp, newPadOp.getResult()); |
594 | |
595 | return success(); |
596 | } |
597 | |
598 | private: |
599 | ControlPropagationFn controlFn; |
600 | }; |
601 | |
602 | /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices. |
603 | /// |
604 | /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and |
605 | /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the |
606 | /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most |
607 | /// non-unit projected dims in pos [2, 3] is 2. |
608 | /// |
609 | /// If all candidates in a reassociation are unit dims, it chooses the |
610 | /// inner-most dim pos. |
611 | static SmallVector<int64_t> |
612 | projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos, |
613 | ArrayRef<ReassociationIndices> reassocIndices, |
614 | ArrayRef<int64_t> targetShape) { |
615 | SmallVector<int64_t> projectedDimsPos; |
616 | for (auto pos : dimsPos) { |
617 | // In the case all dims are unit, this will return the inner-most one. |
618 | int64_t projectedPos = reassocIndices[pos].back(); |
619 | for (auto i : llvm::reverse(reassocIndices[pos])) { |
620 | int64_t dim = targetShape[i]; |
621 | if (dim > 1 || ShapedType::isDynamic(dim)) { |
622 | projectedPos = i; |
623 | break; |
624 | } |
625 | } |
626 | projectedDimsPos.push_back(projectedPos); |
627 | } |
628 | return projectedDimsPos; |
629 | } |
630 | |
631 | /// Check if all dims in dimsPos are divisible by the corresponding tile sizes. |
632 | static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos, |
633 | ArrayRef<int64_t> shape, |
634 | ArrayRef<int64_t> tileSizes) { |
635 | for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { |
636 | int64_t dim = shape[pos]; |
637 | if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) |
638 | return false; |
639 | } |
640 | return true; |
641 | } |
642 | |
643 | /// Permutate the reassociation indices and reindex them in the sequence order. |
644 | /// Returns the next dim pos in the sequence. |
645 | /// |
646 | /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it |
647 | /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into |
648 | /// [[0], [1, 2]]. |
649 | static int64_t applyPermutationAndReindexReassoc( |
650 | SmallVector<ReassociationIndices> &reassocIndices, |
651 | ArrayRef<int64_t> permutation) { |
652 | if (!permutation.empty()) |
653 | applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation); |
654 | int64_t nextPos = 0; |
655 | for (ReassociationIndices &indices : reassocIndices) { |
656 | for (auto &index : indices) { |
657 | index = nextPos; |
658 | nextPos += 1; |
659 | } |
660 | } |
661 | return nextPos; |
662 | } |
663 | |
664 | /// Bubble up pack op through collapse shape op when the packed dims can be |
665 | /// projected to the dims before collapsing. This is possible when the inner |
666 | /// tile sizes can divide the projected dims. |
667 | /// |
668 | /// For example: |
669 | /// |
670 | /// %collapsed = tensor.collapse_shape %in [[0, 1], 2] |
671 | /// : tensor<?x16x4xf32> into tensor<?x4xf32> |
672 | /// %pack = linalg.pack %collapsed outer_dims_perm = [0, 1] |
673 | /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty |
674 | /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32> |
675 | /// |
676 | /// can be transformed into: |
677 | /// |
678 | /// %pack = linalg.pack %in outer_dims_perm = [1, 2] |
679 | /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty |
680 | /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> |
681 | /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4] |
682 | /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1> |
683 | static LogicalResult |
684 | bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, |
685 | linalg::PackOp packOp, |
686 | PatternRewriter &rewriter) { |
687 | SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles(); |
688 | ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); |
689 | ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); |
690 | |
691 | ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape(); |
692 | SmallVector<ReassociationIndices> reassocIndices = |
693 | collapseOp.getReassociationIndices(); |
694 | // Project inner tile pos to the dim pos before collapsing. For example, if |
695 | // dims [x, y] is collapsed into [z], packing on dim z can be projected back |
696 | // to pack on dim y. |
697 | // |
698 | // Project to inner-most non-unit dims to increase the chance that they can be |
699 | // divided by the inner tile sizes. This is correct because for [..., x, 1], |
700 | // packing on dim 1 is equivalent to packing on dim x. |
701 | SmallVector<int64_t> projectedInnerDimsPos = |
702 | projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape); |
703 | |
704 | if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape, |
705 | innerTileSizes)) { |
706 | return failure(); |
707 | } |
708 | // Expand the outer dims permutation with the associated source dims for the |
709 | // new permutation after bubbling. This is because moving a collapsed dim is |
710 | // equivalent to moving the associated source dims together. |
711 | SmallVector<int64_t> newOuterDimsPerm; |
712 | for (auto outerPos : outerDimsPerm) |
713 | llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]); |
714 | |
715 | auto emptyOp = linalg::PackOp::createDestinationTensor( |
716 | rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), |
717 | projectedInnerDimsPos, newOuterDimsPerm); |
718 | auto newPackOp = rewriter.create<linalg::PackOp>( |
719 | packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, |
720 | packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); |
721 | |
722 | SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; |
723 | // First apply the permutation on the reassociations of the outer dims. |
724 | // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] |
725 | // -> [[0], [1, 2]] |
726 | int64_t nextPos = |
727 | applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); |
728 | // Then add direct mapping for the inner tile dims. |
729 | for (size_t i = 0; i < innerDimsPos.size(); ++i) { |
730 | newReassocIndices.push_back({nextPos}); |
731 | nextPos += 1; |
732 | } |
733 | |
734 | auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( |
735 | collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); |
736 | rewriter.replaceOp(packOp, newCollapseOp); |
737 | |
738 | return success(); |
739 | } |
740 | |
741 | /// Project dimsPos to their collapsed positions in the reassocIndices. |
742 | /// |
743 | /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices |
744 | /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0, |
745 | /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos |
746 | /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3. |
747 | static SmallVector<int64_t> |
748 | projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos, |
749 | ArrayRef<ReassociationIndices> reassocIndices) { |
750 | SmallVector<int64_t> projectedPos; |
751 | |
752 | // Map each dimension to the position of corresponding reassociation index. |
753 | for (auto pos : dimsPos) { |
754 | for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { |
755 | // If the dimension is present in the current indices group, the group |
756 | // position within the reassociation map is the desired projected |
757 | // dimension position. |
758 | if (llvm::is_contained(indices, pos)) { |
759 | projectedPos.push_back(idx); |
760 | break; |
761 | } |
762 | } |
763 | } |
764 | assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection"); |
765 | |
766 | return projectedPos; |
767 | } |
768 | |
769 | /// Bubble up pack op through expand shape op. |
770 | /// |
771 | /// For example: |
772 | /// |
773 | /// %expand = tensor.expand_shape %in [[0], [1, 2]] |
774 | /// : tensor<?x64xf32> into tensor<?x4x16xf32> |
775 | /// %pack = linalg.pack %expand outer_dims_perm = [0, 1] |
776 | /// inner_dims_pos = [2] inner_tiles = [8] into %empty |
777 | /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32> |
778 | /// |
779 | /// can be transformed into: |
780 | /// |
781 | /// %pack = linalg.pack %in outer_dims_perm = [1, 2] |
782 | /// inner_dims_pos = [1] inner_tiles = [8] into %empty |
783 | /// : tensor<?x64xf32> -> tensor<?x8x8xf32> |
784 | /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]] |
785 | /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32> |
786 | static LogicalResult |
787 | bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, |
788 | linalg::PackOp packOp, |
789 | PatternRewriter &rewriter) { |
790 | // Outer dimensions permutation is not supported currently. |
791 | // TODO: Handle outer_dims_perm variants. |
792 | ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); |
793 | if (!outerDimsPerm.empty() && !isIdentityPermutation(permutation: outerDimsPerm)) { |
794 | return rewriter.notifyMatchFailure(packOp, |
795 | "non-identity outer dims perm NYI"); |
796 | } |
797 | |
798 | // Validate dimensions' relations between shape expansion and packing. |
799 | SmallVector<ReassociationIndices, 4> reassoc = |
800 | expandOp.getReassociationIndices(); |
801 | ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos(); |
802 | llvm::SetVector<int64_t> packDimsPos(llvm::from_range, packInnerDims); |
803 | |
804 | for (auto [idx, indices] : llvm::enumerate(reassoc)) { |
805 | // For each expand_shape reassociation, figure out which dimensions get |
806 | // packed if any. |
807 | llvm::SetVector<int64_t> expandDimPos(llvm::from_range, indices); |
808 | llvm::SetVector<int64_t> packedDims = |
809 | llvm::set_intersection(packDimsPos, expandDimPos); |
810 | |
811 | // The expanded dimension is not packed so, it does not affect moving pack |
812 | // before shape expansion - simply continue. |
813 | if (packedDims.empty()) |
814 | continue; |
815 | // Shape expansion cannot be propagated when multiple expanded dimension are |
816 | // packed - in this case operation reordering would affect final element |
817 | // positions and/or shapes can no longer be projected. |
818 | if (packedDims.size() != 1) |
819 | return rewriter.notifyMatchFailure( |
820 | packOp, "only one of the expanded dimensions can be packed"); |
821 | // Only the inner-most expanded dimension should be packed. Otherwise, |
822 | // elements order will be affected after operation reordering. |
823 | if (packedDims.front() != indices.back()) |
824 | return rewriter.notifyMatchFailure( |
825 | packOp, "can only pack the inner-most expanded dimension"); |
826 | } |
827 | |
828 | // Project pack.inner_dims_pos to positions before shape expansion. |
829 | SmallVector<int64_t> projectedInnerDimsPos = |
830 | projectDimsPosIntoReassocPos(packInnerDims, reassoc); |
831 | |
832 | // Project the shape expansion to new packed shape. |
833 | // The pack.outer_dims_perm is restricted to identity so, the permutation can |
834 | // be omitted for simplicity. |
835 | // TODO: Account for outer dimensions permutation. |
836 | // |
837 | // If reassociation is not possible, then reordering cannot happen. |
838 | // This can be caused by pack padding affecting previously expanded |
839 | // dimensions or packing extending dimensions. |
840 | RankedTensorType newPackType = linalg::PackOp::inferPackedType( |
841 | expandOp.getSrcType(), packOp.getStaticInnerTiles(), |
842 | projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); |
843 | auto reassocExpand = |
844 | getReassociationIndicesForReshape(newPackType, packOp.getDestType()); |
845 | if (!reassocExpand) |
846 | return rewriter.notifyMatchFailure( |
847 | packOp, "could not reassociate dims after bubbling up"); |
848 | |
849 | Value destTensor = linalg::PackOp::createDestinationTensor( |
850 | rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), |
851 | projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); |
852 | Value packedVal = rewriter.create<linalg::PackOp>( |
853 | packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, |
854 | packOp.getMixedTiles(), packOp.getPaddingValue(), |
855 | /*outerDimsPerm=*/SmallVector<int64_t>{}); |
856 | |
857 | Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>( |
858 | packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand); |
859 | rewriter.replaceOp(packOp, newExpandOp); |
860 | |
861 | return success(); |
862 | } |
863 | |
864 | class BubbleUpPackOpThroughReshapeOp final |
865 | : public OpRewritePattern<linalg::PackOp> { |
866 | public: |
867 | BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun) |
868 | : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {} |
869 | |
870 | LogicalResult matchAndRewrite(linalg::PackOp packOp, |
871 | PatternRewriter &rewriter) const override { |
872 | Operation *srcOp = packOp.getSource().getDefiningOp(); |
873 | // Currently only support when the pack op is the only user. |
874 | if (!srcOp || !(srcOp->getNumResults() == 1) || |
875 | !srcOp->getResult(idx: 0).hasOneUse()) { |
876 | return failure(); |
877 | } |
878 | // Currently only support static inner tile sizes. |
879 | if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic)) |
880 | return failure(); |
881 | |
882 | // User controlled propagation function. |
883 | if (!controlFn(&packOp.getSourceMutable())) |
884 | return failure(); |
885 | |
886 | return TypeSwitch<Operation *, LogicalResult>(srcOp) |
887 | .Case([&](tensor::CollapseShapeOp op) { |
888 | return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter); |
889 | }) |
890 | .Case([&](tensor::ExpandShapeOp op) { |
891 | return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter); |
892 | }) |
893 | .Default([](Operation *) { return failure(); }); |
894 | } |
895 | |
896 | private: |
897 | ControlPropagationFn controlFn; |
898 | }; |
899 | |
900 | /// Push down unpack op through expand shape op when the packed dims can be |
901 | /// projected to the dims after expanding. This is possible when the inner tile |
902 | /// sizes can divide the projected dims. |
903 | /// |
904 | /// For example: |
905 | /// |
906 | /// %unpack = linalg.unpack %in outer_dims_perm = [0, 1] |
907 | /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty |
908 | /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32> |
909 | /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]] |
910 | /// : tensor<?x256xf32> into tensor<?x256x256xf32> |
911 | /// |
912 | /// can be transformed into: |
913 | /// |
914 | /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]] |
915 | /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> |
916 | /// %unpack = linalg.unpack %expanded outer_dims_perm = [0, 1, 2] |
917 | /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty |
918 | /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> |
919 | static LogicalResult pushDownUnPackOpThroughExpandShape( |
920 | linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp, |
921 | PatternRewriter &rewriter, ControlPropagationFn controlFn) { |
922 | // User controlled propagation function. |
923 | if (!controlFn(&expandOp.getSrcMutable())) |
924 | return failure(); |
925 | |
926 | SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles(); |
927 | ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos(); |
928 | ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm(); |
929 | |
930 | auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType()); |
931 | if (!expandTy) |
932 | return failure(); |
933 | ArrayRef<int64_t> dstShape = expandTy.getShape(); |
934 | SmallVector<ReassociationIndices> reassocIndices = |
935 | expandOp.getReassociationIndices(); |
936 | // Project inner tile pos to the dim pos after expanding. For example, if dims |
937 | // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack |
938 | // on dim y. |
939 | // |
940 | // Project to inner-most non-unit dims to increase the chance that they can be |
941 | // divided by the inner tile sizes. This is correct because for [..., x, 1], |
942 | // unpacking on dim 1 is equivalent to unpacking on dim x. |
943 | SmallVector<int64_t> projectedInnerDimsPos = |
944 | projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape); |
945 | |
946 | if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape, |
947 | innerTileSizes)) { |
948 | return failure(); |
949 | } |
950 | // Expand the outer dims permutation with the associated expanded dims for the |
951 | // new permutation after pushing. This is because moving a source dim is |
952 | // equivalent to moving the associated expanded dims together. |
953 | SmallVector<int64_t> newOuterDimsPerm; |
954 | for (auto outerPos : outerDimsPerm) |
955 | llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]); |
956 | |
957 | SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; |
958 | // First apply the permutation on the reassociations of the outer dims. |
959 | // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] |
960 | // -> [[0], [1, 2]] |
961 | int64_t nextPos = |
962 | applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); |
963 | // Then add direct mapping for the inner tile dims. |
964 | for (size_t i = 0; i < innerDimsPos.size(); ++i) { |
965 | newReassocIndices.push_back({nextPos}); |
966 | nextPos += 1; |
967 | } |
968 | |
969 | RankedTensorType newExpandType = linalg::PackOp::inferPackedType( |
970 | expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); |
971 | auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( |
972 | expandOp.getLoc(), newExpandType, unPackOp.getSource(), |
973 | newReassocIndices); |
974 | |
975 | auto emptyOp = linalg::UnPackOp::createDestinationTensor( |
976 | rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), |
977 | projectedInnerDimsPos, newOuterDimsPerm); |
978 | auto newUnPackOp = rewriter.create<linalg::UnPackOp>( |
979 | unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, |
980 | projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); |
981 | rewriter.replaceOp(expandOp, newUnPackOp); |
982 | |
983 | return success(); |
984 | } |
985 | |
986 | class PushDownUnPackOpThroughReshapeOp final |
987 | : public OpRewritePattern<linalg::UnPackOp> { |
988 | public: |
989 | PushDownUnPackOpThroughReshapeOp(MLIRContext *context, |
990 | ControlPropagationFn fun) |
991 | : OpRewritePattern<linalg::UnPackOp>(context), controlFn(std::move(fun)) { |
992 | } |
993 | |
994 | LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp, |
995 | PatternRewriter &rewriter) const override { |
996 | Value result = unPackOp.getResult(); |
997 | // Currently only support unpack op with the single user. |
998 | if (!result.hasOneUse()) { |
999 | return failure(); |
1000 | } |
1001 | // Currently only support static inner tile sizes. |
1002 | if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic)) |
1003 | return failure(); |
1004 | |
1005 | Operation *consumerOp = *result.user_begin(); |
1006 | return TypeSwitch<Operation *, LogicalResult>(consumerOp) |
1007 | .Case([&](tensor::ExpandShapeOp op) { |
1008 | return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter, |
1009 | controlFn); |
1010 | }) |
1011 | .Default([](Operation *) { return failure(); }); |
1012 | } |
1013 | |
1014 | private: |
1015 | ControlPropagationFn controlFn; |
1016 | }; |
1017 | |
1018 | // TODO: Relax this restriction. We should unpack a generic op also |
1019 | // in the presence of multiple unpack ops as producers. |
1020 | /// Return the unpacked operand, if present, for the current generic op. |
1021 | static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { |
1022 | OpOperand *unPackedOperand = nullptr; |
1023 | for (OpOperand &operand : genericOp->getOpOperands()) { |
1024 | auto unPackOp = operand.get().getDefiningOp<linalg::UnPackOp>(); |
1025 | if (!unPackOp) |
1026 | continue; |
1027 | if (unPackedOperand) |
1028 | return failure(); |
1029 | unPackedOperand = &operand; |
1030 | } |
1031 | if (!unPackedOperand) |
1032 | return failure(); |
1033 | return unPackedOperand; |
1034 | } |
1035 | |
1036 | /// Push down a linalg.unpack op through a generic op. |
1037 | /// The new generic op works on packed domain; pack ops are created for input |
1038 | /// and output operands. A linalg.unpack op is inserted right after the packed |
1039 | /// generic. E.g. |
1040 | /// |
1041 | /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
1042 | /// |
1043 | /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. |
1044 | /// |
1045 | /// %0 = tensor.empty() : tensor<12x56x56x64xf32> |
1046 | /// %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] |
1047 | /// inner_dims_pos = [3] inner_tiles = [32] into %0 |
1048 | /// %2 = linalg.generic {indexing_maps = [#map], |
1049 | /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
1050 | /// outs(%1 : tensor<12x56x56x64xf32>) { |
1051 | /// ^bb0(%out : f32): |
1052 | /// linalg.yield %out : f32 |
1053 | /// } -> tensor<12x56x56x64xf32> |
1054 | /// |
1055 | /// will be converted to |
1056 | /// |
1057 | /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> |
1058 | /// |
1059 | /// %0 = tensor.empty() : tensor<12x56x56x64xf32> |
1060 | /// %1 = linalg.generic {indexing_maps = [#map], |
1061 | /// iterator_types = ["parallel", "parallel", "parallel", |
1062 | /// "parallel", "parallel"]} |
1063 | /// outs(%arg0 : tensor<12x2x56x56x32xf32>) { |
1064 | /// ^bb0(%out : f32): |
1065 | /// linalg.yield %out : f32 |
1066 | /// } -> tensor<12x2x56x56x32xf32> |
1067 | /// %2 = linalg.unpack %1 outer_dims_perm = [0, 3, 1, 2] |
1068 | /// inner_dims_pos = [3] inner_tiles = [32] into %0 |
1069 | /// |
1070 | static FailureOr<std::tuple<GenericOp, Value>> |
1071 | pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, |
1072 | ControlPropagationFn controlFn) { |
1073 | if (genericOp.getNumResults() != 1) |
1074 | return failure(); |
1075 | |
1076 | if (hasGatherSemantics(genericOp)) |
1077 | return failure(); |
1078 | |
1079 | // Collect the unPacked operand, if present. |
1080 | auto maybeUnPackedOperand = getUnPackedOperand(genericOp); |
1081 | if (failed(maybeUnPackedOperand)) |
1082 | return failure(); |
1083 | OpOperand *unPackedOperand = *(maybeUnPackedOperand); |
1084 | |
1085 | // Extract packing information. |
1086 | linalg::UnPackOp producerUnPackOp = |
1087 | unPackedOperand->get().getDefiningOp<linalg::UnPackOp>(); |
1088 | assert(producerUnPackOp && "expect a valid UnPackOp"); |
1089 | |
1090 | if (!controlFn(unPackedOperand)) |
1091 | return failure(); |
1092 | |
1093 | auto packInfo = |
1094 | getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); |
1095 | if (failed(packInfo)) |
1096 | return failure(); |
1097 | |
1098 | // Rebuild the indexing map for the corresponding init operand. |
1099 | auto [packedOutOperand, packedOutIndexingMap] = |
1100 | getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, |
1101 | genericOp, genericOp.getDpsInitOperand(0)); |
1102 | auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>(); |
1103 | |
1104 | // If the dps init operand of the generic is a tensor.empty, do not pack it |
1105 | // and forward the new tensor.empty as a destination. |
1106 | Value dest = packedOutOperand; |
1107 | if (auto initTensor = genericOp.getDpsInitOperand(0) |
1108 | ->get() |
1109 | .getDefiningOp<tensor::EmptyOp>()) { |
1110 | if (destPack) |
1111 | dest = destPack.getDest(); |
1112 | } |
1113 | |
1114 | // Pack the genericOp. |
1115 | // pack(unpack) is foldable in this case. This is because in pushing down the |
1116 | // unpack, by default we will populate an additional pack op after the unpack. |
1117 | // This guarantees them to be foldable. |
1118 | GenericOp newGenericOp = |
1119 | packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo, |
1120 | /*isFoldableUnpackPack=*/true); |
1121 | Value newResult = |
1122 | newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); |
1123 | |
1124 | // If the output is unaffected, no need to unpack. |
1125 | if (!destPack) |
1126 | return std::make_tuple(newGenericOp, newResult); |
1127 | |
1128 | auto mixedTiles = destPack.getMixedTiles(); |
1129 | auto innerDimsPos = destPack.getInnerDimsPos(); |
1130 | auto outerDimsPerm = destPack.getOuterDimsPerm(); |
1131 | |
1132 | // Insert an unPackOp right after the packed generic. |
1133 | Value unPackOpRes = |
1134 | rewriter |
1135 | .create<linalg::UnPackOp>(genericOp.getLoc(), newResult, |
1136 | destPack.getSource(), innerDimsPos, |
1137 | mixedTiles, outerDimsPerm) |
1138 | .getResult(); |
1139 | |
1140 | return std::make_tuple(newGenericOp, unPackOpRes); |
1141 | } |
1142 | |
1143 | // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. |
1144 | struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { |
1145 | public: |
1146 | PushDownUnPackOpThroughGenericOp(MLIRContext *context, |
1147 | ControlPropagationFn fun) |
1148 | : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} |
1149 | |
1150 | LogicalResult matchAndRewrite(GenericOp genericOp, |
1151 | PatternRewriter &rewriter) const override { |
1152 | auto genericAndRepl = |
1153 | pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn); |
1154 | if (failed(genericAndRepl)) |
1155 | return failure(); |
1156 | rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); |
1157 | return success(); |
1158 | } |
1159 | |
1160 | private: |
1161 | ControlPropagationFn controlFn; |
1162 | }; |
1163 | |
1164 | /// Propagate a linalg.unpack operation through a tensor.pad. The idea is to |
1165 | /// add as many zero padding dimensions in `high` and `low` based on the number |
1166 | /// of point loops. |
1167 | struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { |
1168 | PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) |
1169 | : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {} |
1170 | |
1171 | LogicalResult matchAndRewrite(tensor::PadOp padOp, |
1172 | PatternRewriter &rewriter) const override { |
1173 | linalg::UnPackOp unpackOp = |
1174 | padOp.getSource().getDefiningOp<linalg::UnPackOp>(); |
1175 | if (!unpackOp) |
1176 | return failure(); |
1177 | |
1178 | if (!controlFn(&padOp.getSourceMutable())) |
1179 | return failure(); |
1180 | |
1181 | Location loc = padOp.getLoc(); |
1182 | // Bail out if one of the padded dimension is a tiled one. |
1183 | llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); |
1184 | ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); |
1185 | llvm::SmallBitVector innerDims(paddedDims.size()); |
1186 | for (int64_t dim : innerDimsPos) |
1187 | innerDims.flip(dim); |
1188 | if (paddedDims.anyCommon(RHS: innerDims)) |
1189 | return failure(); |
1190 | |
1191 | Value paddingVal = padOp.getConstantPaddingValue(); |
1192 | if (!paddingVal) |
1193 | return failure(); |
1194 | |
1195 | // If we have `outer_dims_perms` we need to adjust the padded dimensions. |
1196 | ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); |
1197 | SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); |
1198 | SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); |
1199 | if (!outerDimsPerm.empty()) { |
1200 | applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); |
1201 | applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); |
1202 | } |
1203 | // Add zero padding for the point loops. |
1204 | size_t pointLoopsSize = innerDimsPos.size(); |
1205 | lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
1206 | highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
1207 | |
1208 | auto newPadOp = rewriter.create<tensor::PadOp>( |
1209 | loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, |
1210 | paddingVal, padOp.getNofold()); |
1211 | |
1212 | // Inject the linalg.unpack right after the packed padOp. |
1213 | Value outputUnPack = rewriter.create<tensor::EmptyOp>( |
1214 | loc, padOp.getResultType().getShape(), |
1215 | padOp.getResultType().getElementType()); |
1216 | |
1217 | Value replacement = rewriter.create<linalg::UnPackOp>( |
1218 | loc, newPadOp.getResult(), outputUnPack, innerDimsPos, |
1219 | unpackOp.getMixedTiles(), outerDimsPerm); |
1220 | rewriter.replaceOp(padOp, replacement); |
1221 | return success(); |
1222 | } |
1223 | |
1224 | private: |
1225 | ControlPropagationFn controlFn; |
1226 | }; |
1227 | |
1228 | } // namespace |
1229 | |
1230 | void mlir::linalg::populateDataLayoutPropagationPatterns( |
1231 | RewritePatternSet &patterns, |
1232 | const ControlPropagationFn &controlPackUnPackPropagation) { |
1233 | patterns |
1234 | .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp, |
1235 | BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp, |
1236 | PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( |
1237 | arg: patterns.getContext(), args: controlPackUnPackPropagation); |
1238 | } |
1239 |
Definitions
- hasGatherSemantics
- PackInfo
- getNumTiledLoops
- getPackingInfoFromOperand
- computeOuterDims
- getOrCreatePackedViewOfOperand
- packGenericOp
- bubbleUpPackOpThroughGenericOp
- BubbleUpPackOpThroughGenericOpPattern
- BubbleUpPackOpThroughGenericOpPattern
- matchAndRewrite
- BubbleUpPackThroughPadOp
- BubbleUpPackThroughPadOp
- matchAndRewrite
- projectToInnerMostNonUnitDimsPos
- isDimsDivisibleByTileSizes
- applyPermutationAndReindexReassoc
- bubbleUpPackOpThroughCollapseShape
- projectDimsPosIntoReassocPos
- bubbleUpPackOpThroughExpandShape
- BubbleUpPackOpThroughReshapeOp
- BubbleUpPackOpThroughReshapeOp
- matchAndRewrite
- pushDownUnPackOpThroughExpandShape
- PushDownUnPackOpThroughReshapeOp
- PushDownUnPackOpThroughReshapeOp
- matchAndRewrite
- getUnPackedOperand
- pushDownUnPackOpThroughGenericOp
- PushDownUnPackOpThroughGenericOp
- PushDownUnPackOpThroughGenericOp
- matchAndRewrite
- PushDownUnPackThroughPadOp
- PushDownUnPackThroughPadOp
- matchAndRewrite
Improve your Profiling and Debugging skills
Find out more