1 | //===- Transforms.cpp - Linalg transformations as patterns ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements logic and helpers to expose Linalg transforms as rewrite |
10 | // patterns. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
19 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
20 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
22 | #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" |
23 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
24 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
25 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
26 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
27 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
28 | #include "mlir/IR/AffineExpr.h" |
29 | #include "mlir/IR/Matchers.h" |
30 | #include "mlir/Pass/Pass.h" |
31 | #include "mlir/Support/LLVM.h" |
32 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
33 | #include "llvm/ADT/ScopeExit.h" |
34 | #include "llvm/ADT/TypeSwitch.h" |
35 | #include "llvm/Support/Debug.h" |
36 | #include "llvm/Support/raw_ostream.h" |
37 | #include <type_traits> |
38 | #include <utility> |
39 | |
40 | #define DEBUG_TYPE "linalg-transforms" |
41 | |
42 | using namespace mlir; |
43 | using namespace mlir::linalg; |
44 | |
45 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
46 | #define DBGSNL() (llvm::dbgs() << "\n") |
47 | |
48 | //===----------------------------------------------------------------------===// |
49 | // Transformations exposed as functional-style API calls. |
50 | //===----------------------------------------------------------------------===// |
51 | |
52 | //===----------------------------------------------------------------------===// |
53 | // peelLoop transformation. |
54 | //===----------------------------------------------------------------------===// |
55 | |
56 | /// Try to peel and canonicalize loop `op` and return the new result. |
57 | /// Also applies affine_min/max bounds simplification on the fly where relevant. |
58 | // TODO: Add support for scf.parallel and affine.for loops. |
59 | SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter, |
60 | Operation *op) { |
61 | return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) |
62 | .Case<scf::ForOp>(caseFn: [&](scf::ForOp forOp) { |
63 | scf::ForOp partialIteration; |
64 | if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp, |
65 | partialIteration))) |
66 | return partialIteration->getResults(); |
67 | assert(!partialIteration && "expected that loop was not peeled" ); |
68 | return forOp->getResults(); |
69 | }) |
70 | .Default(defaultFn: [&](Operation *op) { return op->getResults(); }); |
71 | } |
72 | |
73 | /// Peel 'loops' and applies affine_min/max bounds simplification on the fly |
74 | /// where relevant. |
75 | void mlir::linalg::peelLoops(RewriterBase &rewriter, |
76 | ArrayRef<scf::ForOp> loops) { |
77 | for (auto loopOp : loops) |
78 | peelLoop(rewriter, loopOp); |
79 | } |
80 | |
81 | //===----------------------------------------------------------------------===// |
82 | // pack transformation. |
83 | //===----------------------------------------------------------------------===// |
84 | |
85 | #ifndef NDEBUG |
86 | /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim). |
87 | static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { |
88 | bool found = false; |
89 | for (AffineExpr e : map.getResults()) { |
90 | if (!e.isFunctionOfDim(position: dim)) |
91 | continue; |
92 | if (found) |
93 | return false; |
94 | found = true; |
95 | } |
96 | return true; |
97 | } |
98 | #endif // NDEBUG |
99 | |
100 | /// Return the index of the first result of `map` that is a function of |
101 | /// AffineDimExpr(dim), std::nullopt otherwise. |
102 | static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map, |
103 | int64_t dim) { |
104 | for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { |
105 | AffineExpr expr = map.getResult(idx: i); |
106 | if (!expr.isFunctionOfDim(position: dim)) |
107 | continue; |
108 | return i; |
109 | } |
110 | return std::nullopt; |
111 | } |
112 | |
113 | /// Perform one step of packing of a LinalgOp's metadata along `dim` into the |
114 | /// `newDim` at `iteratorTypes.size()` by: |
115 | /// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`. |
116 | /// 2. Appending a `newDim` to the domain of every indexing map. |
117 | /// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing |
118 | /// by potentially adding a `newDim` result to `map`. |
119 | /// The preserved invariant is that `iteratorTypes.size()` is always equal to |
120 | /// `map.getNumDims()` for every map in `indexingMaps`. |
121 | /// |
122 | /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update. |
123 | /// Return a vector that records the optional packing for each operand. |
124 | /// Return failure if the packed indexing cannot be represented with a LinalgOp. |
125 | /// |
126 | /// Further details: |
127 | /// ================ |
128 | /// The current implementation of packing (i.e. data tiling) consists of |
129 | /// rewriting a linearized strip-mined form into a higher-dimensional access. |
130 | /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite |
131 | /// `I` into `4 * i + ii`, where `0 <= ii < 4`. |
132 | /// The access is further rewritten as `A[i][f(j, k, l)][ii]`. |
133 | /// |
134 | /// This rewrite into higher dimensional access is not possible for general |
135 | /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr: |
136 | /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we |
137 | /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`. |
138 | /// The rewrite of the access would be a form not representable in Linalg: |
139 | /// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`. |
140 | /// Note however that as `J` and `ii` iterate, the accesses do not have a |
141 | /// particular alignment, so packing does not achieve alignment in this case |
142 | /// |
143 | /// In the future, we may want to consider a mixed-form that allows some |
144 | /// alignment in the presence of multiple accesses: |
145 | /// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]` |
146 | /// And would rewrite accesses as: |
147 | /// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]` |
148 | static FailureOr<SmallVector<std::optional<int64_t>>> |
149 | packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps, |
150 | SmallVectorImpl<utils::IteratorType> &iteratorTypes, |
151 | int64_t dim) { |
152 | int64_t newDim = iteratorTypes.size(); |
153 | iteratorTypes.push_back(iteratorTypes[dim]); |
154 | |
155 | SmallVector<std::optional<int64_t>> packedDimPerIndexingMap( |
156 | indexingMaps.size(), std::nullopt); |
157 | SmallVector<AffineMap> newMaps; |
158 | for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e; |
159 | ++operandIdx) { |
160 | AffineMap map = indexingMaps[operandIdx]; |
161 | |
162 | // Add the `newDim` to map whatever the case. |
163 | assert(map.getNumDims() == newDim && "num dims invariant violation" ); |
164 | map = map.shiftDims(shift: 1, offset: newDim); |
165 | |
166 | // Get the at-most-1 index of the result that is a function of `dim`. |
167 | // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which |
168 | // logically chunks dimension `dim` into `K * dim + newDim`, where the |
169 | // packing factor `K` is specified separately. |
170 | assert(hasAtMostOneResultFunctionOfDim(map, dim) && |
171 | "num results invariant violation" ); |
172 | auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim); |
173 | if (!maybeOperandDimensionToPack.has_value()) { |
174 | newMaps.push_back(Elt: map); |
175 | continue; |
176 | } |
177 | |
178 | // We can only pack AffineDimExpr atm. |
179 | if (!isa<AffineDimExpr>(Val: map.getResult(idx: maybeOperandDimensionToPack.value()))) |
180 | return failure(); |
181 | |
182 | // Add `newDim` to the results of the map. |
183 | map = map.insertResult(expr: Builder(map.getContext()).getAffineDimExpr(position: newDim), |
184 | pos: map.getNumResults()); |
185 | newMaps.push_back(Elt: map); |
186 | |
187 | // Record the that `operandIdx` is packed. |
188 | packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack; |
189 | } |
190 | indexingMaps = newMaps; |
191 | |
192 | return packedDimPerIndexingMap; |
193 | } |
194 | |
195 | namespace { |
196 | |
197 | /// Helper struct to encode packing along one dimension of a LinalgOp. |
198 | struct PackedOperandsDim { |
199 | OpFoldResult packedSize; |
200 | SmallVector<std::optional<int64_t>> packedDimForEachOperand; |
201 | }; |
202 | |
203 | /// Helper struct to encode packing along all dimensions of a LinalgOp. |
204 | struct PackedOperandsDimList { |
205 | void pushBack(PackedOperandsDim &&packedOperandsDims) { |
206 | spec.emplace_back(Args&: packedOperandsDims); |
207 | } |
208 | /// Return all the dims that have been packed for operand @ `operandPos`. |
209 | SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos); |
210 | /// Return all the pack sizes by which an operand @ `operandPos` is packed. |
211 | SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos); |
212 | |
213 | private: |
214 | SmallVector<PackedOperandsDim> spec; |
215 | }; |
216 | |
217 | } // namespace |
218 | |
219 | FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, |
220 | tensor::PackOp packOp) { |
221 | // 1. Filter out NYI cases. |
222 | auto packedTensorType = |
223 | cast<RankedTensorType>(packOp->getResultTypes().front()); |
224 | if (llvm::any_of(packOp.getStaticInnerTiles(), |
225 | [](int64_t size) { return ShapedType::isDynamic(size); })) { |
226 | return rewriter.notifyMatchFailure( |
227 | packOp, |
228 | "non-static shape NYI, needs a more powerful tensor.expand_shape op" ); |
229 | } |
230 | |
231 | Location loc = packOp->getLoc(); |
232 | OpBuilder::InsertionGuard g(rewriter); |
233 | rewriter.setInsertionPoint(packOp); |
234 | |
235 | // 2. Compute the permutation vector to shuffle packed shape into the shape |
236 | // before any outer or inner permutations have been applied. |
237 | PackingMetadata packingMetadata = computePackingMetadata( |
238 | packedTensorType.getRank(), packOp.getInnerDimsPos()); |
239 | SmallVector<int64_t> packedToStripMinedShapePerm = |
240 | tensor::getPackInverseDestPerm(packOp); |
241 | |
242 | // 3. Compute the stripMinedShape: this is the packed shape before any outer |
243 | // or inner permutations have been applied. |
244 | SmallVector<int64_t> stripMinedShape(packedTensorType.getShape()); |
245 | applyPermutationToVector(inVec&: stripMinedShape, permutation: packedToStripMinedShapePerm); |
246 | |
247 | // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. |
248 | SmallVector<OpFoldResult> lows(packOp.getSourceRank(), |
249 | rewriter.getIndexAttr(0)); |
250 | SmallVector<OpFoldResult> highs(packOp.getSourceRank(), |
251 | rewriter.getIndexAttr(0)); |
252 | for (auto [pos, innerSize] : |
253 | llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) { |
254 | int outerPos = |
255 | packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]]; |
256 | OpFoldResult origSize = |
257 | tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos); |
258 | OpFoldResult outerSize = |
259 | tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos); |
260 | AffineExpr s0, d0, d1; |
261 | bindDims(rewriter.getContext(), d0, d1); |
262 | bindSymbols(rewriter.getContext(), s0); |
263 | auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1); |
264 | highs[pos] = affine::makeComposedFoldedAffineApply( |
265 | rewriter, loc, map, {outerSize, origSize, innerSize}); |
266 | } |
267 | RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( |
268 | RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), |
269 | packingMetadata.reassociations); |
270 | Value paddingValue = packOp.getPaddingValue(); |
271 | if (!paddingValue) { |
272 | paddingValue = rewriter.create<arith::ConstantOp>( |
273 | loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); |
274 | } |
275 | auto padOp = |
276 | rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows, |
277 | highs, paddingValue, /*nofold=*/false); |
278 | |
279 | LLVM_DEBUG( |
280 | DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, |
281 | DBGS() << "insertPositions: " ); |
282 | DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions, |
283 | DBGS() << "outerPositions: " ); |
284 | DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), |
285 | DBGS() << "packedShape: " ); |
286 | DBGSNL(); |
287 | llvm::interleaveComma(packedToStripMinedShapePerm, |
288 | DBGS() << "packedToStripMinedShapePerm: " ); |
289 | DBGSNL(); llvm::interleaveComma( |
290 | packingMetadata.reassociations, DBGS() << "reassociations: " , |
291 | [&](ReassociationIndices ri) { |
292 | llvm::interleaveComma(ri, llvm::dbgs() << "|" ); |
293 | }); |
294 | DBGSNL(); |
295 | llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: " ); |
296 | DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); |
297 | |
298 | if (packOp.isLikePad()) { |
299 | // Pack ops which operate as simple pads may not produce legal |
300 | // tensor.insert_slice operations when the packed type does not rank reduce |
301 | // to the padded type. |
302 | SliceVerificationResult rankReduces = |
303 | isRankReducedType(packedTensorType, padOp.getResultType()); |
304 | |
305 | if (rankReduces == SliceVerificationResult::Success) { |
306 | // This pack is just a plain pad. |
307 | // Just insert the pad in the higher ranked tensor. |
308 | auto emptyOp = |
309 | rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{}); |
310 | // Offsets. |
311 | SmallVector<OpFoldResult> zeros(packOp.getDestRank(), |
312 | rewriter.getIndexAttr(0)); |
313 | // Strides. |
314 | SmallVector<OpFoldResult> ones(packOp.getDestRank(), |
315 | rewriter.getIndexAttr(1)); |
316 | SmallVector<OpFoldResult> sizes = |
317 | tensor::getMixedSizes(builder&: rewriter, loc, value: packOp.getDest()); |
318 | |
319 | auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>( |
320 | loc, /*source=*/padOp, /*dest=*/emptyOp, |
321 | /*offsets=*/zeros, sizes, |
322 | /*strides=*/ones); |
323 | |
324 | LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); |
325 | |
326 | rewriter.replaceOp(packOp, insertSliceOp->getResults()); |
327 | |
328 | return LowerPackResult{padOp, /*reshapeOp=*/nullptr, |
329 | /*transposeOp=*/nullptr}; |
330 | } |
331 | } |
332 | // 5. Expand from the padded result to the stripMinedShape. |
333 | auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>( |
334 | loc, |
335 | RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), |
336 | padOp.getResult(), packingMetadata.reassociations); |
337 | |
338 | // 6. Transpose stripMinedShape to packedShape. |
339 | SmallVector<int64_t> transpPerm = |
340 | invertPermutationVector(permutation: packedToStripMinedShapePerm); |
341 | auto transposeOp = rewriter.create<linalg::TransposeOp>( |
342 | loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); |
343 | |
344 | LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); |
345 | DBGS() << "reshape op: " << reshapeOp; DBGSNL(); |
346 | llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: " ); |
347 | DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); |
348 | |
349 | // 7. Replace packOp by transposeOp. |
350 | rewriter.replaceOp(packOp, transposeOp->getResults()); |
351 | |
352 | return LowerPackResult{padOp, reshapeOp, transposeOp}; |
353 | } |
354 | |
355 | FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter, |
356 | tensor::UnPackOp unPackOp) { |
357 | // 1. Filter out NYI cases. |
358 | if (!unPackOp.getOuterDimsPerm().empty() && |
359 | !isIdentityPermutation(unPackOp.getOuterDimsPerm())) { |
360 | return rewriter.notifyMatchFailure(unPackOp, |
361 | "non-identity outer dims perm NYI" ); |
362 | } |
363 | |
364 | Location loc = unPackOp->getLoc(); |
365 | OpBuilder::InsertionGuard g(rewriter); |
366 | rewriter.setInsertionPoint(unPackOp); |
367 | |
368 | RankedTensorType packedTensorType = unPackOp.getSourceType(); |
369 | int64_t packedRank = packedTensorType.getRank(); |
370 | |
371 | OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); |
372 | auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType()); |
373 | if (unPackOp.isLikeUnPad()) { |
374 | // This unpack is just a plain unpad. |
375 | // Just extract the slice from the higher ranked tensor. |
376 | ArrayRef<int64_t> destShape = destTensorType.getShape(); |
377 | // The inner dimensions stay the same as the destination tensor, but the |
378 | // outer ones are additional 1s. |
379 | SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one); |
380 | sizes.append(tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest())); |
381 | |
382 | auto = rewriter.create<tensor::ExtractSliceOp>( |
383 | loc, destTensorType, unPackOp.getSource(), |
384 | SmallVector<OpFoldResult>(packedRank, zero), sizes, |
385 | SmallVector<OpFoldResult>(packedRank, one)); |
386 | |
387 | rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); |
388 | |
389 | return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr, |
390 | /*reshapeOp=*/nullptr, extractSliceOp}; |
391 | } |
392 | // 2. Compute the permutation vector to move the last `numPackedDims` into |
393 | // the `innerPosDims` of a shape of rank `packedRank`. |
394 | int64_t numPackedDims = unPackOp.getInnerDimsPos().size(); |
395 | auto lastDims = llvm::to_vector( |
396 | Range: llvm::seq<int64_t>(Begin: packedRank - numPackedDims, End: packedRank)); |
397 | PackingMetadata packingMetadata = |
398 | computePackingMetadata(packedRank, unPackOp.getInnerDimsPos()); |
399 | SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector( |
400 | packedRank, lastDims, packingMetadata.insertPositions); |
401 | |
402 | // 3. Compute the stripMinedShape: this is the packed shape without outer and |
403 | // inner permutations. |
404 | SmallVector<int64_t> stripMinedShape(packedTensorType.getShape()); |
405 | applyPermutationToVector(inVec&: stripMinedShape, permutation: lastDimsToInsertPositionsPerm); |
406 | |
407 | // 4. Transpose packedShape to stripMinedShape. |
408 | RankedTensorType stripMinedTensorType = |
409 | RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); |
410 | RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( |
411 | stripMinedTensorType, packingMetadata.reassociations); |
412 | |
413 | // Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm |
414 | // permutation. |
415 | SmallVector<OpFoldResult, 4> dims = |
416 | tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getSource()); |
417 | applyPermutationToVector(inVec&: dims, permutation: lastDimsToInsertPositionsPerm); |
418 | auto emptyOp = rewriter.create<tensor::EmptyOp>( |
419 | loc, dims, stripMinedTensorType.getElementType()); |
420 | auto transposeOp = rewriter.create<linalg::TransposeOp>( |
421 | loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm); |
422 | |
423 | LLVM_DEBUG( |
424 | DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, |
425 | DBGS() << "insertPositions: " ); |
426 | DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), |
427 | DBGS() << "packedShape: " ); |
428 | DBGSNL(); |
429 | llvm::interleaveComma(lastDimsToInsertPositionsPerm, |
430 | DBGS() << "lastDimsToInsertPositionsPerm: " ); |
431 | DBGSNL(); llvm::interleaveComma( |
432 | packingMetadata.reassociations, DBGS() << "reassociations: " , |
433 | [&](ReassociationIndices ri) { |
434 | llvm::interleaveComma(ri, llvm::dbgs() << "|" ); |
435 | }); |
436 | DBGSNL(); |
437 | llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: " ); |
438 | DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); |
439 | |
440 | // 5. Collapse from the stripMinedShape to the padded result. |
441 | auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>( |
442 | loc, collapsedType, transposeOp->getResult(0), |
443 | packingMetadata.reassociations); |
444 | |
445 | // 6. ExtractSlice. |
446 | int64_t destRank = destTensorType.getRank(); |
447 | auto = rewriter.create<tensor::ExtractSliceOp>( |
448 | loc, destTensorType, reshapeOp->getResult(0), |
449 | SmallVector<OpFoldResult>(destRank, zero), |
450 | tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest()), |
451 | SmallVector<OpFoldResult>(destRank, one)); |
452 | |
453 | // 7. Inject a copy to preserve DPS. |
454 | auto copyOp = rewriter.create<linalg::CopyOp>( |
455 | loc, extractSliceOp->getResult(0), unPackOp.getDest()); |
456 | |
457 | // 8. Replace unPackOp by extractSliceOp. |
458 | rewriter.replaceOp(unPackOp, copyOp->getResults()); |
459 | |
460 | return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; |
461 | } |
462 | |
463 | SmallVector<int64_t> |
464 | PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) { |
465 | SmallVector<int64_t> res; |
466 | for (auto &i : spec) { |
467 | if (!i.packedDimForEachOperand[operandPos].has_value()) |
468 | continue; |
469 | res.push_back(Elt: i.packedDimForEachOperand[operandPos].value()); |
470 | } |
471 | return res; |
472 | } |
473 | |
474 | SmallVector<OpFoldResult> |
475 | PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) { |
476 | SmallVector<OpFoldResult> res; |
477 | for (auto &i : spec) { |
478 | if (!i.packedDimForEachOperand[operandPos].has_value()) |
479 | continue; |
480 | res.push_back(Elt: i.packedSize); |
481 | } |
482 | return res; |
483 | } |
484 | |
485 | /// Implement packing of a single LinalgOp by performing packing by |
486 | /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator. |
487 | /// Return the packed Linalg op on success, failure otherwise. |
488 | FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, |
489 | linalg::LinalgOp linalgOp, |
490 | ArrayRef<OpFoldResult> packedSizes) { |
491 | if (packedSizes.size() != linalgOp.getNumLoops()) { |
492 | return rewriter.notifyMatchFailure(linalgOp, |
493 | "incorrect number of pack sizes" ); |
494 | } |
495 | |
496 | Location loc = linalgOp->getLoc(); |
497 | SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); |
498 | SmallVector<utils::IteratorType> iteratorTypes = |
499 | linalgOp.getIteratorTypesArray(); |
500 | LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n" ; |
501 | llvm::interleaveComma(indexingMaps, DBGS() << "maps: " ); DBGSNL(); |
502 | llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: " ); |
503 | DBGSNL();); |
504 | |
505 | SmallVector<tensor::PackOp> packOps; |
506 | SmallVector<tensor::UnPackOp> unPackOps; |
507 | // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i]. |
508 | PackedOperandsDimList listOfPackedOperandsDim; |
509 | for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) { |
510 | std::optional<int64_t> maybeConstant = getConstantIntValue(ofr: packedSizes[i]); |
511 | // Skip tile sizes explicitly set to 0. |
512 | if (maybeConstant.has_value() && maybeConstant.value() == 0) |
513 | continue; |
514 | |
515 | PackedOperandsDim packedOperandsDims; |
516 | packedOperandsDims.packedSize = packedSizes[i]; |
517 | FailureOr<SmallVector<std::optional<int64_t>>> |
518 | maybePackedDimForEachOperand = |
519 | packLinalgMetadataOnce(indexingMaps, iteratorTypes, i); |
520 | if (failed(result: maybePackedDimForEachOperand)) |
521 | return failure(); |
522 | packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; |
523 | listOfPackedOperandsDim.pushBack(packedOperandsDims: std::move(packedOperandsDims)); |
524 | |
525 | LLVM_DEBUG( |
526 | DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i] |
527 | << "\n" ; |
528 | llvm::interleaveComma(indexingMaps, DBGS() << "maps: " ); DBGSNL(); |
529 | llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: " ); DBGSNL(); |
530 | llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand, |
531 | DBGS() << "packedDimForEachOperand: " ); |
532 | DBGSNL();); |
533 | } |
534 | |
535 | // Step 2. Propagate packing to all LinalgOp operands. |
536 | SmallVector<Value> inputsAndInits, results; |
537 | SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range( |
538 | linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); |
539 | SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands(); |
540 | for (const auto &operandsList : {inputOperands, initOperands}) { |
541 | for (OpOperand *opOperand : operandsList) { |
542 | int64_t pos = opOperand->getOperandNumber(); |
543 | Value operand = opOperand->get(); |
544 | SmallVector<int64_t> innerPos = |
545 | listOfPackedOperandsDim.extractPackedDimsForOperand(pos); |
546 | SmallVector<OpFoldResult> innerPackSizes = |
547 | listOfPackedOperandsDim.extractPackSizesForOperand(pos); |
548 | LLVM_DEBUG( |
549 | DBGS() << "operand: " << operand << "\n" ; |
550 | llvm::interleaveComma(innerPos, DBGS() << "innerPos: " ); DBGSNL(); |
551 | llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: " ); |
552 | DBGSNL();); |
553 | if (innerPackSizes.empty()) { |
554 | inputsAndInits.push_back(operand); |
555 | continue; |
556 | } |
557 | Value dest = tensor::PackOp::createDestinationTensor( |
558 | rewriter, loc, operand, innerPackSizes, innerPos, |
559 | /*outerDimsPerm=*/{}); |
560 | ShapedType operandType = cast<ShapedType>(operand.getType()); |
561 | bool areConstantTiles = |
562 | llvm::all_of(innerPackSizes, [](OpFoldResult tile) { |
563 | return getConstantIntValue(tile).has_value(); |
564 | }); |
565 | if (areConstantTiles && operandType.hasStaticShape() && |
566 | !tensor::PackOp::requirePaddingValue( |
567 | operandType.getShape(), innerPos, |
568 | cast<ShapedType>(dest.getType()).getShape(), {}, |
569 | innerPackSizes)) { |
570 | packOps.push_back(rewriter.create<tensor::PackOp>( |
571 | loc, operand, dest, innerPos, innerPackSizes)); |
572 | } else { |
573 | // TODO: value of the padding attribute should be determined by |
574 | // consumers. |
575 | auto zeroAttr = |
576 | rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); |
577 | Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); |
578 | packOps.push_back(rewriter.create<tensor::PackOp>( |
579 | loc, operand, dest, innerPos, innerPackSizes, zero)); |
580 | } |
581 | inputsAndInits.push_back(packOps.back()); |
582 | } |
583 | } |
584 | |
585 | // Step 3. Build the packed op, use the type of `inits` as result types. |
586 | ValueRange inputs = |
587 | ValueRange{inputsAndInits}.take_front(n: linalgOp.getNumDpsInputs()); |
588 | ValueRange inits = |
589 | ValueRange{inputsAndInits}.take_back(n: linalgOp.getNumDpsInits()); |
590 | auto packedLinalgOp = rewriter.create<linalg::GenericOp>( |
591 | linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps, |
592 | iteratorTypes); |
593 | packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0)); |
594 | |
595 | // Step 4. Propagate packing to all the op results. |
596 | for (OpResult result : packedLinalgOp->getResults()) { |
597 | int64_t resultNum = result.getResultNumber(); |
598 | tensor::PackOp maybePackedInit = |
599 | inits[resultNum].getDefiningOp<tensor::PackOp>(); |
600 | if (!maybePackedInit) { |
601 | results.push_back(result); |
602 | continue; |
603 | } |
604 | // Build the symmetrical UnPackOp to the existing PackOp. |
605 | unPackOps.push_back(rewriter.create<tensor::UnPackOp>( |
606 | packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), |
607 | maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); |
608 | results.push_back(unPackOps.back()); |
609 | } |
610 | |
611 | // Step 5. Replace `linalgOp`. |
612 | rewriter.replaceOp(linalgOp, results); |
613 | |
614 | // Return packedLinalgOp. |
615 | return PackResult{packOps, |
616 | cast<linalg::LinalgOp>(packedLinalgOp.getOperation()), |
617 | unPackOps}; |
618 | } |
619 | |
620 | //===----------------------------------------------------------------------===// |
621 | // packTranspose transformation. |
622 | //===----------------------------------------------------------------------===// |
623 | |
624 | /// Return a copy of `tensorType` after permutation by `permutationVector`. |
625 | // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder |
626 | // but this would introduce a dependence on Dialect in IR. |
627 | // TODO: Restructure. |
628 | static RankedTensorType permuteShape(RankedTensorType tensorType, |
629 | ArrayRef<int64_t> permutationVector) { |
630 | SmallVector<int64_t> shape(tensorType.getShape()); |
631 | applyPermutationToVector(inVec&: shape, permutation: permutationVector); |
632 | return RankedTensorType::Builder(tensorType).setShape(shape); |
633 | } |
634 | |
635 | /// Return a new GenericOp obtained by transposing opOperand by the permutation |
636 | /// vector: |
637 | /// - the corresponding indexing map is transposed by `permutation` |
638 | /// - the corresponding operand value is replaced by `transposedValue` |
639 | /// `linalgOp` is replaced by the return op in the process. |
640 | /// Asserts that `transposedValue` is of the proper transposed ShapedType. |
641 | static LinalgOp transposeOneLinalgOperandAndReplace( |
642 | RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, |
643 | ArrayRef<int64_t> permutation, Value transposedValue) { |
644 | // Sanity check the operand. |
645 | assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand" ); |
646 | |
647 | // Sanity check of the expected transposed tensor type. |
648 | auto tensorType = permuteShape( |
649 | cast<RankedTensorType>(opOperand.get().getType()), permutation); |
650 | (void)tensorType; |
651 | assert(tensorType == transposedValue.getType() && |
652 | "expected tensor type mismatch" ); |
653 | |
654 | // Compute the transposed indexing map. |
655 | // Sigh unsigned pollution. |
656 | SmallVector<unsigned> tmpTransposition = llvm::to_vector( |
657 | Range: llvm::map_range(C&: permutation, F: [](int64_t i) -> unsigned { return i; })); |
658 | AffineMap permutationMap = |
659 | AffineMap::getPermutationMap(permutation: tmpTransposition, context: rewriter.getContext()); |
660 | AffineMap transposedMap = |
661 | permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand)); |
662 | |
663 | // Set the transposed indexing map in the proper position. |
664 | SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); |
665 | indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap; |
666 | // Set the transposedValue in the proper operand position. |
667 | SmallVector<Value> operands = linalgOp->getOperands(); |
668 | operands[opOperand.getOperandNumber()] = transposedValue; |
669 | |
670 | ValueRange operandsRef(operands); |
671 | auto transposedGenericOp = rewriter.create<linalg::GenericOp>( |
672 | /*location=*/linalgOp->getLoc(), |
673 | /*resultTensorTypes=*/ |
674 | operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(), |
675 | /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()), |
676 | /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()), |
677 | /*indexingMaps=*/indexingMaps, |
678 | /*iteratorTypes=*/linalgOp.getIteratorTypesArray()); |
679 | transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0)); |
680 | rewriter.replaceOp(linalgOp, transposedGenericOp->getResults()); |
681 | |
682 | return cast<linalg::LinalgOp>(transposedGenericOp.getOperation()); |
683 | } |
684 | |
685 | FailureOr<PackTransposeResult> |
686 | linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, |
687 | linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, |
688 | ArrayRef<int64_t> outerPerm, |
689 | ArrayRef<int64_t> innerPerm) { |
690 | Location loc = linalgOp.getLoc(); |
691 | |
692 | // Step 1. Transpose packOp. |
693 | rewriter.setInsertionPoint(packOp); |
694 | tensor::PackOp transposedPackOp = |
695 | packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm); |
696 | |
697 | if (!packOp.getResult().hasOneUse()) |
698 | return rewriter.notifyMatchFailure(linalgOp, "expect single pack use" ); |
699 | |
700 | OpOperand &packUse = *packOp->getUses().begin(); |
701 | if (packUse.getOwner() != linalgOp) { |
702 | return rewriter.notifyMatchFailure( |
703 | linalgOp, "not a single use by the LinalgOp target" ); |
704 | } |
705 | if (maybeUnPackOp && |
706 | (!linalgOp.isDpsInit(&packUse) || |
707 | maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) { |
708 | return rewriter.notifyMatchFailure(linalgOp, |
709 | "not produced by the LinalgOp target" ); |
710 | } |
711 | |
712 | // Step 2. Transpose linalgOp. |
713 | // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the |
714 | // identity. Don't rely on it. |
715 | int64_t numLeadingDims = packOp.getSourceRank(); |
716 | int64_t numTrailingDims = packOp.getInnerDimsPos().size(); |
717 | // Step 2.a. Compute the permutation on the whole operand. |
718 | // Leading part just reuse the outerPerm. |
719 | SmallVector<int64_t> permutation(outerPerm); |
720 | if (permutation.empty()) |
721 | llvm::append_range(C&: permutation, R: llvm::seq<int64_t>(Begin: 0, End: numLeadingDims)); |
722 | // Trailing part needs to reindex positions by `numLeadingDims`. |
723 | if (innerPerm.empty()) { |
724 | llvm::append_range( |
725 | C&: permutation, |
726 | R: llvm::seq<int64_t>(Begin: numLeadingDims, End: numLeadingDims + numTrailingDims)); |
727 | } else { |
728 | llvm::append_range(permutation, |
729 | llvm::map_range(innerPerm, [&](int64_t pos) { |
730 | return numLeadingDims + pos; |
731 | })); |
732 | } |
733 | if (!isPermutationVector(interchange: permutation)) |
734 | return rewriter.notifyMatchFailure(linalgOp, "invalid permutation" ); |
735 | |
736 | // Step 2.b. Save the transposedPackUse operand number in case we need to |
737 | // get the tied OpResult after `linalgOp` has been replaced. |
738 | int64_t packUseOperandNumber = packUse.getOperandNumber(); |
739 | // Step 2.c. Actually perform the transposition. |
740 | rewriter.setInsertionPoint(linalgOp); |
741 | linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace( |
742 | rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult()); |
743 | |
744 | // Step 3. Maybe transpose unPackOp. |
745 | tensor::UnPackOp transposedUnPackOp; |
746 | if (maybeUnPackOp) { |
747 | OpOperand &opOperand = |
748 | transposedLinalgOp->getOpOperand(packUseOperandNumber); |
749 | OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand); |
750 | rewriter.setInsertionPoint(maybeUnPackOp); |
751 | transposedUnPackOp = maybeUnPackOp.createTransposedClone( |
752 | rewriter, loc, transposedResult, innerPerm, outerPerm); |
753 | |
754 | rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults()); |
755 | } |
756 | |
757 | // Step 4. Finally, replace packOp now that we don't need it anymore. |
758 | rewriter.replaceOp(packOp, transposedPackOp->getResults()); |
759 | |
760 | return PackTransposeResult{transposedPackOp, transposedLinalgOp, |
761 | transposedUnPackOp}; |
762 | } |
763 | |
764 | //===----------------------------------------------------------------------===// |
765 | // packMatmulGreedily transformation. |
766 | //===----------------------------------------------------------------------===// |
767 | |
768 | /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m |
769 | /// and n are proper parallel dimensions and k is a proper reduction |
770 | /// dimension. Packing occurs by rewriting the op as a linalg.generic and |
771 | /// calling linalg::pack by `mnkPackedSizes`. The order of the packed |
772 | /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2} |
773 | /// to reorder {m, n, k} into one of the 8 possible forms. The outer |
774 | /// dimensions of the operands are not permuted at this time, this is left for |
775 | /// future work. |
776 | FailureOr<PackResult> |
777 | linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, |
778 | ArrayRef<OpFoldResult> mnkPackedSizes, |
779 | ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf, |
780 | ArrayRef<int64_t> mnkOrder) { |
781 | assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes" ); |
782 | assert((mnkPaddedSizesNextMultipleOf.empty() || |
783 | mnkPaddedSizesNextMultipleOf.size() == 3) && |
784 | "num of packing sizes next multiple should be empty or of size 3" ); |
785 | assert(mnkOrder.size() == 3 && "unexpected mnkOrder size" ); |
786 | assert(isPermutationVector(mnkOrder) && "expected a permutation" ); |
787 | |
788 | int64_t numLoops = linalgOp.getNumLoops(); |
789 | if (numLoops <= 2) { |
790 | LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got " |
791 | << numLoops << "\nin: " << linalgOp << "\n" ); |
792 | return rewriter.notifyMatchFailure( |
793 | linalgOp, "need 3+ loops to find a matmul to pack" ); |
794 | } |
795 | |
796 | // Locally adjust the desired iterator position of mnk and packing sizes. |
797 | int64_t numPackedDims = mnkPackedSizes.size(); |
798 | SmallVector<int64_t> mmnnkkPos(numPackedDims); |
799 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) |
800 | mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i]; |
801 | SmallVector<OpFoldResult> packedSizes(numPackedDims); |
802 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) |
803 | packedSizes[mnkOrder[i]] = mnkPackedSizes[i]; |
804 | SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims); |
805 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) { |
806 | paddedSizesNextMultipleOf[mnkOrder[i]] = |
807 | mnkPaddedSizesNextMultipleOf.empty() ? 0 |
808 | : mnkPaddedSizesNextMultipleOf[i]; |
809 | } |
810 | |
811 | // 1. Infer dims that are important for matmul. |
812 | FailureOr<ContractionDimensions> maybeDimensions = |
813 | inferContractionDims(linalgOp); |
814 | if (failed(result: maybeDimensions)) { |
815 | LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp |
816 | << "\n" ); |
817 | return rewriter.notifyMatchFailure(linalgOp, |
818 | "couldn't infer matmul iterators" ); |
819 | } |
820 | |
821 | // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most |
822 | // minor iterators. In cases with multiple options for m, n, k bias towards |
823 | // the most minor embedding. |
824 | // If we wanted a different normalization order, this is where it would have |
825 | // to plug a heuristic. |
826 | int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(), |
827 | kPos = maybeDimensions->k.back(); |
828 | LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); |
829 | DBGS() << "Start packing generic op greedily with (m@" << mPos |
830 | << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp |
831 | << "\n" ;); |
832 | |
833 | // 2.a. Rewrite as a generic. |
834 | auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation()); |
835 | if (!genericOp) { |
836 | FailureOr<GenericOp> generalizeResult = |
837 | generalizeNamedOp(rewriter, linalgOp); |
838 | assert(succeeded(generalizeResult) && "unexpected failure generalizing op" ); |
839 | genericOp = *generalizeResult; |
840 | } |
841 | |
842 | // 2.b. Interchange to move the dimensions (k, m, n) as most-minor |
843 | // iterators. Note that this only normalized the iteration order and does |
844 | // not change the indexings of any operand. |
845 | SmallVector<int64_t> permutation = |
846 | computePermutationVector(permSize: numLoops, positions: {mPos, nPos, kPos}, desiredPositions: mmnnkkPos); |
847 | LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: " ); DBGSNL();); |
848 | // Sign .. unsigned pollution. |
849 | SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end()); |
850 | FailureOr<GenericOp> interchangeResult = |
851 | interchangeGenericOp(rewriter, genericOp, unsignedPerm); |
852 | assert(succeeded(interchangeResult) && "unexpected failure interchanging op" ); |
853 | genericOp = *interchangeResult; |
854 | LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n" ;); |
855 | |
856 | // At this point, the op iterators are normalized to {leading, k, m, n}. |
857 | // The layouts induced by packing will always be: |
858 | // - LHS{leading_lhs, kk, mm} |
859 | // - RHS{leading_rhs, kk, nn} |
860 | // - RES{leading_res, mm, nn} |
861 | // If we wanted to change the packed order, we would reorder (k, m, n) to |
862 | // something else above. |
863 | // |
864 | // Additional permutations of the outer dims of the operands (i.e. |
865 | // leading_lhs, leading_rhs and leading_res) could follow by computing the |
866 | // desired outerPerm for each operand. |
867 | // This is left for future work. |
868 | |
869 | // TODO: this creates too much IR, go use reifyResultShapes. |
870 | SmallVector<Range, 4> loopRanges = |
871 | cast<LinalgOp>(genericOp.getOperation()) |
872 | .createLoopRanges(rewriter, genericOp.getLoc()); |
873 | |
874 | // Add leading zeros to match numLoops, we only pack the last 3 dimensions |
875 | // post interchange. |
876 | LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf, |
877 | DBGS() << "paddedSizesNextMultipleOf: " ); |
878 | DBGSNL();); |
879 | LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: " , |
880 | [](Range r) { llvm::dbgs() << r.size; }); |
881 | DBGSNL();); |
882 | SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(), |
883 | rewriter.getIndexAttr(0)); |
884 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) { |
885 | if (paddedSizesNextMultipleOf[i] == 0) { |
886 | adjustedPackedSizes.push_back(Elt: packedSizes[i]); |
887 | continue; |
888 | } |
889 | AffineExpr d0, s0; |
890 | bindDims(ctx: rewriter.getContext(), exprs&: d0); |
891 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0); |
892 | adjustedPackedSizes.push_back(Elt: affine::makeComposedFoldedAffineApply( |
893 | rewriter, genericOp->getLoc(), d0.ceilDiv(other: s0) * s0, |
894 | {loopRanges[adjustedPackedSizes.size()].size, |
895 | rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])})); |
896 | } |
897 | LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes, |
898 | DBGS() << "adjustedPackedSizes: " ); |
899 | DBGSNL();); |
900 | |
901 | // TODO: If we wanted to give the genericOp a name after packing, after |
902 | // calling `pack` would be a good time. One would still need to check that |
903 | // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we |
904 | // also allow degenerate matmul cases (i.e. matvec, dot). |
905 | return pack(rewriter, genericOp, adjustedPackedSizes); |
906 | } |
907 | |
908 | //===----------------------------------------------------------------------===// |
909 | // Transformations exposed as rewrite patterns. |
910 | //===----------------------------------------------------------------------===// |
911 | |
912 | LinalgTilingOptions & |
913 | mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { |
914 | assert(!tileSizeComputationFunction && "tile sizes already set" ); |
915 | SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); |
916 | tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { |
917 | OpBuilder::InsertionGuard guard(b); |
918 | b.setInsertionPointToStart( |
919 | &op->getParentOfType<func::FuncOp>().getBody().front()); |
920 | return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { |
921 | Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); |
922 | return v; |
923 | })); |
924 | }; |
925 | return *this; |
926 | } |
927 | |
928 | LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( |
929 | memref::CopyOp copyOp, PatternRewriter &rewriter) const { |
930 | return vectorizeCopy(rewriter, copyOp); |
931 | } |
932 | |
933 | /// Filling `dest` using FillOp constant padding value if possible. |
934 | /// Otherwise, generate a tensor::GenerateOp. |
935 | Value GeneralizePadOpPattern::createFillOrGenerateOp( |
936 | RewriterBase &rewriter, tensor::PadOp padOp, Value dest, |
937 | const SmallVector<Value> &dynSizes) const { |
938 | auto padValue = padOp.getConstantPaddingValue(); |
939 | if (padValue) |
940 | return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); |
941 | |
942 | // Fill could not be optimized: Lower to tensor::GenerateOp with region. |
943 | auto generateOp = rewriter.create<tensor::GenerateOp>( |
944 | padOp.getLoc(), padOp.getResultType(), dynSizes); |
945 | // Copy region to new op. |
946 | IRMapping bvm; |
947 | padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm); |
948 | return generateOp; |
949 | } |
950 | |
951 | LogicalResult |
952 | GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, |
953 | PatternRewriter &rewriter) const { |
954 | // Given an OpFoldResult, return an index-typed value. |
955 | auto getIdxValue = [&](OpFoldResult ofr) { |
956 | if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr)) |
957 | return val; |
958 | return rewriter |
959 | .create<arith::ConstantIndexOp>( |
960 | padOp.getLoc(), cast<IntegerAttr>(ofr.get<Attribute>()).getInt()) |
961 | .getResult(); |
962 | }; |
963 | |
964 | auto resultType = padOp.getResultType(); |
965 | // Compute size of EmptyOp. Any combination of static/dynamic is supported. |
966 | SmallVector<Value> dynSizes; |
967 | SmallVector<int64_t> staticSizes; |
968 | for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { |
969 | if (resultType.isDynamicDim(dim)) { |
970 | auto srcSize = getIdxValue(tensor::getMixedSize(builder&: rewriter, loc: padOp.getLoc(), |
971 | value: padOp.getSource(), dim)); |
972 | // Add low and high padding value. |
973 | auto plusLow = rewriter.createOrFold<arith::AddIOp>( |
974 | padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); |
975 | auto plusHigh = rewriter.createOrFold<arith::AddIOp>( |
976 | padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); |
977 | dynSizes.push_back(Elt: plusHigh); |
978 | } |
979 | staticSizes.push_back(Elt: resultType.getDimSize(dim)); |
980 | } |
981 | |
982 | // Init tensor and fill it with padding. |
983 | Value emptyTensor = rewriter.create<tensor::EmptyOp>( |
984 | padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes); |
985 | Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes); |
986 | |
987 | // Try optimize the copy of source. |
988 | if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded()) |
989 | return success(); |
990 | |
991 | // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead |
992 | // for copying the PadOp source. |
993 | auto sourceType = padOp.getSourceType(); |
994 | // Compute size of source of tensor::PadOp. |
995 | SmallVector<OpFoldResult> srcSizes = |
996 | tensor::getMixedSizes(builder&: rewriter, loc: padOp.getLoc(), value: padOp.getSource()); |
997 | // Strides of InsertSliceOp are all 1. |
998 | SmallVector<OpFoldResult> strides(sourceType.getRank(), |
999 | rewriter.getIndexAttr(1)); |
1000 | rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( |
1001 | padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes, |
1002 | strides); |
1003 | |
1004 | return success(); |
1005 | } |
1006 | |
1007 | LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( |
1008 | tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { |
1009 | if (!sliceOp.hasUnitStride()) |
1010 | return failure(); |
1011 | |
1012 | auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>(); |
1013 | if (!padOp) |
1014 | return failure(); |
1015 | |
1016 | bool zeroSliceGuard = true; |
1017 | if (controlFn) { |
1018 | if (std::optional<bool> control = controlFn(sliceOp)) |
1019 | zeroSliceGuard = *control; |
1020 | else |
1021 | return failure(); |
1022 | } |
1023 | |
1024 | FailureOr<TilingResult> tilingResult = |
1025 | tensor::bubbleUpPadSlice(b&: rewriter, padOp: padOp, offsets: sliceOp.getMixedOffsets(), |
1026 | sizes: sliceOp.getMixedSizes(), generateZeroSliceGuard: zeroSliceGuard); |
1027 | if (failed(result: tilingResult)) |
1028 | return failure(); |
1029 | // All shapes are static and the data source is actually used. Rewrite into |
1030 | // pad(extract_slice(x)). |
1031 | rewriter.replaceOp(sliceOp, tilingResult->tiledValues); |
1032 | return success(); |
1033 | } |
1034 | |
1035 | /// Returns a tensor.pad op if padding value is set. Otherwise, returns the |
1036 | /// source directly. The method assumes that the `packOp` has static shapes. |
1037 | static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, |
1038 | tensor::PackOp packOp) { |
1039 | Value input = packOp.getSource(); |
1040 | if (!packOp.getPaddingValue()) { |
1041 | return input; |
1042 | } |
1043 | |
1044 | Location loc = packOp.getLoc(); |
1045 | ShapedType inputType = packOp.getSourceType(); |
1046 | int64_t inputRank = inputType.getRank(); |
1047 | assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank), |
1048 | [](int64_t val) { return val == 1; })); |
1049 | |
1050 | SmallVector<int64_t> paddedShape; |
1051 | DenseMap<int64_t, OpFoldResult> tileAndPosMapping = |
1052 | packOp.getDimAndTileMapping(); |
1053 | for (int64_t dim = 0; dim < inputRank; ++dim) { |
1054 | int64_t size = inputType.getDimSize(dim); |
1055 | if (!tileAndPosMapping.count(Val: dim)) { |
1056 | paddedShape.push_back(Elt: size); |
1057 | continue; |
1058 | } |
1059 | |
1060 | // The size is less than or equal to tileSize because outer dims are all 1s. |
1061 | std::optional<int64_t> tileSize = |
1062 | getConstantIntValue(ofr: tileAndPosMapping.lookup(Val: dim)); |
1063 | assert(tileSize.has_value() && "dynamic inner tile size is not supported" ); |
1064 | paddedShape.push_back(Elt: tileSize.value()); |
1065 | } |
1066 | auto resultType = |
1067 | RankedTensorType::get(paddedShape, inputType.getElementType()); |
1068 | return tensor::createPadHighOp(type: resultType, source: input, pad: packOp.getPaddingValue(), |
1069 | /*nofold=*/false, loc, builder); |
1070 | } |
1071 | |
1072 | // Normalizes a permutation on a higher rank space to its actual size, e.g. |
1073 | // perm = [1, 4, 2] |
1074 | // becomes |
1075 | // norm = [0, 2, 1] |
1076 | static SmallVector<int64_t> |
1077 | getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) { |
1078 | constexpr int64_t kNonTiledMarker = -1; |
1079 | SmallVector<int64_t> vec(rank, kNonTiledMarker); |
1080 | for (auto [index, value] : llvm::enumerate(First&: perm)) |
1081 | vec[value] = index; |
1082 | SmallVector<int64_t> normalizedPerm = llvm::to_vector(Range: llvm::make_filter_range( |
1083 | Range&: vec, Pred: [&](int64_t v) { return v != kNonTiledMarker; })); |
1084 | // This inverts the permutation in addition to normalizing so invert back. |
1085 | return invertPermutationVector(permutation: normalizedPerm); |
1086 | } |
1087 | |
1088 | // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm |
1089 | // assuming rank reduction of unit outer dims. |
1090 | static SmallVector<int64_t> |
1091 | getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape, |
1092 | ArrayRef<int64_t> innerDimsPos, |
1093 | ArrayRef<int64_t> outerDimsPerm) { |
1094 | SmallVector<int64_t> rankReducedOuterDimsPerm; |
1095 | SmallVector<int64_t> outerDims; |
1096 | SmallVector<int64_t> innerDims; |
1097 | int64_t dim = 0; |
1098 | int64_t unpackedRank = shape.size(); |
1099 | for (auto i : llvm::seq<unsigned>(Begin: 0, End: unpackedRank)) { |
1100 | if (llvm::is_contained(Range&: innerDimsPos, Element: i)) { |
1101 | innerDims.push_back(Elt: dim++); |
1102 | continue; |
1103 | } |
1104 | if (shape[i] == 1) |
1105 | continue; |
1106 | outerDims.push_back(Elt: dim++); |
1107 | if (!outerDimsPerm.empty()) |
1108 | rankReducedOuterDimsPerm.push_back(Elt: outerDimsPerm[i]); |
1109 | } |
1110 | |
1111 | // Get the position of the inner dims after permutation. |
1112 | SmallVector<int64_t> innerPerm = |
1113 | getPackUnpackNormalizedPerm(rank: unpackedRank, perm: innerDimsPos); |
1114 | applyPermutationToVector<int64_t>(inVec&: innerDims, permutation: innerPerm); |
1115 | |
1116 | // Ditto for the outer dims. |
1117 | SmallVector<int64_t> perm = outerDims; |
1118 | |
1119 | rankReducedOuterDimsPerm = |
1120 | getPackUnpackNormalizedPerm(rank: unpackedRank, perm: rankReducedOuterDimsPerm); |
1121 | if (!rankReducedOuterDimsPerm.empty()) |
1122 | applyPermutationToVector<int64_t>(inVec&: perm, permutation: rankReducedOuterDimsPerm); |
1123 | |
1124 | // The tile always ends up as the inner most dims after packing. |
1125 | perm.append(RHS: innerDims); |
1126 | |
1127 | return perm; |
1128 | } |
1129 | |
1130 | LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( |
1131 | tensor::PackOp packOp, PatternRewriter &rewriter) const { |
1132 | if (llvm::any_of(packOp.getMixedTiles(), |
1133 | [](OpFoldResult tile) { return tile.is<Value>(); })) { |
1134 | return rewriter.notifyMatchFailure(packOp, |
1135 | "require inner tile sizes being static" ); |
1136 | } |
1137 | |
1138 | // TODO: support the case that outer dimensions are not all 1s. A |
1139 | // tensor.expand_shape will be generated in this case. |
1140 | auto innerDimsPos = packOp.getInnerDimsPos(); |
1141 | int64_t srcRank = packOp.getSourceRank(); |
1142 | auto destShape = packOp.getDestType().getShape(); |
1143 | if (llvm::any_of(innerDimsPos, [destShape](int64_t index) { |
1144 | return destShape[index] != 1; |
1145 | })) { |
1146 | return rewriter.notifyMatchFailure( |
1147 | packOp, "require the tiled outer dimensions of the result are all 1s" ); |
1148 | } |
1149 | |
1150 | // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled |
1151 | // outer dims. |
1152 | Location loc = packOp.getLoc(); |
1153 | Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); |
1154 | auto inputShape = packOp.getSourceType().getShape(); |
1155 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
1156 | packOp.getDimAndTileMapping(); |
1157 | Attribute zeroIdxAttr = rewriter.getIndexAttr(0); |
1158 | Attribute oneIdxAttr = rewriter.getIndexAttr(1); |
1159 | SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr); |
1160 | SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr); |
1161 | SmallVector<OpFoldResult> readSizes; |
1162 | SmallVector<int64_t> readShape; |
1163 | for (auto i : llvm::seq<unsigned>(0, srcRank)) { |
1164 | if (dimAndTileMapping.count(i)) { |
1165 | readShape.push_back(getConstantIntValue(dimAndTileMapping[i]) |
1166 | .value_or(ShapedType::kDynamic)); |
1167 | readSizes.push_back(dimAndTileMapping[i]); |
1168 | continue; |
1169 | } |
1170 | if (ShapedType::isDynamic(inputShape[i])) { |
1171 | readSizes.push_back( |
1172 | rewriter.create<tensor::DimOp>(loc, input, i).getResult()); |
1173 | } else { |
1174 | readSizes.push_back(rewriter.getIndexAttr(inputShape[i])); |
1175 | } |
1176 | if (inputShape[i] != 1) |
1177 | readShape.push_back(inputShape[i]); |
1178 | } |
1179 | |
1180 | Type elemType = packOp.getSourceType().getElementType(); |
1181 | auto readType = RankedTensorType::get(readShape, elemType); |
1182 | |
1183 | Value tile = rewriter.create<tensor::ExtractSliceOp>( |
1184 | loc, readType, input, readOffsets, readSizes, readStrides); |
1185 | |
1186 | // 2. Transpose the tile to match the inner tile order. |
1187 | |
1188 | SmallVector<int64_t> perm = getPackUnpackRankReducedPerm( |
1189 | inputShape, innerDimsPos, packOp.getOuterDimsPerm()); |
1190 | |
1191 | LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n" ; |
1192 | llvm::interleaveComma(perm, DBGS() << "perm: " ); DBGSNL();); |
1193 | |
1194 | SmallVector<int64_t> transpShape = readShape; |
1195 | applyPermutationToVector<int64_t>(inVec&: transpShape, permutation: perm); |
1196 | |
1197 | Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType); |
1198 | auto transposedOp = |
1199 | rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm); |
1200 | |
1201 | // 3. Insert the inner tile to the destination. |
1202 | int64_t destRank = packOp.getDestRank(); |
1203 | SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); |
1204 | SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); |
1205 | SmallVector<OpFoldResult> writeSizes = |
1206 | tensor::getMixedSizes(builder&: rewriter, loc, value: packOp.getDest()); |
1207 | |
1208 | auto insert = rewriter.create<tensor::InsertSliceOp>( |
1209 | loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, |
1210 | writeSizes, writeStrides); |
1211 | rewriter.replaceOp(packOp, insert.getResult()); |
1212 | |
1213 | return success(); |
1214 | } |
1215 | |
1216 | LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite( |
1217 | tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const { |
1218 | int64_t srcRank = unpackOp.getSourceRank(); |
1219 | int64_t destRank = unpackOp.getDestRank(); |
1220 | ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape(); |
1221 | ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); |
1222 | if (llvm::any_of(Range&: innerDimsPos, P: [srcShape](int64_t index) { |
1223 | return srcShape[index] != 1; |
1224 | })) { |
1225 | return rewriter.notifyMatchFailure( |
1226 | unpackOp, |
1227 | "require the tiled outer dimensions of the result are all 1s" ); |
1228 | } |
1229 | |
1230 | // 1. Use rank-reduced tensor.extract_slice op to extract the tile. |
1231 | Location loc = unpackOp.getLoc(); |
1232 | Value source = unpackOp.getSource(); |
1233 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
1234 | unpackOp.getDimAndTileMapping(); |
1235 | Attribute zeroIdxAttr = rewriter.getIndexAttr(0); |
1236 | Attribute oneIdxAttr = rewriter.getIndexAttr(1); |
1237 | SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr); |
1238 | SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr); |
1239 | SmallVector<OpFoldResult> readSizes; |
1240 | SmallVector<int64_t> readShape; |
1241 | SmallVector<Value> dynamicDims; |
1242 | for (auto i : llvm::seq<unsigned>(0, destRank)) { |
1243 | if (dimAndTileMapping.count(i)) { |
1244 | readSizes.push_back(oneIdxAttr); |
1245 | continue; |
1246 | } |
1247 | |
1248 | if (ShapedType::isDynamic(srcShape[i])) { |
1249 | Value dynamicDim = |
1250 | rewriter.create<tensor::DimOp>(loc, source, i).getResult(); |
1251 | readSizes.push_back(dynamicDim); |
1252 | dynamicDims.push_back(dynamicDim); |
1253 | } else { |
1254 | readSizes.push_back(rewriter.getIndexAttr(srcShape[i])); |
1255 | } |
1256 | if (srcShape[i] != 1) |
1257 | readShape.push_back(srcShape[i]); |
1258 | } |
1259 | auto mixedTiles = unpackOp.getMixedTiles(); |
1260 | readSizes.append(mixedTiles.begin(), mixedTiles.end()); |
1261 | |
1262 | // Explicitly create the type for extract_slice op because the inner tile |
1263 | // size could be 1. We want to represent the whole inner tile in this case. |
1264 | auto tileShape = srcShape.drop_front(N: destRank); |
1265 | // Append the inner tile shape to the permuted and rank-reduced outer shape. |
1266 | readShape.append(tileShape.begin(), tileShape.end()); |
1267 | Type elemType = unpackOp.getSourceType().getElementType(); |
1268 | auto readType = RankedTensorType::get(readShape, elemType); |
1269 | Value innerTile = rewriter.create<tensor::ExtractSliceOp>( |
1270 | loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides); |
1271 | |
1272 | // 2. Transpose the tile to match the outer corresponding tile order. |
1273 | SmallVector<int64_t> perm = getPackUnpackRankReducedPerm( |
1274 | srcShape.take_front(N: destRank), innerDimsPos, unpackOp.getOuterDimsPerm()); |
1275 | // Unpack is a transition out of packed space so we invert the permutation. |
1276 | perm = invertPermutationVector(permutation: perm); |
1277 | SmallVector<int64_t> transpShape(readShape); |
1278 | applyPermutationToVector<int64_t>(inVec&: transpShape, permutation: perm); |
1279 | |
1280 | Value empty = |
1281 | rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims); |
1282 | auto transposedOp = |
1283 | rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm); |
1284 | |
1285 | // 3. Handle in-complete tiles if needed. It truncates trailing data from the |
1286 | // transposed tile. |
1287 | int numLoops = transpShape.size(); |
1288 | SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr); |
1289 | SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr); |
1290 | SmallVector<OpFoldResult> tileSizes; |
1291 | ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape(); |
1292 | for (auto i : llvm::seq<unsigned>(0, destRank)) { |
1293 | if (dimAndTileMapping.count(i) || destShape[i] != 1) |
1294 | tileSizes.push_back( |
1295 | tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i)); |
1296 | } |
1297 | |
1298 | auto partialTile = rewriter.create<tensor::ExtractSliceOp>( |
1299 | loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); |
1300 | |
1301 | // 4. Insert the result to the destination tensor. |
1302 | SmallVector<OpFoldResult> writeSizes; |
1303 | SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); |
1304 | SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); |
1305 | for (int i = 0, idx = 0; i < destRank; ++i) { |
1306 | if (dimAndTileMapping.count(Val: i) || destShape[i] != 1) |
1307 | writeSizes.push_back(Elt: tileSizes[idx++]); |
1308 | else |
1309 | writeSizes.push_back(Elt: oneIdxAttr); |
1310 | } |
1311 | auto insert = rewriter.create<tensor::InsertSliceOp>( |
1312 | loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes, |
1313 | writeStrides); |
1314 | rewriter.replaceOp(unpackOp, insert.getResult()); |
1315 | |
1316 | return success(); |
1317 | } |
1318 | |
1319 | // The following are patterns for downscaling convolution ops with size-1 |
1320 | // window dimensions. |
1321 | // |
1322 | // Note that we'd eventually want to write such transformations in a generic |
1323 | // way, e.g., converting to linalg.generic, removing the size-1 dimensions, |
1324 | // and then turning back to named ops. But for now it's fine to have a few |
1325 | // patterns matching special ops to get started. |
1326 | |
1327 | template <typename Conv2DOp, typename Conv1DOp> |
1328 | FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>:: |
1329 | returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { |
1330 | if (convOp.hasPureBufferSemantics()) |
1331 | return failure(); // To be implemented. |
1332 | |
1333 | Value input = convOp.getInputs().front(); |
1334 | Value kernel = convOp.getInputs().back(); |
1335 | Value output = convOp.getOutputs().front(); |
1336 | |
1337 | auto inputType = dyn_cast<RankedTensorType>(input.getType()); |
1338 | auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); |
1339 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
1340 | |
1341 | auto kernelShape = kernelType.getShape(); |
1342 | auto outputShape = outputType.getShape(); |
1343 | |
1344 | // Get domain indices based on conv2D layout. |
1345 | auto [khIndex, kwIndex, ohIndex, owIndex] = |
1346 | TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>( |
1347 | convOp) |
1348 | .Case([&](linalg::Conv2DNhwcHwcfOp op) { |
1349 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1350 | }) |
1351 | .Case([&](linalg::Conv2DNchwFchwOp op) { |
1352 | return std::make_tuple(args: 2, args: 3, args: 2, args: 3); |
1353 | }) |
1354 | .Case([&](linalg::PoolingNhwcSumOp op) { |
1355 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1356 | }) |
1357 | .Case([&](linalg::PoolingNchwSumOp op) { |
1358 | return std::make_tuple(args: 0, args: 1, args: 2, args: 3); |
1359 | }) |
1360 | .Case([&](linalg::PoolingNhwcMaxOp op) { |
1361 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1362 | }) |
1363 | .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) { |
1364 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1365 | }) |
1366 | .Case([&](linalg::PoolingNhwcMinOp op) { |
1367 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1368 | }) |
1369 | .Case([&](linalg::PoolingNhwcMinUnsignedOp op) { |
1370 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1371 | }) |
1372 | .Case([&](linalg::PoolingNchwMaxOp op) { |
1373 | return std::make_tuple(args: 0, args: 1, args: 2, args: 3); |
1374 | }) |
1375 | .Default([&](Operation *op) { |
1376 | llvm_unreachable("unexpected conv2d/pool2d operation." ); |
1377 | return std::make_tuple(args: 0, args: 0, args: 0, args: 0); |
1378 | }); |
1379 | |
1380 | // Only handle the case where at least one of the window dimensions is |
1381 | // of size 1. Other cases can rely on tiling to reduce to such cases. |
1382 | int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; |
1383 | int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; |
1384 | bool removeH = (khSize == 1 && ohSize == 1); |
1385 | bool removeW = (kwSize == 1 && owSize == 1); |
1386 | if (!removeH && !removeW) |
1387 | return failure(); |
1388 | |
1389 | // Get new shapes and types for all operands by removing the size-1 |
1390 | // dimension. |
1391 | using RTTBuilder = RankedTensorType::Builder; |
1392 | RankedTensorType newInputType = |
1393 | RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); |
1394 | RankedTensorType newKernelType = |
1395 | RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); |
1396 | RankedTensorType newOutputType = |
1397 | RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); |
1398 | |
1399 | // Rank-reduce operands. |
1400 | Location loc = convOp.getLoc(); |
1401 | Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( |
1402 | b&: rewriter, loc, tensor: input, targetType: newInputType); |
1403 | Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( |
1404 | b&: rewriter, loc, tensor: kernel, targetType: newKernelType); |
1405 | Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( |
1406 | b&: rewriter, loc, tensor: output, targetType: newOutputType); |
1407 | |
1408 | // Rank-reduce strides and dilations too. |
1409 | // TODO: dropDim 1-liner helper. |
1410 | auto strides = |
1411 | llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>()); |
1412 | strides.erase(strides.begin() + (removeH ? 0 : 1)); |
1413 | auto stridesAttr = rewriter.getI64VectorAttr(values: strides); |
1414 | |
1415 | auto dilations = |
1416 | llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>()); |
1417 | dilations.erase(dilations.begin() + (removeH ? 0 : 1)); |
1418 | auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations); |
1419 | |
1420 | auto conv1DOp = rewriter.create<Conv1DOp>( |
1421 | loc, newOutputType, ValueRange{newInput, newKernel}, |
1422 | ValueRange{newOutput}, stridesAttr, dilationsAttr); |
1423 | |
1424 | // Insert back. |
1425 | Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( |
1426 | b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output); |
1427 | rewriter.replaceOp(convOp, inserted); |
1428 | |
1429 | return conv1DOp; |
1430 | } |
1431 | |
1432 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp, |
1433 | Conv1DNwcWcfOp>; |
1434 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp, |
1435 | Conv1DNcwFcwOp>; |
1436 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, |
1437 | PoolingNwcSumOp>; |
1438 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, |
1439 | PoolingNcwSumOp>; |
1440 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, |
1441 | PoolingNwcMaxOp>; |
1442 | template struct linalg::DownscaleSizeOneWindowed2DConvolution< |
1443 | PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>; |
1444 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, |
1445 | PoolingNwcMinOp>; |
1446 | template struct linalg::DownscaleSizeOneWindowed2DConvolution< |
1447 | PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>; |
1448 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, |
1449 | PoolingNcwMaxOp>; |
1450 | |
1451 | FailureOr<DepthwiseConv1DNwcWcOp> |
1452 | DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( |
1453 | DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { |
1454 | if (convOp.hasPureBufferSemantics()) |
1455 | return failure(); // To be implemented. |
1456 | |
1457 | Value input = convOp.getInputs().front(); |
1458 | Value kernel = convOp.getInputs().back(); |
1459 | Value output = convOp.getOutputs().front(); |
1460 | |
1461 | auto inputType = dyn_cast<RankedTensorType>(input.getType()); |
1462 | auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); |
1463 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
1464 | |
1465 | auto kernelShape = kernelType.getShape(); |
1466 | auto outputShape = outputType.getShape(); |
1467 | |
1468 | // Only handle the case where at least one of the window dimensions is |
1469 | // of size 1. Other cases can rely on tiling to reduce to such cases. |
1470 | int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; |
1471 | int64_t ohSize = outputShape[1], owSize = outputShape[2]; |
1472 | bool removeH = (khSize == 1 && ohSize == 1); |
1473 | bool removeW = (kwSize == 1 && owSize == 1); |
1474 | if (!removeH && !removeW) |
1475 | return failure(); |
1476 | |
1477 | // Get new shapes and types for all operands by removing the size-1 |
1478 | // dimension. |
1479 | using RTTBuilder = RankedTensorType::Builder; |
1480 | RankedTensorType newInputType = |
1481 | RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); |
1482 | RankedTensorType newKernelType = |
1483 | RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); |
1484 | RankedTensorType newOutputType = |
1485 | RTTBuilder(outputType).dropDim(removeH ? 1 : 2); |
1486 | |
1487 | // Rank-reduce operands. |
1488 | Location loc = convOp.getLoc(); |
1489 | Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( |
1490 | b&: rewriter, loc, tensor: input, targetType: newInputType); |
1491 | Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( |
1492 | b&: rewriter, loc, tensor: kernel, targetType: newKernelType); |
1493 | Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( |
1494 | b&: rewriter, loc, tensor: output, targetType: newOutputType); |
1495 | |
1496 | // Rank-reduce strides and dilations too. |
1497 | // TODO: dropDim 1-liner helper. |
1498 | auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>()); |
1499 | strides.erase(strides.begin() + (removeH ? 0 : 1)); |
1500 | auto stridesAttr = rewriter.getI64VectorAttr(values: strides); |
1501 | |
1502 | auto dilations = |
1503 | llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>()); |
1504 | dilations.erase(dilations.begin() + (removeH ? 0 : 1)); |
1505 | auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations); |
1506 | |
1507 | auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( |
1508 | loc, newOutputType, ValueRange{newInput, newKernel}, |
1509 | ValueRange{newOutput}, stridesAttr, dilationsAttr); |
1510 | |
1511 | // Insert back. |
1512 | Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( |
1513 | b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output); |
1514 | rewriter.replaceOp(convOp, inserted); |
1515 | |
1516 | return conv1DOp; |
1517 | } |
1518 | |
1519 | FailureOr<Conv1DOp> |
1520 | DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp, |
1521 | PatternRewriter &rewriter) const { |
1522 | if (convOp.hasPureBufferSemantics()) |
1523 | return failure(); // To be implemented. |
1524 | |
1525 | Value input = convOp.getInputs().front(); |
1526 | Value kernel = convOp.getInputs().back(); |
1527 | Value output = convOp.getOutputs().front(); |
1528 | |
1529 | auto inputType = dyn_cast<RankedTensorType>(input.getType()); |
1530 | auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); |
1531 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
1532 | |
1533 | auto kernelShape = kernelType.getShape(); |
1534 | auto outputShape = outputType.getShape(); |
1535 | |
1536 | // Only handle the case where at least one of the window dimensions is |
1537 | // of size 1. Other cases can rely on tiling to reduce to such cases. |
1538 | int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; |
1539 | int64_t ohSize = outputShape[0], owSize = outputShape[1]; |
1540 | bool removeH = (khSize == 1 && ohSize == 1); |
1541 | bool removeW = (kwSize == 1 && owSize == 1); |
1542 | if (!removeH && !removeW) |
1543 | return failure(); |
1544 | |
1545 | // Get new shapes and types for all operands by removing the size-1 |
1546 | // dimension. |
1547 | using RTTBuilder = RankedTensorType::Builder; |
1548 | RankedTensorType newInputType = |
1549 | RTTBuilder(inputType).dropDim((removeH ? 0 : 1)); |
1550 | RankedTensorType newKernelType = |
1551 | RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); |
1552 | RankedTensorType newOutputType = |
1553 | RTTBuilder(outputType).dropDim(removeH ? 0 : 1); |
1554 | |
1555 | // Rank-reduce operands. |
1556 | Location loc = convOp.getLoc(); |
1557 | Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( |
1558 | b&: rewriter, loc, tensor: input, targetType: newInputType); |
1559 | Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( |
1560 | b&: rewriter, loc, tensor: kernel, targetType: newKernelType); |
1561 | Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( |
1562 | b&: rewriter, loc, tensor: output, targetType: newOutputType); |
1563 | |
1564 | auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType, |
1565 | ValueRange{newInput, newKernel}, |
1566 | ValueRange{newOutput}); |
1567 | |
1568 | // Insert back. |
1569 | Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( |
1570 | b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output); |
1571 | rewriter.replaceOp(convOp, inserted); |
1572 | |
1573 | return conv1DOp; |
1574 | } |
1575 | |
1576 | void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, |
1577 | PatternBenefit benefit) { |
1578 | patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp, |
1579 | Conv1DNwcWcfOp>, |
1580 | DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp, |
1581 | Conv1DNcwFcwOp>, |
1582 | DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>( |
1583 | patterns.getContext(), benefit); |
1584 | patterns.add< |
1585 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>, |
1586 | DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>, |
1587 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>, |
1588 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp, |
1589 | PoolingNwcMaxUnsignedOp>, |
1590 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>, |
1591 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp, |
1592 | PoolingNwcMinUnsignedOp>, |
1593 | DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>( |
1594 | patterns.getContext(), benefit); |
1595 | } |
1596 | |