1//===- TestTensorTransforms.cpp - Test Tensor transformation patterns -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic for testing Tensor transformations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/SCF/IR/SCF.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
18#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19#include "mlir/Dialect/Transform/IR/TransformOps.h"
20#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23
24using namespace mlir;
25
26namespace {
27struct TestTensorTransforms
28 : public PassWrapper<TestTensorTransforms, OperationPass<>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorTransforms)
30
31 TestTensorTransforms() = default;
32 TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
33
34 void getDependentDialects(DialectRegistry &registry) const override {
35 registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect,
36 transform::TransformDialect>();
37 }
38
39 StringRef getArgument() const final {
40 return "test-tensor-transform-patterns";
41 }
42 StringRef getDescription() const final {
43 return "Test Tensor transformation patterns by applying them greedily.";
44 }
45
46 void runOnOperation() override;
47
48 Option<bool> testFoldConstantExtractSlice{
49 *this, "test-fold-constant-extract-slice",
50 llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
51 llvm::cl::init(Val: false)};
52
53 Option<bool> testFoldConsecutiveInsertExtractSlice{
54 *this, "test-fold-consecutive-insert-extract-slice",
55 llvm::cl::desc(
56 "Test folding consecutive tensor.insert_slice/tensor.extract_slice"),
57 llvm::cl::init(Val: false)};
58
59 Option<bool> testRewriteExtractSliceWithTiledCollapseShape{
60 *this, "test-rewrite-extract-slice-from-collapse-shape",
61 llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape "
62 "with loop nest"),
63 llvm::cl::init(Val: false)};
64
65 Option<bool> testDropRedundantInsertSliceRankExpansion{
66 *this, "test-drop-redundant-insert-slice-rank-expansion",
67 llvm::cl::desc("Test dropping redundant insert_slice rank expansions"),
68 llvm::cl::init(Val: false)};
69
70 Option<bool> testReassociativeReshapeFolding{
71 *this, "test-reassociative-reshape-folding",
72 llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
73 llvm::cl::init(Val: false)};
74
75 Option<bool> testFoldIntoPackAndUnpack{
76 *this, "test-fold-into-pack-and-unpack",
77 llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
78 llvm::cl::init(Val: false)};
79
80 Option<bool> useForeach{
81 *this, "use-foreach",
82 llvm::cl::desc(
83 "Use the scf.forall operation when generating loop nests for "
84 "the extract_slice of collapse_shape pattern"),
85 llvm::cl::init(Val: false)};
86
87 Option<bool> testSimplifyPackUnpackPatterns{
88 *this, "test-simplify-pack-unpack-patterns",
89 llvm::cl::desc("Test patterns to simplify tensor.pack and tensor.unpack"),
90 llvm::cl::init(Val: false)};
91
92 Option<bool> testTrackingListener{
93 *this, "test-tracking-listener",
94 llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
95 llvm::cl::init(Val: false)};
96};
97} // namespace
98
99static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) {
100 RewritePatternSet patterns(rootOp->getContext());
101 tensor::populateReassociativeReshapeFoldingPatterns(patterns);
102 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
103}
104
105static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
106 RewritePatternSet patterns(rootOp->getContext());
107 tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
108 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
109}
110
111static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
112 RewritePatternSet patterns(rootOp->getContext());
113 tensor::ControlConstantExtractSliceFusionFn controlFn =
114 [](tensor::ExtractSliceOp op) {
115 if (!op.getSource().hasOneUse())
116 return false;
117
118 auto resultType = cast<ShapedType>(op.getResult().getType());
119 constexpr int64_t kConstantFoldingMaxNumElements = 1024;
120 return resultType.getNumElements() <= kConstantFoldingMaxNumElements;
121 };
122
123 tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
124 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
125}
126
127static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) {
128 RewritePatternSet patterns(rootOp->getContext());
129 tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
130 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
131}
132
133static void
134applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
135 RewritePatternSet patterns(rootOp->getContext());
136 tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);
137 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
138}
139
140static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
141 RewritePatternSet patterns(rootOp->getContext());
142 tensor::populateSimplifyPackAndUnpackPatterns(patterns);
143 (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
144}
145
146namespace {
147/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
148/// The `tensor.extract_slice` is replaced by a loop or gather operation that
149/// stitches together the desired tile from slices of the source of the collapse
150/// shape op.
151struct RewriteExtractSliceFromCollapseShapeBase
152 : public OpRewritePattern<tensor::ExtractSliceOp> {
153 RewriteExtractSliceFromCollapseShapeBase(MLIRContext *context)
154 : mlir::OpRewritePattern<tensor::ExtractSliceOp>(context) {}
155
156 /// Emit a loop or gather operation that uses `helper` to take each point in
157 /// the parallel iteration space bounds, extract a slice from the source
158 /// tensor and insert it into `dest`. For examples, see below for `scf.for`
159 /// and `scf.foreach`.
160 virtual LogicalResult
161 emitReplacement(tensor::ExtractSliceOp op, Value dest,
162 tensor::ExtractSliceFromCollapseHelper &helper,
163 PatternRewriter &rewriter) const = 0;
164
165 LogicalResult matchAndRewrite(tensor::ExtractSliceOp op,
166 PatternRewriter &rewriter) const override {
167 auto collapseOp = op.getSource().getDefiningOp<tensor::CollapseShapeOp>();
168 if (!collapseOp)
169 return rewriter.notifyMatchFailure(
170 op, "producer is not a tensor.collapse_shape op");
171
172 // Try to simplify the collapse shape using a rank-reducing slice, if
173 // possible.
174 FailureOr<Operation *> simplifiedCollapseShapeResult =
175 tensor::simplifyCollapseShapeWithRankReducingExtractSlice(collapseOp,
176 rewriter);
177 if (succeeded(result: simplifiedCollapseShapeResult)) {
178 auto newCollapseOp =
179 dyn_cast<tensor::CollapseShapeOp>(*simplifiedCollapseShapeResult);
180 // The collapse shape op might have been simplified away, so we can just
181 // return.
182 if (!newCollapseOp)
183 return success();
184 collapseOp = newCollapseOp;
185 }
186
187 // Materialize the output shape values of the slice operation.
188 ReifiedRankedShapedTypeDims reifiedShapes;
189 if (failed(reifyResultShapes(rewriter, op, reifiedShapes)))
190 return rewriter.notifyMatchFailure(op, "failed to reify result shapes");
191
192 // Create the destination tensor using the above values.
193 Type elementType = op.getSourceType().getElementType();
194 SmallVector<OpFoldResult> outputShape = reifiedShapes[0];
195 Value dest = rewriter.create<tensor::EmptyOp>(op->getLoc(), outputShape,
196 elementType);
197
198 // Calculate the parameters for the tile loop nest.
199 FailureOr<tensor::ExtractSliceFromCollapseHelper> params =
200 tensor::ExtractSliceFromCollapseHelper::create(rewriter, collapseOp,
201 op);
202 if (failed(params))
203 return rewriter.notifyMatchFailure(
204 op, "could not calculate tiling parameters");
205 return emitReplacement(op, dest, *params, rewriter);
206 }
207};
208
209struct RewriteExtractSliceFromCollapseShapeUsingScfFor
210 : public RewriteExtractSliceFromCollapseShapeBase {
211 RewriteExtractSliceFromCollapseShapeUsingScfFor(MLIRContext *context)
212 : RewriteExtractSliceFromCollapseShapeBase(context) {}
213 LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
214 tensor::ExtractSliceFromCollapseHelper &helper,
215 PatternRewriter &rewriter) const override {
216 Location loc = op.getLoc();
217 const unsigned numTiledDims = helper.getIterationSpaceSizes().size();
218 auto zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
219 auto one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
220 SmallVector<Value> lbs(numTiledDims, zero);
221 SmallVector<Value> steps(numTiledDims, one);
222
223 scf::LoopNest nest = scf::buildLoopNest(
224 builder&: rewriter, loc, lbs, ubs: helper.getIterationSpaceSizes(), steps, iterArgs: dest,
225 bodyBuilder: [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
226 ValueRange iterArgs) -> scf::ValueVector {
227 auto [tile, insertParams] =
228 helper.emitLoopNestBody(builder&: nestedBuilder, loc, tileInductionVars: outputIvs);
229
230 // Insert the slice into the destination.
231 return {nestedBuilder.create<tensor::InsertSliceOp>(
232 loc, tile, iterArgs[0], insertParams)};
233 });
234 rewriter.replaceOp(op, nest.results);
235
236 return success();
237 }
238};
239
240struct RewriteExtractSliceFromCollapseShapeUsingScfForeach
241 : public RewriteExtractSliceFromCollapseShapeBase {
242 RewriteExtractSliceFromCollapseShapeUsingScfForeach(MLIRContext *context)
243 : RewriteExtractSliceFromCollapseShapeBase(context) {}
244 LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest,
245 tensor::ExtractSliceFromCollapseHelper &helper,
246 PatternRewriter &rewriter) const override {
247 Location loc = op.getLoc();
248 auto forallOp = rewriter.create<scf::ForallOp>(
249 loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()),
250 /*outputs=*/dest,
251 /*mapping=*/std::nullopt,
252 [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) {
253 unsigned numThreadIdRegionArgs =
254 helper.getIterationSpaceSizes().size();
255 unsigned numOutputRegionArgs =
256 regionArgs.size() - numThreadIdRegionArgs;
257 ValueRange outputIvs = regionArgs.take_front(numThreadIdRegionArgs);
258 ValueRange outputArgs = regionArgs.take_back(numOutputRegionArgs);
259 assert(outputArgs.size() == 1 &&
260 "there should only be one output region argument");
261 auto [tile, insertParams] =
262 helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
263 // Insert the slice into the destination.
264 auto term = nestedBuilder.create<scf::InParallelOp>(loc);
265 nestedBuilder.setInsertionPointToStart(term.getBody());
266 nestedBuilder.create<tensor::ParallelInsertSliceOp>(
267 loc, tile, outputArgs[0], insertParams);
268 });
269 rewriter.replaceOp(op, forallOp->getResult(0));
270 return success();
271 }
272};
273} // namespace
274
275static LogicalResult
276applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp,
277 bool useForeach) {
278 RewritePatternSet patterns(rootOp->getContext());
279 if (useForeach)
280 patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfForeach>(
281 arg: rootOp->getContext());
282 else
283 patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfFor>(
284 arg: rootOp->getContext());
285 return applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
286}
287
288namespace {
289class DummyTrackingListener : public transform::TrackingListener {
290public:
291 using transform::TrackingListener::TrackingListener;
292
293 // Expose `findReplacementOp` as a public function, so that it can be tested.
294 Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
295 Operation *replacementOp;
296 if (!findReplacementOp(replacementOp, op, newValues).succeeded())
297 return nullptr;
298 return replacementOp;
299 }
300};
301} // namespace
302
303static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
304 // Find replaced op.
305 Operation *replaced = nullptr;
306 WalkResult status = rootOp->walk(callback: [&](Operation *op) {
307 if (op->hasAttr(name: "replaced")) {
308 if (replaced) {
309 op->emitError(message: "only one 'replaced' op is allowed per test case");
310 replaced->emitRemark(message: "other 'replaced' op");
311 return WalkResult::interrupt();
312 }
313 replaced = op;
314 }
315 return WalkResult::advance();
316 });
317 if (status.wasInterrupted())
318 return failure();
319 if (!replaced) {
320 rootOp->emitError(message: "could not find 'replaced' op");
321 return failure();
322 }
323
324 // Find replacements.
325 SmallVector<Value> replacements(replaced->getNumResults(), Value());
326 status = rootOp->walk(callback: [&](Operation *op) {
327 for (int64_t i = 0; i < replaced->getNumResults(); ++i) {
328 if (auto attr = op->getAttrOfType<IntegerAttr>("replacement_" +
329 std::to_string(i))) {
330 if (replacements[i]) {
331 op->emitError(message: "only one 'replacement_" + std::to_string(val: i) +
332 "' is allowed per test case");
333 replacements[i].getDefiningOp()->emitRemark(message: "other 'replacement_" +
334 std::to_string(val: i) + "'");
335 return WalkResult::interrupt();
336 }
337 replacements[i] = op->getResult(idx: attr.getInt());
338 }
339 }
340 return WalkResult::advance();
341 });
342 if (status.wasInterrupted())
343 return failure();
344
345 if (!llvm::all_of(Range&: replacements,
346 P: [](Value v) { return static_cast<bool>(v); })) {
347 replaced->emitError(message: "insufficient replacement values");
348 return failure();
349 }
350
351 // Find the replacement op (if any) and emit a remark/error.
352 transform::TransformState transformState =
353 transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
354 /*payloadRoot=*/nullptr);
355 MLIRContext *context = rootOp->getContext();
356 OpBuilder builder(context);
357 OwningOpRef<transform::NamedSequenceOp> transformOp =
358 builder.create<transform::NamedSequenceOp>(
359 rootOp->getLoc(),
360 /*sym_name=*/"test_sequence",
361 /*function_type=*/
362 TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})),
363 /*sym_visibility*/ StringAttr::get(context, "public"),
364 /*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()),
365 /*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()));
366 DummyTrackingListener listener(transformState, transformOp.get());
367 Operation *replacement = listener.getReplacementOp(op: replaced, newValues: replacements);
368 if (!replacement) {
369 replaced->emitError(message: "listener could not find replacement op");
370 return failure();
371 }
372
373 replacement->emitRemark(message: "replacement found");
374 return success();
375}
376
377void TestTensorTransforms::runOnOperation() {
378 Operation *rootOp = getOperation();
379 if (testSimplifyPackUnpackPatterns)
380 applySimplifyPackUnpackPatterns(rootOp);
381 if (testFoldConstantExtractSlice)
382 applyFoldConstantExtractSlicePatterns(rootOp);
383 if (testFoldConsecutiveInsertExtractSlice)
384 applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
385 if (testDropRedundantInsertSliceRankExpansion)
386 applyDropRedundantInsertSliceRankExpansionPatterns(rootOp);
387 if (testReassociativeReshapeFolding)
388 applyReassociativeReshapeFoldingPatterns(rootOp);
389 if (testFoldIntoPackAndUnpack)
390 applyFoldIntoPackAndUnpackPatterns(rootOp);
391 if (testRewriteExtractSliceWithTiledCollapseShape) {
392 if (failed(
393 result: applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
394 return signalPassFailure();
395 }
396 if (testTrackingListener)
397 if (failed(result: testTrackingListenerReplacements(rootOp)))
398 return signalPassFailure();
399}
400
401namespace mlir {
402namespace test {
403void registerTestTensorTransforms() {
404 PassRegistration<TestTensorTransforms>();
405}
406} // namespace test
407} // namespace mlir
408

source code of mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp