1 | //===- Transforms.cpp - Linalg transformations as patterns ----------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements logic and helpers to expose Linalg transforms as rewrite |
10 | // patterns. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
19 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
20 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
22 | #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" |
23 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
24 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
25 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
26 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
27 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
28 | #include "mlir/IR/AffineExpr.h" |
29 | #include "mlir/IR/Matchers.h" |
30 | #include "mlir/Pass/Pass.h" |
31 | #include "mlir/Support/LLVM.h" |
32 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
33 | #include "llvm/ADT/ScopeExit.h" |
34 | #include "llvm/ADT/TypeSwitch.h" |
35 | #include "llvm/Support/Debug.h" |
36 | #include "llvm/Support/InterleavedRange.h" |
37 | #include "llvm/Support/raw_ostream.h" |
38 | #include <type_traits> |
39 | #include <utility> |
40 | |
41 | #define DEBUG_TYPE "linalg-transforms" |
42 | |
43 | using namespace mlir; |
44 | using namespace mlir::linalg; |
45 | |
46 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
47 | #define DBGSNL() (llvm::dbgs() << "\n") |
48 | |
49 | //===----------------------------------------------------------------------===// |
50 | // Transformations exposed as functional-style API calls. |
51 | //===----------------------------------------------------------------------===// |
52 | |
53 | //===----------------------------------------------------------------------===// |
54 | // peelLoop transformation. |
55 | //===----------------------------------------------------------------------===// |
56 | |
57 | /// Try to peel and canonicalize loop `op` and return the new result. |
58 | /// Also applies affine_min/max bounds simplification on the fly where relevant. |
59 | // TODO: Add support for scf.parallel and affine.for loops. |
60 | SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter, |
61 | Operation *op) { |
62 | return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) |
63 | .Case<scf::ForOp>(caseFn: [&](scf::ForOp forOp) { |
64 | scf::ForOp partialIteration; |
65 | if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp, |
66 | partialIteration))) |
67 | return partialIteration->getResults(); |
68 | assert(!partialIteration && "expected that loop was not peeled"); |
69 | return forOp->getResults(); |
70 | }) |
71 | .Default(defaultFn: [&](Operation *op) { return op->getResults(); }); |
72 | } |
73 | |
74 | /// Peel 'loops' and applies affine_min/max bounds simplification on the fly |
75 | /// where relevant. |
76 | void mlir::linalg::peelLoops(RewriterBase &rewriter, |
77 | ArrayRef<scf::ForOp> loops) { |
78 | for (auto loopOp : loops) |
79 | peelLoop(rewriter, loopOp); |
80 | } |
81 | |
82 | //===----------------------------------------------------------------------===// |
83 | // pack transformation. |
84 | //===----------------------------------------------------------------------===// |
85 | |
86 | #ifndef NDEBUG |
87 | /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim). |
88 | static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { |
89 | bool found = false; |
90 | for (AffineExpr e : map.getResults()) { |
91 | if (!e.isFunctionOfDim(position: dim)) |
92 | continue; |
93 | if (found) |
94 | return false; |
95 | found = true; |
96 | } |
97 | return true; |
98 | } |
99 | |
100 | static std::string stringifyReassocIndices(ReassociationIndicesRef ri) { |
101 | return llvm::interleaved(R: ri, Separator: ", ", /*Prefix=*/ "|", /*Suffix=*/ ""); |
102 | } |
103 | #endif // NDEBUG |
104 | |
105 | /// Return the index of the first result of `map` that is a function of |
106 | /// AffineDimExpr(dim), std::nullopt otherwise. |
107 | static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map, |
108 | int64_t dim) { |
109 | for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { |
110 | AffineExpr expr = map.getResult(idx: i); |
111 | if (!expr.isFunctionOfDim(position: dim)) |
112 | continue; |
113 | return i; |
114 | } |
115 | return std::nullopt; |
116 | } |
117 | |
118 | /// Perform one step of packing of a LinalgOp's metadata along `dim` into the |
119 | /// `newDim` at `iteratorTypes.size()` by: |
120 | /// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`. |
121 | /// 2. Appending a `newDim` to the domain of every indexing map. |
122 | /// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing |
123 | /// by potentially adding a `newDim` result to `map`. |
124 | /// The preserved invariant is that `iteratorTypes.size()` is always equal to |
125 | /// `map.getNumDims()` for every map in `indexingMaps`. |
126 | /// |
127 | /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update. |
128 | /// Return a vector that records the optional packing for each operand. |
129 | /// Return failure if the packed indexing cannot be represented with a LinalgOp. |
130 | /// |
131 | /// Further details: |
132 | /// ================ |
133 | /// The current implementation of packing (i.e. data tiling) consists of |
134 | /// rewriting a linearized strip-mined form into a higher-dimensional access. |
135 | /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite |
136 | /// `I` into `4 * i + ii`, where `0 <= ii < 4`. |
137 | /// The access is further rewritten as `A[i][f(j, k, l)][ii]`. |
138 | /// |
139 | /// This rewrite into higher dimensional access is not possible for general |
140 | /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr: |
141 | /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we |
142 | /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`. |
143 | /// The rewrite of the access would be a form not representable in Linalg: |
144 | /// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`. |
145 | /// Note however that as `J` and `ii` iterate, the accesses do not have a |
146 | /// particular alignment, so packing does not achieve alignment in this case |
147 | /// |
148 | /// In the future, we may want to consider a mixed-form that allows some |
149 | /// alignment in the presence of multiple accesses: |
150 | /// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]` |
151 | /// And would rewrite accesses as: |
152 | /// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]` |
153 | static FailureOr<SmallVector<std::optional<int64_t>>> |
154 | packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps, |
155 | SmallVectorImpl<utils::IteratorType> &iteratorTypes, |
156 | int64_t dim) { |
157 | int64_t newDim = iteratorTypes.size(); |
158 | iteratorTypes.push_back(iteratorTypes[dim]); |
159 | |
160 | SmallVector<std::optional<int64_t>> packedDimPerIndexingMap( |
161 | indexingMaps.size(), std::nullopt); |
162 | SmallVector<AffineMap> newMaps; |
163 | for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e; |
164 | ++operandIdx) { |
165 | AffineMap map = indexingMaps[operandIdx]; |
166 | |
167 | // Add the `newDim` to map whatever the case. |
168 | assert(map.getNumDims() == newDim && "num dims invariant violation"); |
169 | map = map.shiftDims(shift: 1, offset: newDim); |
170 | |
171 | // Get the at-most-1 index of the result that is a function of `dim`. |
172 | // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which |
173 | // logically chunks dimension `dim` into `K * dim + newDim`, where the |
174 | // packing factor `K` is specified separately. |
175 | assert(hasAtMostOneResultFunctionOfDim(map, dim) && |
176 | "num results invariant violation"); |
177 | auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim); |
178 | if (!maybeOperandDimensionToPack.has_value()) { |
179 | newMaps.push_back(Elt: map); |
180 | continue; |
181 | } |
182 | |
183 | // We can only pack AffineDimExpr atm. |
184 | if (!isa<AffineDimExpr>(Val: map.getResult(idx: maybeOperandDimensionToPack.value()))) |
185 | return failure(); |
186 | |
187 | // Add `newDim` to the results of the map. |
188 | map = map.insertResult(expr: Builder(map.getContext()).getAffineDimExpr(position: newDim), |
189 | pos: map.getNumResults()); |
190 | newMaps.push_back(Elt: map); |
191 | |
192 | // Record the that `operandIdx` is packed. |
193 | packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack; |
194 | } |
195 | indexingMaps = newMaps; |
196 | |
197 | return packedDimPerIndexingMap; |
198 | } |
199 | |
200 | namespace { |
201 | |
202 | /// Helper struct to encode packing along one dimension of a LinalgOp. |
203 | struct PackedOperandsDim { |
204 | OpFoldResult packedSize; |
205 | SmallVector<std::optional<int64_t>> packedDimForEachOperand; |
206 | }; |
207 | |
208 | /// Helper struct to encode packing along all dimensions of a LinalgOp. |
209 | struct PackedOperandsDimList { |
210 | void pushBack(PackedOperandsDim &&packedOperandsDims) { |
211 | spec.emplace_back(Args&: packedOperandsDims); |
212 | } |
213 | /// Return all the dims that have been packed for operand @ `operandPos`. |
214 | SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos); |
215 | /// Return all the pack sizes by which an operand @ `operandPos` is packed. |
216 | SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos); |
217 | |
218 | private: |
219 | SmallVector<PackedOperandsDim> spec; |
220 | }; |
221 | |
222 | } // namespace |
223 | |
224 | FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, |
225 | linalg::PackOp packOp, |
226 | bool lowerPadLikeWithInsertSlice) { |
227 | // 1. Filter out NYI cases. |
228 | auto packedTensorType = |
229 | cast<RankedTensorType>(packOp->getResultTypes().front()); |
230 | if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) { |
231 | return rewriter.notifyMatchFailure( |
232 | packOp, |
233 | "non-static shape NYI, needs a more powerful tensor.expand_shape op"); |
234 | } |
235 | |
236 | Location loc = packOp->getLoc(); |
237 | OpBuilder::InsertionGuard g(rewriter); |
238 | rewriter.setInsertionPoint(packOp); |
239 | |
240 | // 2. Compute the permutation vector to shuffle packed shape into the shape |
241 | // before any outer or inner permutations have been applied. |
242 | PackingMetadata packingMetadata = computePackingMetadata( |
243 | packedTensorType.getRank(), packOp.getInnerDimsPos()); |
244 | SmallVector<int64_t> packedToStripMinedShapePerm = |
245 | getPackInverseDestPerm(packOp); |
246 | |
247 | // 3. Compute the stripMinedShape: this is the packed shape before any outer |
248 | // or inner permutations have been applied. |
249 | SmallVector<int64_t> stripMinedShape(packedTensorType.getShape()); |
250 | applyPermutationToVector(inVec&: stripMinedShape, permutation: packedToStripMinedShapePerm); |
251 | |
252 | // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. |
253 | SmallVector<OpFoldResult> lows(packOp.getSourceRank(), |
254 | rewriter.getIndexAttr(0)); |
255 | SmallVector<OpFoldResult> highs(packOp.getSourceRank(), |
256 | rewriter.getIndexAttr(0)); |
257 | for (auto [pos, innerSize] : |
258 | llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) { |
259 | int outerPos = |
260 | packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]]; |
261 | OpFoldResult origSize = |
262 | tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos); |
263 | OpFoldResult outerSize = |
264 | tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos); |
265 | AffineExpr s0, d0, d1; |
266 | bindDims(rewriter.getContext(), d0, d1); |
267 | bindSymbols(rewriter.getContext(), s0); |
268 | auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1); |
269 | highs[pos] = affine::makeComposedFoldedAffineApply( |
270 | rewriter, loc, map, {outerSize, origSize, innerSize}); |
271 | } |
272 | RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( |
273 | RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), |
274 | packingMetadata.reassociations); |
275 | Value paddingValue = packOp.getPaddingValue(); |
276 | if (!paddingValue) { |
277 | paddingValue = rewriter.create<arith::ConstantOp>( |
278 | loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); |
279 | } |
280 | auto padOp = |
281 | rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows, |
282 | highs, paddingValue, /*nofold=*/false); |
283 | |
284 | LLVM_DEBUG( |
285 | DBGSNL(); DBGSNL(); |
286 | DBGS() << "insertPositions: " |
287 | << llvm::interleaved(packingMetadata.insertPositions); |
288 | DBGSNL(); DBGS() << "outerPositions: " |
289 | << llvm::interleaved(packingMetadata.outerPositions); |
290 | DBGSNL(); DBGS() << "packedShape: " |
291 | << llvm::interleaved(packedTensorType.getShape()); |
292 | DBGSNL(); DBGS() << "packedToStripMinedShapePerm: " |
293 | << llvm::interleaved(packedToStripMinedShapePerm); |
294 | DBGSNL(); |
295 | DBGS() << "reassociations: " |
296 | << llvm::interleaved(llvm::map_range( |
297 | packingMetadata.reassociations, stringifyReassocIndices)); |
298 | DBGSNL(); |
299 | DBGS() << "stripMinedShape: "<< llvm::interleaved(stripMinedShape); |
300 | DBGSNL(); DBGS() << "collapsed type: "<< collapsed; DBGSNL();); |
301 | |
302 | if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) { |
303 | // Pack ops which operate as simple pads may not produce legal |
304 | // tensor.insert_slice operations when the packed type does not rank reduce |
305 | // to the padded type. |
306 | SliceVerificationResult rankReduces = |
307 | isRankReducedType(packedTensorType, padOp.getResultType()); |
308 | |
309 | if (rankReduces == SliceVerificationResult::Success) { |
310 | // This pack is just a plain pad. |
311 | // Just insert the pad in the higher ranked tensor. |
312 | // Offsets. |
313 | SmallVector<OpFoldResult> zeros(packOp.getDestRank(), |
314 | rewriter.getIndexAttr(0)); |
315 | // Strides. |
316 | SmallVector<OpFoldResult> ones(packOp.getDestRank(), |
317 | rewriter.getIndexAttr(1)); |
318 | SmallVector<OpFoldResult> sizes = |
319 | tensor::getMixedSizes(builder&: rewriter, loc, value: packOp.getDest()); |
320 | |
321 | auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>( |
322 | loc, /*source=*/padOp, /*dest=*/packOp.getDest(), |
323 | /*offsets=*/zeros, sizes, /*strides=*/ones); |
324 | |
325 | LLVM_DEBUG(DBGS() << "insert_slice op: "<< insertSliceOp; DBGSNL();); |
326 | |
327 | rewriter.replaceOp(packOp, insertSliceOp->getResults()); |
328 | |
329 | return LowerPackResult{padOp, /*reshapeOp=*/nullptr, |
330 | /*transposeOp=*/nullptr}; |
331 | } |
332 | } |
333 | |
334 | // 5. Expand from the padded result to the stripMinedShape. |
335 | auto expandShapeResultType = |
336 | RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); |
337 | auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>( |
338 | loc, expandShapeResultType, padOp.getResult(), |
339 | packingMetadata.reassociations); |
340 | |
341 | // 6. Transpose stripMinedShape to packedShape. |
342 | SmallVector<int64_t> transpPerm = |
343 | invertPermutationVector(permutation: packedToStripMinedShapePerm); |
344 | auto transposeOp = rewriter.create<linalg::TransposeOp>( |
345 | loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); |
346 | |
347 | LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); |
348 | DBGS() << "reshape op: "<< reshapeOp; DBGSNL(); |
349 | DBGS() << "transpPerm: "<< llvm::interleaved(transpPerm); |
350 | DBGSNL(); DBGS() << "transpose op: "<< transposeOp; DBGSNL();); |
351 | |
352 | // 7. Replace packOp by transposeOp. |
353 | rewriter.replaceOp(packOp, transposeOp->getResults()); |
354 | |
355 | return LowerPackResult{padOp, reshapeOp, transposeOp}; |
356 | } |
357 | |
358 | FailureOr<LowerUnPackOpResult> |
359 | linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, |
360 | bool lowerUnpadLikeWithExtractSlice) { |
361 | Location loc = unPackOp->getLoc(); |
362 | OpBuilder::InsertionGuard g(rewriter); |
363 | rewriter.setInsertionPoint(unPackOp); |
364 | |
365 | RankedTensorType packedTensorType = unPackOp.getSourceType(); |
366 | int64_t packedRank = packedTensorType.getRank(); |
367 | |
368 | OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); |
369 | auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType()); |
370 | if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) { |
371 | // This unpack is just a plain unpad. |
372 | // Just extract the slice from the higher ranked tensor. |
373 | ArrayRef<int64_t> destShape = destTensorType.getShape(); |
374 | // The inner dimensions stay the same as the destination tensor, but the |
375 | // outer ones are additional 1s. |
376 | SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one); |
377 | sizes.append(tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest())); |
378 | |
379 | auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( |
380 | loc, destTensorType, unPackOp.getSource(), |
381 | SmallVector<OpFoldResult>(packedRank, zero), sizes, |
382 | SmallVector<OpFoldResult>(packedRank, one)); |
383 | |
384 | rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); |
385 | |
386 | return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr, |
387 | /*reshapeOp=*/nullptr, extractSliceOp}; |
388 | } |
389 | |
390 | // 1. Compute the permutation vector to shuffle packed shape into the shape |
391 | // before any outer or inner permutations have been applied. |
392 | PackingMetadata packingMetadata; |
393 | SmallVector<int64_t> packedToStripMinedShapePerm = |
394 | getUnPackInverseSrcPerm(unPackOp, packingMetadata); |
395 | |
396 | // 2. Compute the stripMinedShape: this is the packed shape without outer and |
397 | // inner permutations. |
398 | SmallVector<int64_t> stripMinedShape(packedTensorType.getShape()); |
399 | applyPermutationToVector(inVec&: stripMinedShape, permutation: packedToStripMinedShapePerm); |
400 | |
401 | // 3. Transpose packedShape to stripMinedShape. |
402 | RankedTensorType stripMinedTensorType = |
403 | RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); |
404 | RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( |
405 | stripMinedTensorType, packingMetadata.reassociations); |
406 | |
407 | // Get dynamic dims from input tensor based on packedToStripMinedShapePerm |
408 | // permutation. |
409 | SmallVector<OpFoldResult, 4> dims = |
410 | tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getSource()); |
411 | applyPermutationToVector(inVec&: dims, permutation: packedToStripMinedShapePerm); |
412 | auto emptyOp = rewriter.create<tensor::EmptyOp>( |
413 | loc, dims, stripMinedTensorType.getElementType()); |
414 | auto transposeOp = rewriter.create<linalg::TransposeOp>( |
415 | loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm); |
416 | |
417 | LLVM_DEBUG( |
418 | DBGSNL(); DBGSNL(); |
419 | DBGS() << "insertPositions: " |
420 | << llvm::interleaved(packingMetadata.insertPositions); |
421 | DBGSNL(); DBGS() << "packedShape: " |
422 | << llvm::interleaved(packedTensorType.getShape()); |
423 | DBGSNL(); DBGS() << "packedToStripMinedShapePerm: " |
424 | << llvm::interleaved(packedToStripMinedShapePerm); |
425 | DBGSNL(); |
426 | DBGS() << "reassociations: " |
427 | << llvm::interleaved(llvm::map_range( |
428 | packingMetadata.reassociations, stringifyReassocIndices)); |
429 | DBGSNL(); |
430 | DBGS() << "stripMinedShape: "<< llvm::interleaved(stripMinedShape); |
431 | DBGSNL(); DBGS() << "collapsed type: "<< collapsedType; DBGSNL();); |
432 | |
433 | // 4. Collapse from the stripMinedShape to the padded result. |
434 | auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>( |
435 | loc, collapsedType, transposeOp->getResult(0), |
436 | packingMetadata.reassociations); |
437 | |
438 | // 5. ExtractSlice. |
439 | int64_t destRank = destTensorType.getRank(); |
440 | auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( |
441 | loc, destTensorType, reshapeOp->getResult(0), |
442 | SmallVector<OpFoldResult>(destRank, zero), |
443 | tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest()), |
444 | SmallVector<OpFoldResult>(destRank, one)); |
445 | |
446 | // 6. Inject a copy to preserve DPS. |
447 | auto copyOp = rewriter.create<linalg::CopyOp>( |
448 | loc, extractSliceOp->getResult(0), unPackOp.getDest()); |
449 | |
450 | // 7. Replace unPackOp by copyOp. |
451 | rewriter.replaceOp(unPackOp, copyOp->getResults()); |
452 | |
453 | return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; |
454 | } |
455 | |
456 | SmallVector<int64_t> |
457 | PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) { |
458 | SmallVector<int64_t> res; |
459 | for (auto &i : spec) { |
460 | if (!i.packedDimForEachOperand[operandPos].has_value()) |
461 | continue; |
462 | res.push_back(Elt: i.packedDimForEachOperand[operandPos].value()); |
463 | } |
464 | return res; |
465 | } |
466 | |
467 | SmallVector<OpFoldResult> |
468 | PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) { |
469 | SmallVector<OpFoldResult> res; |
470 | for (auto &i : spec) { |
471 | if (!i.packedDimForEachOperand[operandPos].has_value()) |
472 | continue; |
473 | res.push_back(Elt: i.packedSize); |
474 | } |
475 | return res; |
476 | } |
477 | |
478 | /// Implement packing of a single LinalgOp by performing packing by |
479 | /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator. |
480 | /// Return the packed Linalg op on success, failure otherwise. |
481 | FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, |
482 | linalg::LinalgOp linalgOp, |
483 | ArrayRef<OpFoldResult> packedSizes) { |
484 | if (packedSizes.size() != linalgOp.getNumLoops()) { |
485 | return rewriter.notifyMatchFailure(linalgOp, |
486 | "incorrect number of pack sizes"); |
487 | } |
488 | |
489 | Location loc = linalgOp->getLoc(); |
490 | SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); |
491 | SmallVector<utils::IteratorType> iteratorTypes = |
492 | linalgOp.getIteratorTypesArray(); |
493 | LLVM_DEBUG(DBGS() << "Start packing: "<< linalgOp << "\n" |
494 | << "maps: "<< llvm::interleaved(indexingMaps) << "\n" |
495 | << "iterators: "<< llvm::interleaved(iteratorTypes) |
496 | << "\n"); |
497 | |
498 | SmallVector<linalg::PackOp> packOps; |
499 | SmallVector<linalg::UnPackOp> unPackOps; |
500 | // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i]. |
501 | PackedOperandsDimList listOfPackedOperandsDim; |
502 | for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) { |
503 | std::optional<int64_t> maybeConstant = getConstantIntValue(ofr: packedSizes[i]); |
504 | // Skip tile sizes explicitly set to 0. |
505 | if (maybeConstant.has_value() && maybeConstant.value() == 0) |
506 | continue; |
507 | |
508 | PackedOperandsDim packedOperandsDims; |
509 | packedOperandsDims.packedSize = packedSizes[i]; |
510 | FailureOr<SmallVector<std::optional<int64_t>>> |
511 | maybePackedDimForEachOperand = |
512 | packLinalgMetadataOnce(indexingMaps, iteratorTypes, i); |
513 | if (failed(Result: maybePackedDimForEachOperand)) |
514 | return failure(); |
515 | packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; |
516 | listOfPackedOperandsDim.pushBack(packedOperandsDims: std::move(packedOperandsDims)); |
517 | |
518 | LLVM_DEBUG( |
519 | DBGS() << "++++ After pack size #"<< i << ": "<< packedSizes[i] |
520 | << "\n" |
521 | << "maps: "<< llvm::interleaved(indexingMaps) << "\n" |
522 | << "iterators: "<< llvm::interleaved(iteratorTypes) << "\n" |
523 | << "packedDimForEachOperand: " |
524 | << llvm::interleaved(packedOperandsDims.packedDimForEachOperand) |
525 | << "\n"); |
526 | } |
527 | |
528 | // Step 2. Propagate packing to all LinalgOp operands. |
529 | SmallVector<Value> inputsAndInits, results; |
530 | SmallVector<OpOperand *> initOperands = |
531 | llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable())); |
532 | SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands(); |
533 | for (const auto &operandsList : {inputOperands, initOperands}) { |
534 | for (OpOperand *opOperand : operandsList) { |
535 | int64_t pos = opOperand->getOperandNumber(); |
536 | Value operand = opOperand->get(); |
537 | SmallVector<int64_t> innerPos = |
538 | listOfPackedOperandsDim.extractPackedDimsForOperand(pos); |
539 | SmallVector<OpFoldResult> innerPackSizes = |
540 | listOfPackedOperandsDim.extractPackSizesForOperand(pos); |
541 | LLVM_DEBUG(DBGS() << "operand: "<< operand << "\n" |
542 | << "innerPos: "<< llvm::interleaved(innerPos) << "\n" |
543 | << "innerPackSizes: " |
544 | << llvm::interleaved(innerPackSizes) << "\n"); |
545 | if (innerPackSizes.empty()) { |
546 | inputsAndInits.push_back(operand); |
547 | continue; |
548 | } |
549 | Value dest = linalg::PackOp::createDestinationTensor( |
550 | rewriter, loc, operand, innerPackSizes, innerPos, |
551 | /*outerDimsPerm=*/{}); |
552 | ShapedType operandType = cast<ShapedType>(operand.getType()); |
553 | bool areConstantTiles = |
554 | llvm::all_of(innerPackSizes, [](OpFoldResult tile) { |
555 | return getConstantIntValue(tile).has_value(); |
556 | }); |
557 | if (areConstantTiles && operandType.hasStaticShape() && |
558 | !linalg::PackOp::requirePaddingValue( |
559 | operandType.getShape(), innerPos, |
560 | cast<ShapedType>(dest.getType()).getShape(), {}, |
561 | innerPackSizes)) { |
562 | packOps.push_back(rewriter.create<linalg::PackOp>( |
563 | loc, operand, dest, innerPos, innerPackSizes)); |
564 | } else { |
565 | // TODO: value of the padding attribute should be determined by |
566 | // consumers. |
567 | auto zeroAttr = |
568 | rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); |
569 | Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); |
570 | packOps.push_back(rewriter.create<linalg::PackOp>( |
571 | loc, operand, dest, innerPos, innerPackSizes, zero)); |
572 | } |
573 | inputsAndInits.push_back(packOps.back()); |
574 | } |
575 | } |
576 | |
577 | // Step 3. Build the packed op, use the type of `inits` as result types. |
578 | ValueRange inputs = |
579 | ValueRange{inputsAndInits}.take_front(n: linalgOp.getNumDpsInputs()); |
580 | ValueRange inits = |
581 | ValueRange{inputsAndInits}.take_back(n: linalgOp.getNumDpsInits()); |
582 | auto packedLinalgOp = rewriter.create<linalg::GenericOp>( |
583 | linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps, |
584 | iteratorTypes); |
585 | packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0)); |
586 | |
587 | // Step 4. Propagate packing to all the op results. |
588 | for (OpResult result : packedLinalgOp->getResults()) { |
589 | int64_t resultNum = result.getResultNumber(); |
590 | linalg::PackOp maybePackedInit = |
591 | inits[resultNum].getDefiningOp<linalg::PackOp>(); |
592 | if (!maybePackedInit) { |
593 | results.push_back(result); |
594 | continue; |
595 | } |
596 | // Build the symmetrical UnPackOp to the existing PackOp. |
597 | unPackOps.push_back(rewriter.create<linalg::UnPackOp>( |
598 | packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), |
599 | maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); |
600 | results.push_back(unPackOps.back()); |
601 | } |
602 | |
603 | // Step 5. Replace `linalgOp`. |
604 | rewriter.replaceOp(linalgOp, results); |
605 | |
606 | // Return packedLinalgOp. |
607 | return PackResult{packOps, |
608 | cast<linalg::LinalgOp>(packedLinalgOp.getOperation()), |
609 | unPackOps}; |
610 | } |
611 | |
612 | //===----------------------------------------------------------------------===// |
613 | // packTranspose transformation. |
614 | //===----------------------------------------------------------------------===// |
615 | |
616 | /// Return a copy of `tensorType` after permutation by `permutationVector`. |
617 | // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder |
618 | // but this would introduce a dependence on Dialect in IR. |
619 | // TODO: Restructure. |
620 | static RankedTensorType permuteShape(RankedTensorType tensorType, |
621 | ArrayRef<int64_t> permutationVector) { |
622 | SmallVector<int64_t> shape(tensorType.getShape()); |
623 | applyPermutationToVector(inVec&: shape, permutation: permutationVector); |
624 | return RankedTensorType::Builder(tensorType).setShape(shape); |
625 | } |
626 | |
627 | /// Return a new GenericOp obtained by transposing opOperand by the permutation |
628 | /// vector: |
629 | /// - the corresponding indexing map is transposed by `permutation` |
630 | /// - the corresponding operand value is replaced by `transposedValue` |
631 | /// `linalgOp` is replaced by the return op in the process. |
632 | /// Asserts that `transposedValue` is of the proper transposed ShapedType. |
633 | static LinalgOp transposeOneLinalgOperandAndReplace( |
634 | RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, |
635 | ArrayRef<int64_t> permutation, Value transposedValue) { |
636 | // Sanity check the operand. |
637 | assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand"); |
638 | |
639 | // Sanity check of the expected transposed tensor type. |
640 | auto tensorType = permuteShape( |
641 | cast<RankedTensorType>(opOperand.get().getType()), permutation); |
642 | (void)tensorType; |
643 | assert(tensorType == transposedValue.getType() && |
644 | "expected tensor type mismatch"); |
645 | |
646 | // Compute the transposed indexing map. |
647 | // Sigh unsigned pollution. |
648 | SmallVector<unsigned> tmpTransposition = llvm::to_vector( |
649 | Range: llvm::map_range(C&: permutation, F: [](int64_t i) -> unsigned { return i; })); |
650 | AffineMap permutationMap = |
651 | AffineMap::getPermutationMap(permutation: tmpTransposition, context: rewriter.getContext()); |
652 | AffineMap transposedMap = |
653 | permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand)); |
654 | |
655 | // Set the transposed indexing map in the proper position. |
656 | SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); |
657 | indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap; |
658 | // Set the transposedValue in the proper operand position. |
659 | SmallVector<Value> operands = linalgOp->getOperands(); |
660 | operands[opOperand.getOperandNumber()] = transposedValue; |
661 | |
662 | ValueRange operandsRef(operands); |
663 | auto transposedGenericOp = rewriter.create<linalg::GenericOp>( |
664 | /*location=*/linalgOp->getLoc(), |
665 | /*resultTensorTypes=*/ |
666 | operandsRef.drop_front(n: linalgOp.getNumDpsInputs()).getTypes(), |
667 | /*inputs=*/operandsRef.take_front(n: linalgOp.getNumDpsInputs()), |
668 | /*outputs=*/operandsRef.drop_front(n: linalgOp.getNumDpsInputs()), |
669 | /*indexingMaps=*/indexingMaps, |
670 | /*iteratorTypes=*/linalgOp.getIteratorTypesArray()); |
671 | transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0)); |
672 | rewriter.replaceOp(linalgOp, transposedGenericOp->getResults()); |
673 | |
674 | return cast<linalg::LinalgOp>(transposedGenericOp.getOperation()); |
675 | } |
676 | |
677 | FailureOr<PackTransposeResult> |
678 | linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, |
679 | linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, |
680 | ArrayRef<int64_t> outerPerm, |
681 | ArrayRef<int64_t> innerPerm) { |
682 | Location loc = linalgOp.getLoc(); |
683 | |
684 | // Step 1. Transpose packOp. |
685 | rewriter.setInsertionPoint(packOp); |
686 | linalg::PackOp transposedPackOp = |
687 | packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm); |
688 | |
689 | if (!packOp.getResult().hasOneUse()) |
690 | return rewriter.notifyMatchFailure(linalgOp, "expect single pack use"); |
691 | |
692 | OpOperand &packUse = *packOp->getUses().begin(); |
693 | if (packUse.getOwner() != linalgOp) { |
694 | return rewriter.notifyMatchFailure( |
695 | linalgOp, "not a single use by the LinalgOp target"); |
696 | } |
697 | if (maybeUnPackOp && |
698 | (!linalgOp.isDpsInit(&packUse) || |
699 | maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) { |
700 | return rewriter.notifyMatchFailure(linalgOp, |
701 | "not produced by the LinalgOp target"); |
702 | } |
703 | |
704 | // Step 2. Transpose linalgOp. |
705 | // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the |
706 | // identity. Don't rely on it. |
707 | int64_t numLeadingDims = packOp.getSourceRank(); |
708 | int64_t numTrailingDims = packOp.getInnerDimsPos().size(); |
709 | // Step 2.a. Compute the permutation on the whole operand. |
710 | // Leading part just reuse the outerPerm. |
711 | SmallVector<int64_t> permutation(outerPerm); |
712 | if (permutation.empty()) |
713 | llvm::append_range(C&: permutation, R: llvm::seq<int64_t>(Begin: 0, End: numLeadingDims)); |
714 | // Trailing part needs to reindex positions by `numLeadingDims`. |
715 | if (innerPerm.empty()) { |
716 | llvm::append_range( |
717 | C&: permutation, |
718 | R: llvm::seq<int64_t>(Begin: numLeadingDims, End: numLeadingDims + numTrailingDims)); |
719 | } else { |
720 | llvm::append_range(permutation, |
721 | llvm::map_range(innerPerm, [&](int64_t pos) { |
722 | return numLeadingDims + pos; |
723 | })); |
724 | } |
725 | if (!isPermutationVector(interchange: permutation)) |
726 | return rewriter.notifyMatchFailure(linalgOp, "invalid permutation"); |
727 | |
728 | // Step 2.b. Save the transposedPackUse operand number in case we need to |
729 | // get the tied OpResult after `linalgOp` has been replaced. |
730 | int64_t packUseOperandNumber = packUse.getOperandNumber(); |
731 | // Step 2.c. Actually perform the transposition. |
732 | rewriter.setInsertionPoint(linalgOp); |
733 | linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace( |
734 | rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult()); |
735 | |
736 | // Step 3. Maybe transpose unPackOp. |
737 | linalg::UnPackOp transposedUnPackOp; |
738 | if (maybeUnPackOp) { |
739 | OpOperand &opOperand = |
740 | transposedLinalgOp->getOpOperand(packUseOperandNumber); |
741 | OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand); |
742 | rewriter.setInsertionPoint(maybeUnPackOp); |
743 | transposedUnPackOp = maybeUnPackOp.createTransposedClone( |
744 | rewriter, loc, transposedResult, innerPerm, outerPerm); |
745 | |
746 | rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults()); |
747 | } |
748 | |
749 | // Step 4. Finally, replace packOp now that we don't need it anymore. |
750 | rewriter.replaceOp(packOp, transposedPackOp->getResults()); |
751 | |
752 | return PackTransposeResult{transposedPackOp, transposedLinalgOp, |
753 | transposedUnPackOp}; |
754 | } |
755 | |
756 | //===----------------------------------------------------------------------===// |
757 | // packMatmulGreedily transformation. |
758 | //===----------------------------------------------------------------------===// |
759 | |
760 | /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m |
761 | /// and n are proper parallel dimensions and k is a proper reduction |
762 | /// dimension. Packing occurs by rewriting the op as a linalg.generic and |
763 | /// calling linalg::pack by `mnkPackedSizes`. The order of the packed |
764 | /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2} |
765 | /// to reorder {m, n, k} into one of the 8 possible forms. The outer |
766 | /// dimensions of the operands are not permuted at this time, this is left for |
767 | /// future work. |
768 | FailureOr<PackResult> |
769 | linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, |
770 | ArrayRef<OpFoldResult> mnkPackedSizes, |
771 | ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf, |
772 | ArrayRef<int64_t> mnkOrder) { |
773 | assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes"); |
774 | assert((mnkPaddedSizesNextMultipleOf.empty() || |
775 | mnkPaddedSizesNextMultipleOf.size() == 3) && |
776 | "num of packing sizes next multiple should be empty or of size 3"); |
777 | assert(mnkOrder.size() == 3 && "unexpected mnkOrder size"); |
778 | assert(isPermutationVector(mnkOrder) && "expected a permutation"); |
779 | |
780 | int64_t numLoops = linalgOp.getNumLoops(); |
781 | if (numLoops <= 2) { |
782 | LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got " |
783 | << numLoops << "\nin: "<< linalgOp << "\n"); |
784 | return rewriter.notifyMatchFailure( |
785 | linalgOp, "need 3+ loops to find a matmul to pack"); |
786 | } |
787 | |
788 | // Locally adjust the desired iterator position of mnk and packing sizes. |
789 | int64_t numPackedDims = mnkPackedSizes.size(); |
790 | SmallVector<int64_t> mmnnkkPos(numPackedDims); |
791 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) |
792 | mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i]; |
793 | SmallVector<OpFoldResult> packedSizes(numPackedDims); |
794 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) |
795 | packedSizes[mnkOrder[i]] = mnkPackedSizes[i]; |
796 | SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims); |
797 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) { |
798 | paddedSizesNextMultipleOf[mnkOrder[i]] = |
799 | mnkPaddedSizesNextMultipleOf.empty() ? 0 |
800 | : mnkPaddedSizesNextMultipleOf[i]; |
801 | } |
802 | |
803 | // 1. Infer dims that are important for matmul. |
804 | FailureOr<ContractionDimensions> maybeDimensions = |
805 | inferContractionDims(linalgOp); |
806 | if (failed(Result: maybeDimensions)) { |
807 | LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: "<< linalgOp |
808 | << "\n"); |
809 | return rewriter.notifyMatchFailure(linalgOp, |
810 | "couldn't infer matmul iterators"); |
811 | } |
812 | |
813 | // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most |
814 | // minor iterators. In cases with multiple options for m, n, k bias towards |
815 | // the most minor embedding. |
816 | // If we wanted a different normalization order, this is where it would have |
817 | // to plug a heuristic. |
818 | int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(), |
819 | kPos = maybeDimensions->k.back(); |
820 | LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); |
821 | DBGS() << "Start packing generic op greedily with (m@"<< mPos |
822 | << ", n@"<< nPos << ", k@"<< kPos << "): "<< linalgOp |
823 | << "\n";); |
824 | |
825 | // 2.a. Rewrite as a generic. |
826 | auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation()); |
827 | if (!genericOp) { |
828 | FailureOr<GenericOp> generalizeResult = |
829 | generalizeNamedOp(rewriter, linalgOp); |
830 | assert(succeeded(generalizeResult) && "unexpected failure generalizing op"); |
831 | genericOp = *generalizeResult; |
832 | } |
833 | |
834 | // 2.b. Interchange to move the dimensions (k, m, n) as most-minor |
835 | // iterators. Note that this only normalized the iteration order and does |
836 | // not change the indexings of any operand. |
837 | SmallVector<int64_t> permutation = |
838 | computePermutationVector(permSize: numLoops, positions: {mPos, nPos, kPos}, desiredPositions: mmnnkkPos); |
839 | LLVM_DEBUG(DBGS() << "perm: "<< llvm::interleaved(permutation) << "\n"); |
840 | // Sign .. unsigned pollution. |
841 | SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end()); |
842 | FailureOr<GenericOp> interchangeResult = |
843 | interchangeGenericOp(rewriter, genericOp, unsignedPerm); |
844 | assert(succeeded(interchangeResult) && "unexpected failure interchanging op"); |
845 | genericOp = *interchangeResult; |
846 | LLVM_DEBUG(DBGS() << "Generalized Op to pack: "<< genericOp << "\n";); |
847 | |
848 | // At this point, the op iterators are normalized to {leading, k, m, n}. |
849 | // The layouts induced by packing will always be: |
850 | // - LHS{leading_lhs, kk, mm} |
851 | // - RHS{leading_rhs, kk, nn} |
852 | // - RES{leading_res, mm, nn} |
853 | // If we wanted to change the packed order, we would reorder (k, m, n) to |
854 | // something else above. |
855 | // |
856 | // Additional permutations of the outer dims of the operands (i.e. |
857 | // leading_lhs, leading_rhs and leading_res) could follow by computing the |
858 | // desired outerPerm for each operand. |
859 | // This is left for future work. |
860 | |
861 | // TODO: this creates too much IR, go use reifyResultShapes. |
862 | SmallVector<Range, 4> loopRanges = |
863 | cast<LinalgOp>(genericOp.getOperation()) |
864 | .createLoopRanges(rewriter, genericOp.getLoc()); |
865 | |
866 | // Add leading zeros to match numLoops, we only pack the last 3 dimensions |
867 | // post interchange. |
868 | LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: " |
869 | << llvm::interleaved(paddedSizesNextMultipleOf) << "\n" |
870 | << "loopRanges: " |
871 | << llvm::interleaved(llvm::map_range( |
872 | loopRanges, [](Range r) { return r.size; })) |
873 | << "\n"); |
874 | SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(), |
875 | rewriter.getIndexAttr(0)); |
876 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) { |
877 | if (paddedSizesNextMultipleOf[i] == 0) { |
878 | adjustedPackedSizes.push_back(Elt: packedSizes[i]); |
879 | continue; |
880 | } |
881 | AffineExpr d0, s0; |
882 | bindDims(ctx: rewriter.getContext(), exprs&: d0); |
883 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0); |
884 | adjustedPackedSizes.push_back(Elt: affine::makeComposedFoldedAffineApply( |
885 | rewriter, genericOp->getLoc(), d0.ceilDiv(other: s0) * s0, |
886 | {loopRanges[adjustedPackedSizes.size()].size, |
887 | rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])})); |
888 | } |
889 | LLVM_DEBUG(DBGS() << "adjustedPackedSizes: " |
890 | << llvm::interleaved(adjustedPackedSizes) << "\n"); |
891 | |
892 | // TODO: If we wanted to give the genericOp a name after packing, after |
893 | // calling `pack` would be a good time. One would still need to check that |
894 | // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we |
895 | // also allow degenerate matmul cases (i.e. matvec, dot). |
896 | return pack(rewriter, genericOp, adjustedPackedSizes); |
897 | } |
898 | |
899 | //===----------------------------------------------------------------------===// |
900 | // Transformations exposed as rewrite patterns. |
901 | //===----------------------------------------------------------------------===// |
902 | |
903 | LinalgTilingOptions & |
904 | mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { |
905 | assert(!tileSizeComputationFunction && "tile sizes already set"); |
906 | SmallVector<int64_t, 4> tileSizes(ts); |
907 | tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { |
908 | OpBuilder::InsertionGuard guard(b); |
909 | b.setInsertionPointToStart( |
910 | &op->getParentOfType<func::FuncOp>().getBody().front()); |
911 | return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { |
912 | Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); |
913 | return v; |
914 | })); |
915 | }; |
916 | return *this; |
917 | } |
918 | |
919 | LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( |
920 | memref::CopyOp copyOp, PatternRewriter &rewriter) const { |
921 | return vectorizeCopy(rewriter, copyOp); |
922 | } |
923 | |
924 | /// Filling `dest` using FillOp constant padding value if possible. |
925 | /// Otherwise, generate a tensor::GenerateOp. |
926 | Value DecomposePadOpPattern::createFillOrGenerateOp( |
927 | RewriterBase &rewriter, tensor::PadOp padOp, Value dest, |
928 | const SmallVector<Value> &dynSizes) const { |
929 | auto padValue = padOp.getConstantPaddingValue(); |
930 | if (padValue) { |
931 | // Move the padding value defined inside the PadOp block to outside. |
932 | if (padValue.getParentBlock() == &padOp.getRegion().front()) |
933 | rewriter.moveOpBefore(padValue.getDefiningOp(), padOp); |
934 | return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); |
935 | } |
936 | |
937 | // Fill could not be optimized: Lower to tensor::GenerateOp with region. |
938 | auto generateOp = rewriter.create<tensor::GenerateOp>( |
939 | padOp.getLoc(), padOp.getResultType(), dynSizes); |
940 | // Copy region to new op. |
941 | IRMapping bvm; |
942 | padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm); |
943 | return generateOp; |
944 | } |
945 | |
946 | LogicalResult |
947 | DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp, |
948 | PatternRewriter &rewriter) const { |
949 | // Given an OpFoldResult, return an index-typed value. |
950 | auto getIdxValue = [&](OpFoldResult ofr) { |
951 | if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr)) |
952 | return val; |
953 | return rewriter |
954 | .create<arith::ConstantIndexOp>( |
955 | padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt()) |
956 | .getResult(); |
957 | }; |
958 | |
959 | auto resultType = padOp.getResultType(); |
960 | // Compute size of EmptyOp. Any combination of static/dynamic is supported. |
961 | SmallVector<Value> dynSizes; |
962 | SmallVector<int64_t> staticSizes; |
963 | for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { |
964 | if (resultType.isDynamicDim(dim)) { |
965 | auto srcSize = getIdxValue(tensor::getMixedSize(builder&: rewriter, loc: padOp.getLoc(), |
966 | value: padOp.getSource(), dim)); |
967 | // Add low and high padding value. |
968 | auto plusLow = rewriter.createOrFold<arith::AddIOp>( |
969 | padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); |
970 | auto plusHigh = rewriter.createOrFold<arith::AddIOp>( |
971 | padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); |
972 | dynSizes.push_back(Elt: plusHigh); |
973 | } |
974 | staticSizes.push_back(Elt: resultType.getDimSize(dim)); |
975 | } |
976 | |
977 | // Init tensor and fill it with padding. |
978 | Value emptyTensor = rewriter.create<tensor::EmptyOp>( |
979 | padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes); |
980 | Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes); |
981 | |
982 | // Generate a InsertSliceOp for copying the PadOp source. |
983 | auto sourceType = padOp.getSourceType(); |
984 | // Compute size of source of tensor::PadOp. |
985 | SmallVector<OpFoldResult> srcSizes = |
986 | tensor::getMixedSizes(builder&: rewriter, loc: padOp.getLoc(), value: padOp.getSource()); |
987 | // Strides of InsertSliceOp are all 1. |
988 | SmallVector<OpFoldResult> strides(sourceType.getRank(), |
989 | rewriter.getIndexAttr(1)); |
990 | rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( |
991 | padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes, |
992 | strides); |
993 | |
994 | return success(); |
995 | } |
996 | |
997 | LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( |
998 | tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { |
999 | if (!sliceOp.hasUnitStride()) |
1000 | return failure(); |
1001 | |
1002 | auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>(); |
1003 | if (!padOp) |
1004 | return failure(); |
1005 | |
1006 | bool zeroSliceGuard = true; |
1007 | if (controlFn) { |
1008 | if (std::optional<bool> control = controlFn(sliceOp)) |
1009 | zeroSliceGuard = *control; |
1010 | else |
1011 | return failure(); |
1012 | } |
1013 | |
1014 | FailureOr<TilingResult> tilingResult = |
1015 | tensor::bubbleUpPadSlice(b&: rewriter, padOp: padOp, offsets: sliceOp.getMixedOffsets(), |
1016 | sizes: sliceOp.getMixedSizes(), generateZeroSliceGuard: zeroSliceGuard); |
1017 | if (failed(Result: tilingResult)) |
1018 | return failure(); |
1019 | |
1020 | RankedTensorType sourceType = sliceOp.getSourceType(); |
1021 | RankedTensorType resultType = sliceOp.getResultType(); |
1022 | |
1023 | // If the extract_slice is not rank-reduced, all shapes are static and the |
1024 | // data source is actually used. Rewrite into pad(extract_slice(x)). |
1025 | if (sourceType.getRank() == resultType.getRank()) { |
1026 | rewriter.replaceOp(sliceOp, tilingResult->tiledValues); |
1027 | return success(); |
1028 | } |
1029 | |
1030 | // Handle rank-reduced slice by creating another extract_slice op. |
1031 | Value rankReduced = tensor::createCanonicalRankReducingExtractSliceOp( |
1032 | b&: rewriter, loc: sliceOp.getLoc(), tensor: tilingResult->tiledValues[0], targetType: resultType); |
1033 | |
1034 | rewriter.replaceOp(sliceOp, rankReduced); |
1035 | return success(); |
1036 | } |
1037 | |
1038 | /// If padding value is set, returns a tensor.pad Op for the source tensor, |
1039 | /// with the output shape matching the output of `packOp`. Otherwise, returns |
1040 | /// the source directly. |
1041 | /// |
1042 | /// This method assumes that all outer dims for this pack Op are 1. |
1043 | static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, |
1044 | linalg::PackOp packOp) { |
1045 | Value input = packOp.getSource(); |
1046 | if (!packOp.getPaddingValue()) { |
1047 | return input; |
1048 | } |
1049 | |
1050 | assert(llvm::all_of(packOp.getAllOuterDims(), |
1051 | [](int64_t val) { return val == 1; }) && |
1052 | "some outer dims are != 1"); |
1053 | |
1054 | Location loc = packOp.getLoc(); |
1055 | ShapedType inputType = packOp.getSourceType(); |
1056 | int64_t inputRank = inputType.getRank(); |
1057 | |
1058 | DenseMap<int64_t, OpFoldResult> tileAndPosMapping = |
1059 | packOp.getDimAndTileMapping(); |
1060 | |
1061 | // The sizes of dynamic tiles |
1062 | SmallVector<Value> dynamicTileSizes; |
1063 | |
1064 | // Collect dims for the padded shape. |
1065 | SmallVector<int64_t> paddedShape; |
1066 | for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) { |
1067 | // 1. Non-tiled outer dims. |
1068 | // These dims should be 1 and we simply preserve them. |
1069 | if (!tileAndPosMapping.count(Val: dimIdx)) { |
1070 | int64_t inputDimSize = inputType.getDimSize(dimIdx); |
1071 | assert(inputDimSize == 1 && |
1072 | "with all outer dims == 1, this non-tiled input dim should be 1!"); |
1073 | paddedShape.push_back(Elt: inputDimSize); |
1074 | continue; |
1075 | } |
1076 | |
1077 | // 2. Tiled outer dims |
1078 | // As all outer dims == 1, it is safe to use the tile size for the padded |
1079 | // shape. |
1080 | OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(Val: dimIdx); |
1081 | |
1082 | // 2.1 Static tile sizes |
1083 | std::optional<int64_t> cstTileSize = getConstantIntValue(ofr: tileSizeForDim); |
1084 | if (cstTileSize.has_value()) { |
1085 | paddedShape.push_back(Elt: cstTileSize.value()); |
1086 | continue; |
1087 | } |
1088 | |
1089 | // 2.2 Dynamic tile sizes |
1090 | paddedShape.push_back(ShapedType::kDynamic); |
1091 | |
1092 | // Get the value that holds the dynamic size. |
1093 | dynamicTileSizes.push_back(Elt: llvm::dyn_cast<Value>(Val&: tileSizeForDim)); |
1094 | } |
1095 | auto resultType = |
1096 | RankedTensorType::get(paddedShape, inputType.getElementType()); |
1097 | return tensor::createPadHighOp(resType: resultType, source: input, pad: packOp.getPaddingValue(), |
1098 | /*nofold=*/false, loc, builder, |
1099 | dynOutDims: dynamicTileSizes); |
1100 | } |
1101 | |
1102 | // Normalizes a permutation on a higher rank space to its actual size, e.g. |
1103 | // perm = [1, 4, 2] |
1104 | // becomes |
1105 | // norm = [0, 2, 1] |
1106 | static SmallVector<int64_t> |
1107 | getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) { |
1108 | constexpr int64_t kNonTiledMarker = -1; |
1109 | SmallVector<int64_t> vec(rank, kNonTiledMarker); |
1110 | for (auto [index, value] : llvm::enumerate(First&: perm)) |
1111 | vec[value] = index; |
1112 | SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector( |
1113 | C&: vec, Pred: [&](int64_t v) { return v != kNonTiledMarker; }); |
1114 | // This inverts the permutation in addition to normalizing so invert back. |
1115 | return invertPermutationVector(permutation: normalizedPerm); |
1116 | } |
1117 | |
1118 | // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm |
1119 | // assuming rank reduction of unit outer dims. |
1120 | static SmallVector<int64_t> |
1121 | getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape, |
1122 | ArrayRef<int64_t> innerDimsPos, |
1123 | ArrayRef<int64_t> outerDimsPerm) { |
1124 | SmallVector<int64_t> rankReducedOuterDimsPerm; |
1125 | SmallVector<int64_t> outerDims; |
1126 | SmallVector<int64_t> innerDims; |
1127 | int64_t dim = 0; |
1128 | int64_t unpackedRank = shape.size(); |
1129 | for (auto i : llvm::seq<unsigned>(Begin: 0, End: unpackedRank)) { |
1130 | if (llvm::is_contained(Range&: innerDimsPos, Element: i)) { |
1131 | innerDims.push_back(Elt: dim++); |
1132 | continue; |
1133 | } |
1134 | if (shape[i] == 1) |
1135 | continue; |
1136 | outerDims.push_back(Elt: dim++); |
1137 | if (!outerDimsPerm.empty()) |
1138 | rankReducedOuterDimsPerm.push_back(Elt: outerDimsPerm[i]); |
1139 | } |
1140 | |
1141 | // Get the position of the inner dims after permutation. |
1142 | SmallVector<int64_t> innerPerm = |
1143 | getPackUnpackNormalizedPerm(rank: unpackedRank, perm: innerDimsPos); |
1144 | applyPermutationToVector<int64_t>(inVec&: innerDims, permutation: innerPerm); |
1145 | |
1146 | // Ditto for the outer dims. |
1147 | SmallVector<int64_t> perm = outerDims; |
1148 | |
1149 | rankReducedOuterDimsPerm = |
1150 | getPackUnpackNormalizedPerm(rank: unpackedRank, perm: rankReducedOuterDimsPerm); |
1151 | if (!rankReducedOuterDimsPerm.empty()) |
1152 | applyPermutationToVector<int64_t>(inVec&: perm, permutation: rankReducedOuterDimsPerm); |
1153 | |
1154 | // The tile always ends up as the inner most dims after packing. |
1155 | perm.append(RHS: innerDims); |
1156 | |
1157 | return perm; |
1158 | } |
1159 | |
1160 | LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( |
1161 | linalg::PackOp packOp, PatternRewriter &rewriter) const { |
1162 | // TODO: support the case that outer dimensions are not all 1s. A |
1163 | // tensor.expand_shape will be generated in this case. |
1164 | if (llvm::any_of(packOp.getAllOuterDims(), |
1165 | [](int64_t dim) { return dim != 1; })) { |
1166 | return rewriter.notifyMatchFailure( |
1167 | packOp, "not all outer dimensions of the result are 1s"); |
1168 | } |
1169 | |
1170 | Attribute zeroIdxAttr = rewriter.getIndexAttr(0); |
1171 | Attribute oneIdxAttr = rewriter.getIndexAttr(1); |
1172 | Location loc = packOp.getLoc(); |
1173 | |
1174 | Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); |
1175 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
1176 | packOp.getDimAndTileMapping(); |
1177 | int64_t srcRank = packOp.getSourceRank(); |
1178 | int64_t destRank = packOp.getDestRank(); |
1179 | int64_t numTiles = destRank - srcRank; |
1180 | |
1181 | if (!llvm::all_of(packOp.getInnerDimsPos(), |
1182 | [&srcRank, &numTiles](int64_t dimPos) { |
1183 | return dimPos >= (srcRank - numTiles - 1); |
1184 | })) |
1185 | return rewriter.notifyMatchFailure( |
1186 | packOp, "Attempting to tile non-trailing source dims!"); |
1187 | |
1188 | // 1. Extract the inner tile sizes. |
1189 | // Where possible, values are replaced with constant attributes (to match the |
1190 | // behaviour of `getPackOpSourceOrPaddedSource`). |
1191 | SmallVector<OpFoldResult> tileSizes; |
1192 | for (auto i : llvm::seq<unsigned>(0, srcRank)) { |
1193 | if (dimAndTileMapping.count(i)) { |
1194 | // Rather than taking the tile size as is, extact the actual constant |
1195 | // value Attribute where possible, e.g.: |
1196 | // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8] |
1197 | auto [_, tileSize] = |
1198 | getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); |
1199 | tileSizes.push_back(tileSize); |
1200 | } |
1201 | } |
1202 | |
1203 | // 2. Transpose the input to match the inner tile order: |
1204 | // %init = tensor.empty() |
1205 | // %transposed_tile = linalg.transpose ins(%source_or_padded_source), |
1206 | // outs(%init) |
1207 | // Two assumptions are made: |
1208 | // 1. All outer dims are 1 - the corresponding transposition doesn't matter. |
1209 | // 2. Inner dims position correspond to the trailing `numTiles` dims. |
1210 | SmallVector<int64_t> tilesPermNormalized = |
1211 | getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos()); |
1212 | SmallVector<int64_t> srcPermForTranspose; |
1213 | for (int64_t i = 0; i < (srcRank - numTiles); i++) |
1214 | srcPermForTranspose.push_back(Elt: i); |
1215 | |
1216 | srcPermForTranspose.append(RHS: SmallVector<int64_t>(packOp.getInnerDimsPos())); |
1217 | |
1218 | LLVM_DEBUG(DBGS() << "Pack permutation: "<< packOp << "\n" |
1219 | << "perm: "<< llvm::interleaved(srcPermForTranspose) |
1220 | << "\n"); |
1221 | |
1222 | // 2.1 Create tensor.empty (init value for TransposeOp) |
1223 | SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles, |
1224 | oneIdxAttr); |
1225 | transShapeForEmptyOp.append(RHS: tileSizes); |
1226 | |
1227 | applyPermutationToVector<OpFoldResult>(inVec&: transShapeForEmptyOp, |
1228 | permutation: srcPermForTranspose); |
1229 | Value empty = rewriter.create<tensor::EmptyOp>( |
1230 | loc, transShapeForEmptyOp, packOp.getSourceType().getElementType()); |
1231 | |
1232 | // 2.2 Create linalg.transpose |
1233 | auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty, |
1234 | srcPermForTranspose); |
1235 | |
1236 | // 3. Insert the inner tile to the destination: |
1237 | // %inserted_tile = tensor.insert_slice(%transposed_tile) |
1238 | SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); |
1239 | SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); |
1240 | // Outer dims are all 1s! |
1241 | SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(), |
1242 | oneIdxAttr); |
1243 | SmallVector<int64_t> writeShape; |
1244 | |
1245 | for (auto tileSize : packOp.getMixedTiles()) { |
1246 | auto [tileSizeStatic, tileSizeOfr] = |
1247 | getSimplifiedOfrAndStaticSizePair(tileSize, rewriter); |
1248 | writeSizes.push_back(tileSizeOfr); |
1249 | writeShape.push_back(tileSizeStatic); |
1250 | } |
1251 | |
1252 | // 4. Replace tensor.packOp with tensor.insert_slice created above |
1253 | auto insert = rewriter.create<tensor::InsertSliceOp>( |
1254 | loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, |
1255 | writeSizes, writeStrides); |
1256 | rewriter.replaceOp(packOp, insert.getResult()); |
1257 | |
1258 | return success(); |
1259 | } |
1260 | |
1261 | LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( |
1262 | linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const { |
1263 | int64_t srcRank = unpackOp.getSourceRank(); |
1264 | int64_t destRank = unpackOp.getDestRank(); |
1265 | ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape(); |
1266 | ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); |
1267 | if (llvm::any_of(unpackOp.getTiledOuterDims(), |
1268 | [](int64_t dim) { return dim != 1; })) { |
1269 | return rewriter.notifyMatchFailure( |
1270 | unpackOp, |
1271 | "require the tiled outer dimensions of the result are all 1s"); |
1272 | } |
1273 | |
1274 | // 1. Use rank-reduced tensor.extract_slice op to extract the tile: |
1275 | // %extracted_tile = tensor.extract_slice(%unpack_op_input) |
1276 | Location loc = unpackOp.getLoc(); |
1277 | Value source = unpackOp.getSource(); |
1278 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
1279 | unpackOp.getDimAndTileMapping(); |
1280 | Attribute zeroIdxAttr = rewriter.getIndexAttr(0); |
1281 | Attribute oneIdxAttr = rewriter.getIndexAttr(1); |
1282 | |
1283 | // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of |
1284 | // dims: |
1285 | // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ] |
1286 | SmallVector<int64_t> readShapeForExtractSlice; |
1287 | // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and |
1288 | // outer-tiled-dims being all 1), this will be |
1289 | // [ outer-untiled-dims, tile-sizes ] |
1290 | SmallVector<OpFoldResult> extractSliceSizes; |
1291 | // The offset and strides attributes for ExtractSliceOp. |
1292 | SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr); |
1293 | SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr); |
1294 | |
1295 | // Shape for EmptyOp that's used as the init value for TransposeOp below. |
1296 | // This should be: |
1297 | // [ outer-untiled-dims, tile-sizes ] |
1298 | // However, skip unit dims - TransposeOp (below) applies rank-reduced |
1299 | // permutation. |
1300 | SmallVector<OpFoldResult> shapeForEmptyOp; |
1301 | |
1302 | for (auto i : llvm::seq<unsigned>(0, destRank)) { |
1303 | // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims. |
1304 | // |
1305 | // As all outer tiled dims are 1, so the corresponding |
1306 | // slice size to read will also 1. As this will be rank-reducing "extract |
1307 | // slice" (i.e. the unit dims will be "collapsed"), there's no need to |
1308 | // update: |
1309 | // * the output shape for ExtractSliceOp, nor |
1310 | // * the shape for EmptyOp. |
1311 | if (dimAndTileMapping.count(i)) { |
1312 | extractSliceSizes.push_back(oneIdxAttr); |
1313 | continue; |
1314 | } |
1315 | |
1316 | // Compute sizes attribute for ExtractSliceOp + EmptyOp - |
1317 | // outer-untiled-dims |
1318 | if (ShapedType::isDynamic(srcShape[i])) { |
1319 | OpFoldResult dynamicDim = |
1320 | rewriter.create<tensor::DimOp>(loc, source, i).getResult(); |
1321 | extractSliceSizes.push_back(dynamicDim); |
1322 | shapeForEmptyOp.push_back(dynamicDim); |
1323 | } else { |
1324 | extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i])); |
1325 | if (srcShape[i] != 1) |
1326 | shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i])); |
1327 | } |
1328 | // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take |
1329 | // into account rank-reducing) |
1330 | if (srcShape[i] != 1) { |
1331 | readShapeForExtractSlice.push_back(srcShape[i]); |
1332 | } |
1333 | } |
1334 | // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the |
1335 | // shape for EmptyOp. |
1336 | auto mixedTiles = unpackOp.getMixedTiles(); |
1337 | extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end()); |
1338 | shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end()); |
1339 | |
1340 | // Explicitly create the type for extract_slice op because the inner tile |
1341 | // size could be 1. We want to represent the whole inner tile in this case. |
1342 | auto tileShape = srcShape.drop_front(N: destRank); |
1343 | // Append the inner tile shape to the permuted and rank-reduced outer shape. |
1344 | readShapeForExtractSlice.append(tileShape.begin(), tileShape.end()); |
1345 | Type elemType = unpackOp.getSourceType().getElementType(); |
1346 | auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); |
1347 | Value innerTile = rewriter.create<tensor::ExtractSliceOp>( |
1348 | loc, readType, unpackOp.getSource(), extractSliceOffsets, |
1349 | extractSliceSizes, extractSliceStrides); |
1350 | |
1351 | // 2. Transpose the tile to match the outer corresponding tile order. |
1352 | SmallVector<int64_t> perm = getPackUnpackRankReducedPerm( |
1353 | srcShape.take_front(N: destRank), innerDimsPos, unpackOp.getOuterDimsPerm()); |
1354 | // Unpack is a transition out of packed space so we invert the permutation. |
1355 | perm = invertPermutationVector(permutation: perm); |
1356 | applyPermutationToVector<OpFoldResult>(inVec&: shapeForEmptyOp, permutation: perm); |
1357 | |
1358 | Value empty = |
1359 | rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType); |
1360 | auto transposedOp = |
1361 | rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm); |
1362 | |
1363 | // 3. Handle in-complete tiles if needed. It truncates trailing data from the |
1364 | // transposed tile. |
1365 | int numLoops = shapeForEmptyOp.size(); |
1366 | SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr); |
1367 | SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr); |
1368 | SmallVector<OpFoldResult> tileSizes; |
1369 | ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape(); |
1370 | for (auto i : llvm::seq<unsigned>(0, destRank)) { |
1371 | if (dimAndTileMapping.count(i) || destShape[i] != 1) |
1372 | tileSizes.push_back( |
1373 | tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i)); |
1374 | } |
1375 | |
1376 | auto partialTile = rewriter.create<tensor::ExtractSliceOp>( |
1377 | loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); |
1378 | |
1379 | // 4. Insert the result to the destination tensor. |
1380 | SmallVector<OpFoldResult> writeSizes; |
1381 | SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); |
1382 | SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); |
1383 | for (int i = 0, idx = 0; i < destRank; ++i) { |
1384 | if (dimAndTileMapping.count(Val: i) || destShape[i] != 1) |
1385 | writeSizes.push_back(Elt: tileSizes[idx++]); |
1386 | else |
1387 | writeSizes.push_back(Elt: oneIdxAttr); |
1388 | } |
1389 | auto insert = rewriter.create<tensor::InsertSliceOp>( |
1390 | loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes, |
1391 | writeStrides); |
1392 | rewriter.replaceOp(unpackOp, insert.getResult()); |
1393 | |
1394 | return success(); |
1395 | } |
1396 | |
1397 | // The following are patterns for downscaling convolution ops with size-1 |
1398 | // window dimensions. |
1399 | // |
1400 | // Note that we'd eventually want to write such transformations in a generic |
1401 | // way, e.g., converting to linalg.generic, removing the size-1 dimensions, |
1402 | // and then turning back to named ops. But for now it's fine to have a few |
1403 | // patterns matching special ops to get started. |
1404 | |
1405 | template <typename Conv2DOp, typename Conv1DOp> |
1406 | FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>:: |
1407 | returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { |
1408 | if (convOp.hasPureBufferSemantics()) |
1409 | return failure(); // To be implemented. |
1410 | |
1411 | Value input = convOp.getInputs().front(); |
1412 | Value kernel = convOp.getInputs().back(); |
1413 | Value output = convOp.getOutputs().front(); |
1414 | |
1415 | auto inputType = dyn_cast<RankedTensorType>(input.getType()); |
1416 | auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); |
1417 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
1418 | |
1419 | auto kernelShape = kernelType.getShape(); |
1420 | auto outputShape = outputType.getShape(); |
1421 | |
1422 | // Get domain indices based on conv2D layout. |
1423 | auto [khIndex, kwIndex, ohIndex, owIndex] = |
1424 | TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>( |
1425 | convOp) |
1426 | .Case([&](linalg::Conv2DNhwcHwcfOp op) { |
1427 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1428 | }) |
1429 | .Case([&](linalg::Conv2DNchwFchwOp op) { |
1430 | return std::make_tuple(args: 2, args: 3, args: 2, args: 3); |
1431 | }) |
1432 | .Case([&](linalg::PoolingNhwcSumOp op) { |
1433 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1434 | }) |
1435 | .Case([&](linalg::PoolingNchwSumOp op) { |
1436 | return std::make_tuple(args: 0, args: 1, args: 2, args: 3); |
1437 | }) |
1438 | .Case([&](linalg::PoolingNhwcMaxOp op) { |
1439 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1440 | }) |
1441 | .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) { |
1442 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1443 | }) |
1444 | .Case([&](linalg::PoolingNhwcMinOp op) { |
1445 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1446 | }) |
1447 | .Case([&](linalg::PoolingNhwcMinUnsignedOp op) { |
1448 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
1449 | }) |
1450 | .Case([&](linalg::PoolingNchwMaxOp op) { |
1451 | return std::make_tuple(args: 0, args: 1, args: 2, args: 3); |
1452 | }) |
1453 | .Default([&](Operation *op) { |
1454 | llvm_unreachable("unexpected conv2d/pool2d operation."); |
1455 | return std::make_tuple(args: 0, args: 0, args: 0, args: 0); |
1456 | }); |
1457 | |
1458 | // Only handle the case where at least one of the window dimensions is |
1459 | // of size 1. Other cases can rely on tiling to reduce to such cases. |
1460 | int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; |
1461 | int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; |
1462 | bool removeH = (khSize == 1 && ohSize == 1); |
1463 | bool removeW = (kwSize == 1 && owSize == 1); |
1464 | if (!removeH && !removeW) |
1465 | return failure(); |
1466 | |
1467 | // Get new shapes and types for all operands by removing the size-1 |
1468 | // dimension. |
1469 | using RTTBuilder = RankedTensorType::Builder; |
1470 | RankedTensorType newInputType = |
1471 | RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); |
1472 | RankedTensorType newKernelType = |
1473 | RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); |
1474 | RankedTensorType newOutputType = |
1475 | RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); |
1476 | |
1477 | // Rank-reduce operands. |
1478 | Location loc = convOp.getLoc(); |
1479 | Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( |
1480 | b&: rewriter, loc, tensor: input, targetType: newInputType); |
1481 | Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( |
1482 | b&: rewriter, loc, tensor: kernel, targetType: newKernelType); |
1483 | Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( |
1484 | b&: rewriter, loc, tensor: output, targetType: newOutputType); |
1485 | |
1486 | // Rank-reduce strides and dilations too. |
1487 | // TODO: dropDim 1-liner helper. |
1488 | auto strides = |
1489 | llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>()); |
1490 | strides.erase(strides.begin() + (removeH ? 0 : 1)); |
1491 | auto stridesAttr = rewriter.getI64VectorAttr(values: strides); |
1492 | |
1493 | auto dilations = |
1494 | llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>()); |
1495 | dilations.erase(dilations.begin() + (removeH ? 0 : 1)); |
1496 | auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations); |
1497 | |
1498 | auto conv1DOp = rewriter.create<Conv1DOp>( |
1499 | loc, newOutputType, ValueRange{newInput, newKernel}, |
1500 | ValueRange{newOutput}, stridesAttr, dilationsAttr); |
1501 | |
1502 | // Insert back. |
1503 | Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( |
1504 | b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output); |
1505 | rewriter.replaceOp(convOp, inserted); |
1506 | |
1507 | return conv1DOp; |
1508 | } |
1509 | |
1510 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp, |
1511 | Conv1DNwcWcfOp>; |
1512 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp, |
1513 | Conv1DNcwFcwOp>; |
1514 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, |
1515 | PoolingNwcSumOp>; |
1516 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, |
1517 | PoolingNcwSumOp>; |
1518 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, |
1519 | PoolingNwcMaxOp>; |
1520 | template struct linalg::DownscaleSizeOneWindowed2DConvolution< |
1521 | PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>; |
1522 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, |
1523 | PoolingNwcMinOp>; |
1524 | template struct linalg::DownscaleSizeOneWindowed2DConvolution< |
1525 | PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>; |
1526 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, |
1527 | PoolingNcwMaxOp>; |
1528 | |
1529 | FailureOr<DepthwiseConv1DNwcWcOp> |
1530 | DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( |
1531 | DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { |
1532 | if (convOp.hasPureBufferSemantics()) |
1533 | return failure(); // To be implemented. |
1534 | |
1535 | Value input = convOp.getInputs().front(); |
1536 | Value kernel = convOp.getInputs().back(); |
1537 | Value output = convOp.getOutputs().front(); |
1538 | |
1539 | auto inputType = dyn_cast<RankedTensorType>(input.getType()); |
1540 | auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); |
1541 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
1542 | |
1543 | auto kernelShape = kernelType.getShape(); |
1544 | auto outputShape = outputType.getShape(); |
1545 | |
1546 | // Only handle the case where at least one of the window dimensions is |
1547 | // of size 1. Other cases can rely on tiling to reduce to such cases. |
1548 | int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; |
1549 | int64_t ohSize = outputShape[1], owSize = outputShape[2]; |
1550 | bool removeH = (khSize == 1 && ohSize == 1); |
1551 | bool removeW = (kwSize == 1 && owSize == 1); |
1552 | if (!removeH && !removeW) |
1553 | return failure(); |
1554 | |
1555 | // Get new shapes and types for all operands by removing the size-1 |
1556 | // dimension. |
1557 | using RTTBuilder = RankedTensorType::Builder; |
1558 | RankedTensorType newInputType = |
1559 | RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); |
1560 | RankedTensorType newKernelType = |
1561 | RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); |
1562 | RankedTensorType newOutputType = |
1563 | RTTBuilder(outputType).dropDim(removeH ? 1 : 2); |
1564 | |
1565 | // Rank-reduce operands. |
1566 | Location loc = convOp.getLoc(); |
1567 | Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( |
1568 | b&: rewriter, loc, tensor: input, targetType: newInputType); |
1569 | Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( |
1570 | b&: rewriter, loc, tensor: kernel, targetType: newKernelType); |
1571 | Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( |
1572 | b&: rewriter, loc, tensor: output, targetType: newOutputType); |
1573 | |
1574 | // Rank-reduce strides and dilations too. |
1575 | // TODO: dropDim 1-liner helper. |
1576 | auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>()); |
1577 | strides.erase(strides.begin() + (removeH ? 0 : 1)); |
1578 | auto stridesAttr = rewriter.getI64VectorAttr(values: strides); |
1579 | |
1580 | auto dilations = |
1581 | llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>()); |
1582 | dilations.erase(dilations.begin() + (removeH ? 0 : 1)); |
1583 | auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations); |
1584 | |
1585 | auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( |
1586 | loc, newOutputType, ValueRange{newInput, newKernel}, |
1587 | ValueRange{newOutput}, stridesAttr, dilationsAttr); |
1588 | |
1589 | // Insert back. |
1590 | Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( |
1591 | b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output); |
1592 | rewriter.replaceOp(convOp, inserted); |
1593 | |
1594 | return conv1DOp; |
1595 | } |
1596 | |
1597 | FailureOr<Conv1DOp> |
1598 | DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp, |
1599 | PatternRewriter &rewriter) const { |
1600 | if (convOp.hasPureBufferSemantics()) |
1601 | return failure(); // To be implemented. |
1602 | |
1603 | Value input = convOp.getInputs().front(); |
1604 | Value kernel = convOp.getInputs().back(); |
1605 | Value output = convOp.getOutputs().front(); |
1606 | |
1607 | auto inputType = dyn_cast<RankedTensorType>(input.getType()); |
1608 | auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); |
1609 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
1610 | |
1611 | auto kernelShape = kernelType.getShape(); |
1612 | auto outputShape = outputType.getShape(); |
1613 | |
1614 | // Only handle the case where at least one of the window dimensions is |
1615 | // of size 1. Other cases can rely on tiling to reduce to such cases. |
1616 | int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; |
1617 | int64_t ohSize = outputShape[0], owSize = outputShape[1]; |
1618 | bool removeH = (khSize == 1 && ohSize == 1); |
1619 | bool removeW = (kwSize == 1 && owSize == 1); |
1620 | if (!removeH && !removeW) |
1621 | return failure(); |
1622 | |
1623 | // Get new shapes and types for all operands by removing the size-1 |
1624 | // dimension. |
1625 | using RTTBuilder = RankedTensorType::Builder; |
1626 | RankedTensorType newInputType = |
1627 | RTTBuilder(inputType).dropDim((removeH ? 0 : 1)); |
1628 | RankedTensorType newKernelType = |
1629 | RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); |
1630 | RankedTensorType newOutputType = |
1631 | RTTBuilder(outputType).dropDim(removeH ? 0 : 1); |
1632 | |
1633 | // Rank-reduce operands. |
1634 | Location loc = convOp.getLoc(); |
1635 | Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( |
1636 | b&: rewriter, loc, tensor: input, targetType: newInputType); |
1637 | Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( |
1638 | b&: rewriter, loc, tensor: kernel, targetType: newKernelType); |
1639 | Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( |
1640 | b&: rewriter, loc, tensor: output, targetType: newOutputType); |
1641 | |
1642 | auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType, |
1643 | ValueRange{newInput, newKernel}, |
1644 | ValueRange{newOutput}); |
1645 | |
1646 | // Insert back. |
1647 | Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( |
1648 | b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output); |
1649 | rewriter.replaceOp(convOp, inserted); |
1650 | |
1651 | return conv1DOp; |
1652 | } |
1653 | |
1654 | void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, |
1655 | PatternBenefit benefit) { |
1656 | patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp, |
1657 | Conv1DNwcWcfOp>, |
1658 | DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp, |
1659 | Conv1DNcwFcwOp>, |
1660 | DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>( |
1661 | patterns.getContext(), benefit); |
1662 | patterns.add< |
1663 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>, |
1664 | DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>, |
1665 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>, |
1666 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp, |
1667 | PoolingNwcMaxUnsignedOp>, |
1668 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>, |
1669 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp, |
1670 | PoolingNwcMinUnsignedOp>, |
1671 | DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>( |
1672 | patterns.getContext(), benefit); |
1673 | } |
1674 | |
1675 | void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) { |
1676 | patterns.add<DecomposeOuterUnitDimsPackOpPattern>(arg: patterns.getContext()); |
1677 | patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(arg: patterns.getContext()); |
1678 | } |
1679 | |
1680 | void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) { |
1681 | patterns.add<DecomposePadOpPattern>(arg: patterns.getContext()); |
1682 | } |
1683 |
Definitions
- peelLoop
- peelLoops
- hasAtMostOneResultFunctionOfDim
- stringifyReassocIndices
- getFirstResultIndexFunctionOf
- packLinalgMetadataOnce
- PackedOperandsDim
- PackedOperandsDimList
- pushBack
- lowerPack
- lowerUnPack
- extractPackedDimsForOperand
- extractPackSizesForOperand
- pack
- permuteShape
- transposeOneLinalgOperandAndReplace
- packTranspose
- packMatmulGreedily
- setTileSizes
- matchAndRewrite
- createFillOrGenerateOp
- matchAndRewrite
- matchAndRewrite
- getPackOpSourceOrPaddedSource
- getPackUnpackNormalizedPerm
- getPackUnpackRankReducedPerm
- matchAndRewrite
- matchAndRewrite
- returningMatchAndRewrite
- returningMatchAndRewrite
- returningMatchAndRewrite
- populateDecomposeConvolutionPatterns
- populateDecomposePackUnpackPatterns
Learn to use CMake with our Intro Training
Find out more