1 | //===- TestTensorTransforms.cpp - Test Tensor transformation patterns -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements logic for testing Tensor transformations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
15 | #include "mlir/Dialect/SCF/IR/SCF.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" |
18 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
19 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
20 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
23 | |
24 | using namespace mlir; |
25 | |
26 | namespace { |
27 | struct TestTensorTransforms |
28 | : public PassWrapper<TestTensorTransforms, OperationPass<>> { |
29 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorTransforms) |
30 | |
31 | TestTensorTransforms() = default; |
32 | TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {} |
33 | |
34 | void getDependentDialects(DialectRegistry ®istry) const override { |
35 | registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect, |
36 | transform::TransformDialect>(); |
37 | } |
38 | |
39 | StringRef getArgument() const final { |
40 | return "test-tensor-transform-patterns" ; |
41 | } |
42 | StringRef getDescription() const final { |
43 | return "Test Tensor transformation patterns by applying them greedily." ; |
44 | } |
45 | |
46 | void runOnOperation() override; |
47 | |
48 | Option<bool> { |
49 | *this, "test-fold-constant-extract-slice" , |
50 | llvm::cl::desc("Test folding arith.constant and tensor.extract_slice" ), |
51 | llvm::cl::init(Val: false)}; |
52 | |
53 | Option<bool> { |
54 | *this, "test-fold-consecutive-insert-extract-slice" , |
55 | llvm::cl::desc( |
56 | "Test folding consecutive tensor.insert_slice/tensor.extract_slice" ), |
57 | llvm::cl::init(Val: false)}; |
58 | |
59 | Option<bool> { |
60 | *this, "test-rewrite-extract-slice-from-collapse-shape" , |
61 | llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " |
62 | "with loop nest" ), |
63 | llvm::cl::init(Val: false)}; |
64 | |
65 | Option<bool> testDropRedundantInsertSliceRankExpansion{ |
66 | *this, "test-drop-redundant-insert-slice-rank-expansion" , |
67 | llvm::cl::desc("Test dropping redundant insert_slice rank expansions" ), |
68 | llvm::cl::init(Val: false)}; |
69 | |
70 | Option<bool> testReassociativeReshapeFolding{ |
71 | *this, "test-reassociative-reshape-folding" , |
72 | llvm::cl::desc("Test folding of expand_shape/collapse_shape" ), |
73 | llvm::cl::init(Val: false)}; |
74 | |
75 | Option<bool> testFoldIntoPackAndUnpack{ |
76 | *this, "test-fold-into-pack-and-unpack" , |
77 | llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack" ), |
78 | llvm::cl::init(Val: false)}; |
79 | |
80 | Option<bool> useForeach{ |
81 | *this, "use-foreach" , |
82 | llvm::cl::desc( |
83 | "Use the scf.forall operation when generating loop nests for " |
84 | "the extract_slice of collapse_shape pattern" ), |
85 | llvm::cl::init(Val: false)}; |
86 | |
87 | Option<bool> testSimplifyPackUnpackPatterns{ |
88 | *this, "test-simplify-pack-unpack-patterns" , |
89 | llvm::cl::desc("Test patterns to simplify tensor.pack and tensor.unpack" ), |
90 | llvm::cl::init(Val: false)}; |
91 | |
92 | Option<bool> testTrackingListener{ |
93 | *this, "test-tracking-listener" , |
94 | llvm::cl::desc("Test tensor TrackingListener for the transform dialect" ), |
95 | llvm::cl::init(Val: false)}; |
96 | }; |
97 | } // namespace |
98 | |
99 | static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) { |
100 | RewritePatternSet patterns(rootOp->getContext()); |
101 | tensor::populateReassociativeReshapeFoldingPatterns(patterns); |
102 | (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); |
103 | } |
104 | |
105 | static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) { |
106 | RewritePatternSet patterns(rootOp->getContext()); |
107 | tensor::populateFoldIntoPackAndUnpackPatterns(patterns); |
108 | (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); |
109 | } |
110 | |
111 | static void (Operation *rootOp) { |
112 | RewritePatternSet patterns(rootOp->getContext()); |
113 | tensor::ControlConstantExtractSliceFusionFn controlFn = |
114 | [](tensor::ExtractSliceOp op) { |
115 | if (!op.getSource().hasOneUse()) |
116 | return false; |
117 | |
118 | auto resultType = cast<ShapedType>(op.getResult().getType()); |
119 | constexpr int64_t kConstantFoldingMaxNumElements = 1024; |
120 | return resultType.getNumElements() <= kConstantFoldingMaxNumElements; |
121 | }; |
122 | |
123 | tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn); |
124 | (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); |
125 | } |
126 | |
127 | static void (Operation *rootOp) { |
128 | RewritePatternSet patterns(rootOp->getContext()); |
129 | tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); |
130 | (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); |
131 | } |
132 | |
133 | static void |
134 | applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) { |
135 | RewritePatternSet patterns(rootOp->getContext()); |
136 | tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); |
137 | (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); |
138 | } |
139 | |
140 | static void applySimplifyPackUnpackPatterns(Operation *rootOp) { |
141 | RewritePatternSet patterns(rootOp->getContext()); |
142 | tensor::populateSimplifyPackAndUnpackPatterns(patterns); |
143 | (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); |
144 | } |
145 | |
146 | namespace { |
147 | /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`. |
148 | /// The `tensor.extract_slice` is replaced by a loop or gather operation that |
149 | /// stitches together the desired tile from slices of the source of the collapse |
150 | /// shape op. |
151 | struct |
152 | : public OpRewritePattern<tensor::ExtractSliceOp> { |
153 | (MLIRContext *context) |
154 | : mlir::OpRewritePattern<tensor::ExtractSliceOp>(context) {} |
155 | |
156 | /// Emit a loop or gather operation that uses `helper` to take each point in |
157 | /// the parallel iteration space bounds, extract a slice from the source |
158 | /// tensor and insert it into `dest`. For examples, see below for `scf.for` |
159 | /// and `scf.foreach`. |
160 | virtual LogicalResult |
161 | (tensor::ExtractSliceOp op, Value dest, |
162 | tensor::ExtractSliceFromCollapseHelper &helper, |
163 | PatternRewriter &rewriter) const = 0; |
164 | |
165 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, |
166 | PatternRewriter &rewriter) const override { |
167 | auto collapseOp = op.getSource().getDefiningOp<tensor::CollapseShapeOp>(); |
168 | if (!collapseOp) |
169 | return rewriter.notifyMatchFailure( |
170 | op, "producer is not a tensor.collapse_shape op" ); |
171 | |
172 | // Try to simplify the collapse shape using a rank-reducing slice, if |
173 | // possible. |
174 | FailureOr<Operation *> simplifiedCollapseShapeResult = |
175 | tensor::simplifyCollapseShapeWithRankReducingExtractSlice(collapseOp, |
176 | rewriter); |
177 | if (succeeded(result: simplifiedCollapseShapeResult)) { |
178 | auto newCollapseOp = |
179 | dyn_cast<tensor::CollapseShapeOp>(*simplifiedCollapseShapeResult); |
180 | // The collapse shape op might have been simplified away, so we can just |
181 | // return. |
182 | if (!newCollapseOp) |
183 | return success(); |
184 | collapseOp = newCollapseOp; |
185 | } |
186 | |
187 | // Materialize the output shape values of the slice operation. |
188 | ReifiedRankedShapedTypeDims reifiedShapes; |
189 | if (failed(reifyResultShapes(rewriter, op, reifiedShapes))) |
190 | return rewriter.notifyMatchFailure(op, "failed to reify result shapes" ); |
191 | |
192 | // Create the destination tensor using the above values. |
193 | Type elementType = op.getSourceType().getElementType(); |
194 | SmallVector<OpFoldResult> outputShape = reifiedShapes[0]; |
195 | Value dest = rewriter.create<tensor::EmptyOp>(op->getLoc(), outputShape, |
196 | elementType); |
197 | |
198 | // Calculate the parameters for the tile loop nest. |
199 | FailureOr<tensor::ExtractSliceFromCollapseHelper> params = |
200 | tensor::ExtractSliceFromCollapseHelper::create(rewriter, collapseOp, |
201 | op); |
202 | if (failed(params)) |
203 | return rewriter.notifyMatchFailure( |
204 | op, "could not calculate tiling parameters" ); |
205 | return emitReplacement(op, dest, *params, rewriter); |
206 | } |
207 | }; |
208 | |
209 | struct |
210 | : public RewriteExtractSliceFromCollapseShapeBase { |
211 | (MLIRContext *context) |
212 | : RewriteExtractSliceFromCollapseShapeBase(context) {} |
213 | LogicalResult (tensor::ExtractSliceOp op, Value dest, |
214 | tensor::ExtractSliceFromCollapseHelper &helper, |
215 | PatternRewriter &rewriter) const override { |
216 | Location loc = op.getLoc(); |
217 | const unsigned numTiledDims = helper.getIterationSpaceSizes().size(); |
218 | auto zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
219 | auto one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
220 | SmallVector<Value> lbs(numTiledDims, zero); |
221 | SmallVector<Value> steps(numTiledDims, one); |
222 | |
223 | scf::LoopNest nest = scf::buildLoopNest( |
224 | builder&: rewriter, loc, lbs, ubs: helper.getIterationSpaceSizes(), steps, iterArgs: dest, |
225 | bodyBuilder: [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs, |
226 | ValueRange iterArgs) -> scf::ValueVector { |
227 | auto [tile, insertParams] = |
228 | helper.emitLoopNestBody(builder&: nestedBuilder, loc, tileInductionVars: outputIvs); |
229 | |
230 | // Insert the slice into the destination. |
231 | return {nestedBuilder.create<tensor::InsertSliceOp>( |
232 | loc, tile, iterArgs[0], insertParams)}; |
233 | }); |
234 | rewriter.replaceOp(op, nest.results); |
235 | |
236 | return success(); |
237 | } |
238 | }; |
239 | |
240 | struct |
241 | : public RewriteExtractSliceFromCollapseShapeBase { |
242 | (MLIRContext *context) |
243 | : RewriteExtractSliceFromCollapseShapeBase(context) {} |
244 | LogicalResult (tensor::ExtractSliceOp op, Value dest, |
245 | tensor::ExtractSliceFromCollapseHelper &helper, |
246 | PatternRewriter &rewriter) const override { |
247 | Location loc = op.getLoc(); |
248 | auto forallOp = rewriter.create<scf::ForallOp>( |
249 | loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()), |
250 | /*outputs=*/dest, |
251 | /*mapping=*/std::nullopt, |
252 | [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) { |
253 | unsigned numThreadIdRegionArgs = |
254 | helper.getIterationSpaceSizes().size(); |
255 | unsigned numOutputRegionArgs = |
256 | regionArgs.size() - numThreadIdRegionArgs; |
257 | ValueRange outputIvs = regionArgs.take_front(numThreadIdRegionArgs); |
258 | ValueRange outputArgs = regionArgs.take_back(numOutputRegionArgs); |
259 | assert(outputArgs.size() == 1 && |
260 | "there should only be one output region argument" ); |
261 | auto [tile, insertParams] = |
262 | helper.emitLoopNestBody(nestedBuilder, loc, outputIvs); |
263 | // Insert the slice into the destination. |
264 | auto term = nestedBuilder.create<scf::InParallelOp>(loc); |
265 | nestedBuilder.setInsertionPointToStart(term.getBody()); |
266 | nestedBuilder.create<tensor::ParallelInsertSliceOp>( |
267 | loc, tile, outputArgs[0], insertParams); |
268 | }); |
269 | rewriter.replaceOp(op, forallOp->getResult(0)); |
270 | return success(); |
271 | } |
272 | }; |
273 | } // namespace |
274 | |
275 | static LogicalResult |
276 | (Operation *rootOp, |
277 | bool useForeach) { |
278 | RewritePatternSet patterns(rootOp->getContext()); |
279 | if (useForeach) |
280 | patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfForeach>( |
281 | arg: rootOp->getContext()); |
282 | else |
283 | patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfFor>( |
284 | arg: rootOp->getContext()); |
285 | return applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); |
286 | } |
287 | |
288 | namespace { |
289 | class DummyTrackingListener : public transform::TrackingListener { |
290 | public: |
291 | using transform::TrackingListener::TrackingListener; |
292 | |
293 | // Expose `findReplacementOp` as a public function, so that it can be tested. |
294 | Operation *getReplacementOp(Operation *op, ValueRange newValues) const { |
295 | Operation *replacementOp; |
296 | if (!findReplacementOp(replacementOp, op, newValues).succeeded()) |
297 | return nullptr; |
298 | return replacementOp; |
299 | } |
300 | }; |
301 | } // namespace |
302 | |
303 | static LogicalResult testTrackingListenerReplacements(Operation *rootOp) { |
304 | // Find replaced op. |
305 | Operation *replaced = nullptr; |
306 | WalkResult status = rootOp->walk(callback: [&](Operation *op) { |
307 | if (op->hasAttr(name: "replaced" )) { |
308 | if (replaced) { |
309 | op->emitError(message: "only one 'replaced' op is allowed per test case" ); |
310 | replaced->emitRemark(message: "other 'replaced' op" ); |
311 | return WalkResult::interrupt(); |
312 | } |
313 | replaced = op; |
314 | } |
315 | return WalkResult::advance(); |
316 | }); |
317 | if (status.wasInterrupted()) |
318 | return failure(); |
319 | if (!replaced) { |
320 | rootOp->emitError(message: "could not find 'replaced' op" ); |
321 | return failure(); |
322 | } |
323 | |
324 | // Find replacements. |
325 | SmallVector<Value> replacements(replaced->getNumResults(), Value()); |
326 | status = rootOp->walk(callback: [&](Operation *op) { |
327 | for (int64_t i = 0; i < replaced->getNumResults(); ++i) { |
328 | if (auto attr = op->getAttrOfType<IntegerAttr>("replacement_" + |
329 | std::to_string(i))) { |
330 | if (replacements[i]) { |
331 | op->emitError(message: "only one 'replacement_" + std::to_string(val: i) + |
332 | "' is allowed per test case" ); |
333 | replacements[i].getDefiningOp()->emitRemark(message: "other 'replacement_" + |
334 | std::to_string(val: i) + "'" ); |
335 | return WalkResult::interrupt(); |
336 | } |
337 | replacements[i] = op->getResult(idx: attr.getInt()); |
338 | } |
339 | } |
340 | return WalkResult::advance(); |
341 | }); |
342 | if (status.wasInterrupted()) |
343 | return failure(); |
344 | |
345 | if (!llvm::all_of(Range&: replacements, |
346 | P: [](Value v) { return static_cast<bool>(v); })) { |
347 | replaced->emitError(message: "insufficient replacement values" ); |
348 | return failure(); |
349 | } |
350 | |
351 | // Find the replacement op (if any) and emit a remark/error. |
352 | transform::TransformState transformState = |
353 | transform::detail::makeTransformStateForTesting(/*region=*/nullptr, |
354 | /*payloadRoot=*/nullptr); |
355 | MLIRContext *context = rootOp->getContext(); |
356 | OpBuilder builder(context); |
357 | OwningOpRef<transform::NamedSequenceOp> transformOp = |
358 | builder.create<transform::NamedSequenceOp>( |
359 | rootOp->getLoc(), |
360 | /*sym_name=*/"test_sequence" , |
361 | /*function_type=*/ |
362 | TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})), |
363 | /*sym_visibility*/ StringAttr::get(context, "public" ), |
364 | /*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()), |
365 | /*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>())); |
366 | DummyTrackingListener listener(transformState, transformOp.get()); |
367 | Operation *replacement = listener.getReplacementOp(op: replaced, newValues: replacements); |
368 | if (!replacement) { |
369 | replaced->emitError(message: "listener could not find replacement op" ); |
370 | return failure(); |
371 | } |
372 | |
373 | replacement->emitRemark(message: "replacement found" ); |
374 | return success(); |
375 | } |
376 | |
377 | void TestTensorTransforms::runOnOperation() { |
378 | Operation *rootOp = getOperation(); |
379 | if (testSimplifyPackUnpackPatterns) |
380 | applySimplifyPackUnpackPatterns(rootOp); |
381 | if (testFoldConstantExtractSlice) |
382 | applyFoldConstantExtractSlicePatterns(rootOp); |
383 | if (testFoldConsecutiveInsertExtractSlice) |
384 | applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); |
385 | if (testDropRedundantInsertSliceRankExpansion) |
386 | applyDropRedundantInsertSliceRankExpansionPatterns(rootOp); |
387 | if (testReassociativeReshapeFolding) |
388 | applyReassociativeReshapeFoldingPatterns(rootOp); |
389 | if (testFoldIntoPackAndUnpack) |
390 | applyFoldIntoPackAndUnpackPatterns(rootOp); |
391 | if (testRewriteExtractSliceWithTiledCollapseShape) { |
392 | if (failed( |
393 | result: applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach))) |
394 | return signalPassFailure(); |
395 | } |
396 | if (testTrackingListener) |
397 | if (failed(result: testTrackingListenerReplacements(rootOp))) |
398 | return signalPassFailure(); |
399 | } |
400 | |
401 | namespace mlir { |
402 | namespace test { |
403 | void registerTestTensorTransforms() { |
404 | PassRegistration<TestTensorTransforms>(); |
405 | } |
406 | } // namespace test |
407 | } // namespace mlir |
408 | |