| 1 | //===- TestTensorTransforms.cpp - Test Tensor transformation patterns -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements logic for testing Tensor transformations. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 15 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 17 | #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" |
| 18 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| 19 | #include "mlir/Dialect/Transform/IR/TransformOps.h" |
| 20 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| 21 | #include "mlir/Pass/Pass.h" |
| 22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 23 | |
| 24 | using namespace mlir; |
| 25 | |
| 26 | namespace { |
| 27 | struct TestTensorTransforms |
| 28 | : public PassWrapper<TestTensorTransforms, OperationPass<>> { |
| 29 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorTransforms) |
| 30 | |
| 31 | TestTensorTransforms() = default; |
| 32 | TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {} |
| 33 | |
| 34 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 35 | registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect, |
| 36 | transform::TransformDialect>(); |
| 37 | } |
| 38 | |
| 39 | StringRef getArgument() const final { |
| 40 | return "test-tensor-transform-patterns" ; |
| 41 | } |
| 42 | StringRef getDescription() const final { |
| 43 | return "Test Tensor transformation patterns by applying them greedily." ; |
| 44 | } |
| 45 | |
| 46 | void runOnOperation() override; |
| 47 | |
| 48 | Option<bool> { |
| 49 | *this, "test-fold-constant-extract-slice" , |
| 50 | llvm::cl::desc("Test folding arith.constant and tensor.extract_slice" ), |
| 51 | llvm::cl::init(Val: false)}; |
| 52 | |
| 53 | Option<bool> { |
| 54 | *this, "test-fold-consecutive-insert-extract-slice" , |
| 55 | llvm::cl::desc( |
| 56 | "Test folding consecutive tensor.insert_slice/tensor.extract_slice" ), |
| 57 | llvm::cl::init(Val: false)}; |
| 58 | |
| 59 | Option<bool> { |
| 60 | *this, "test-rewrite-extract-slice-from-collapse-shape" , |
| 61 | llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " |
| 62 | "with loop nest" ), |
| 63 | llvm::cl::init(Val: false)}; |
| 64 | |
| 65 | Option<bool> testDropRedundantInsertSliceRankExpansion{ |
| 66 | *this, "test-drop-redundant-insert-slice-rank-expansion" , |
| 67 | llvm::cl::desc("Test dropping redundant insert_slice rank expansions" ), |
| 68 | llvm::cl::init(Val: false)}; |
| 69 | |
| 70 | Option<bool> testReassociativeReshapeFolding{ |
| 71 | *this, "test-reassociative-reshape-folding" , |
| 72 | llvm::cl::desc("Test folding of expand_shape/collapse_shape" ), |
| 73 | llvm::cl::init(Val: false)}; |
| 74 | |
| 75 | Option<bool> testBubbleUpExpandShapePatterns{ |
| 76 | *this, "test-expand-shape-bubbling" , |
| 77 | llvm::cl::desc("Test folding of expand_shape/collapse_shape" ), |
| 78 | llvm::cl::init(Val: false)}; |
| 79 | |
| 80 | Option<bool> { |
| 81 | *this, "test-fold-extract-from-collapse-shape" , |
| 82 | llvm::cl::desc("Test folding of extract from collapse_shape" ), |
| 83 | llvm::cl::init(Val: false)}; |
| 84 | |
| 85 | Option<bool> useForeach{ |
| 86 | *this, "use-foreach" , |
| 87 | llvm::cl::desc( |
| 88 | "Use the scf.forall operation when generating loop nests for " |
| 89 | "the extract_slice of collapse_shape pattern" ), |
| 90 | llvm::cl::init(Val: false)}; |
| 91 | |
| 92 | Option<bool> testTrackingListener{ |
| 93 | *this, "test-tracking-listener" , |
| 94 | llvm::cl::desc("Test tensor TrackingListener for the transform dialect" ), |
| 95 | llvm::cl::init(Val: false)}; |
| 96 | }; |
| 97 | } // namespace |
| 98 | |
| 99 | static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) { |
| 100 | RewritePatternSet patterns(rootOp->getContext()); |
| 101 | tensor::populateReassociativeReshapeFoldingPatterns(patterns); |
| 102 | (void)applyPatternsGreedily(rootOp, std::move(patterns)); |
| 103 | } |
| 104 | |
| 105 | static void applyBubbleUpExpandShapePatterns(Operation *rootOp) { |
| 106 | RewritePatternSet patterns(rootOp->getContext()); |
| 107 | tensor::populateBubbleUpExpandShapePatterns(patterns); |
| 108 | (void)applyPatternsGreedily(rootOp, std::move(patterns)); |
| 109 | } |
| 110 | |
| 111 | static void (Operation *rootOp) { |
| 112 | RewritePatternSet patterns(rootOp->getContext()); |
| 113 | tensor::ControlConstantExtractSliceFusionFn controlFn = |
| 114 | [](tensor::ExtractSliceOp op) { |
| 115 | if (!op.getSource().hasOneUse()) |
| 116 | return false; |
| 117 | |
| 118 | auto resultType = cast<ShapedType>(op.getResult().getType()); |
| 119 | constexpr int64_t kConstantFoldingMaxNumElements = 1024; |
| 120 | return resultType.getNumElements() <= kConstantFoldingMaxNumElements; |
| 121 | }; |
| 122 | |
| 123 | tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn); |
| 124 | (void)applyPatternsGreedily(rootOp, std::move(patterns)); |
| 125 | } |
| 126 | |
| 127 | static void (Operation *rootOp) { |
| 128 | RewritePatternSet patterns(rootOp->getContext()); |
| 129 | tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); |
| 130 | (void)applyPatternsGreedily(rootOp, std::move(patterns)); |
| 131 | } |
| 132 | |
| 133 | static void |
| 134 | applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) { |
| 135 | RewritePatternSet patterns(rootOp->getContext()); |
| 136 | tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); |
| 137 | (void)applyPatternsGreedily(rootOp, std::move(patterns)); |
| 138 | } |
| 139 | |
| 140 | static void (Operation *rootOp) { |
| 141 | RewritePatternSet patterns(rootOp->getContext()); |
| 142 | tensor::populateFoldCollapseExtractPatterns(patterns); |
| 143 | (void)applyPatternsGreedily(rootOp, std::move(patterns)); |
| 144 | } |
| 145 | |
| 146 | namespace { |
| 147 | /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`. |
| 148 | /// The `tensor.extract_slice` is replaced by a loop or gather operation that |
| 149 | /// stitches together the desired tile from slices of the source of the collapse |
| 150 | /// shape op. |
| 151 | struct |
| 152 | : public OpRewritePattern<tensor::ExtractSliceOp> { |
| 153 | (MLIRContext *context) |
| 154 | : mlir::OpRewritePattern<tensor::ExtractSliceOp>(context) {} |
| 155 | |
| 156 | /// Emit a loop or gather operation that uses `helper` to take each point in |
| 157 | /// the parallel iteration space bounds, extract a slice from the source |
| 158 | /// tensor and insert it into `dest`. For examples, see below for `scf.for` |
| 159 | /// and `scf.foreach`. |
| 160 | virtual LogicalResult |
| 161 | (tensor::ExtractSliceOp op, Value dest, |
| 162 | tensor::ExtractSliceFromCollapseHelper &helper, |
| 163 | PatternRewriter &rewriter) const = 0; |
| 164 | |
| 165 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, |
| 166 | PatternRewriter &rewriter) const override { |
| 167 | auto collapseOp = op.getSource().getDefiningOp<tensor::CollapseShapeOp>(); |
| 168 | if (!collapseOp) |
| 169 | return rewriter.notifyMatchFailure( |
| 170 | op, "producer is not a tensor.collapse_shape op" ); |
| 171 | |
| 172 | // Try to simplify the collapse shape using a rank-reducing slice, if |
| 173 | // possible. |
| 174 | FailureOr<Operation *> simplifiedCollapseShapeResult = |
| 175 | tensor::simplifyCollapseShapeWithRankReducingExtractSlice(collapseOp, |
| 176 | rewriter); |
| 177 | if (succeeded(Result: simplifiedCollapseShapeResult)) { |
| 178 | auto newCollapseOp = |
| 179 | dyn_cast<tensor::CollapseShapeOp>(*simplifiedCollapseShapeResult); |
| 180 | // The collapse shape op might have been simplified away, so we can just |
| 181 | // return. |
| 182 | if (!newCollapseOp) |
| 183 | return success(); |
| 184 | collapseOp = newCollapseOp; |
| 185 | } |
| 186 | |
| 187 | // Materialize the output shape values of the slice operation. |
| 188 | ReifiedRankedShapedTypeDims reifiedShapes; |
| 189 | if (failed(reifyResultShapes(rewriter, op, reifiedShapes))) |
| 190 | return rewriter.notifyMatchFailure(op, "failed to reify result shapes" ); |
| 191 | |
| 192 | // Create the destination tensor using the above values. |
| 193 | Type elementType = op.getSourceType().getElementType(); |
| 194 | SmallVector<OpFoldResult> outputShape = reifiedShapes[0]; |
| 195 | Value dest = rewriter.create<tensor::EmptyOp>(op->getLoc(), outputShape, |
| 196 | elementType); |
| 197 | |
| 198 | // Calculate the parameters for the tile loop nest. |
| 199 | FailureOr<tensor::ExtractSliceFromCollapseHelper> params = |
| 200 | tensor::ExtractSliceFromCollapseHelper::create(rewriter, collapseOp, |
| 201 | op); |
| 202 | if (failed(params)) |
| 203 | return rewriter.notifyMatchFailure( |
| 204 | op, "could not calculate tiling parameters" ); |
| 205 | return emitReplacement(op, dest, *params, rewriter); |
| 206 | } |
| 207 | }; |
| 208 | |
| 209 | struct |
| 210 | : public RewriteExtractSliceFromCollapseShapeBase { |
| 211 | (MLIRContext *context) |
| 212 | : RewriteExtractSliceFromCollapseShapeBase(context) {} |
| 213 | LogicalResult (tensor::ExtractSliceOp op, Value dest, |
| 214 | tensor::ExtractSliceFromCollapseHelper &helper, |
| 215 | PatternRewriter &rewriter) const override { |
| 216 | Location loc = op.getLoc(); |
| 217 | const unsigned numTiledDims = helper.getIterationSpaceSizes().size(); |
| 218 | auto zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 219 | auto one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 220 | SmallVector<Value> lbs(numTiledDims, zero); |
| 221 | SmallVector<Value> steps(numTiledDims, one); |
| 222 | |
| 223 | scf::LoopNest nest = scf::buildLoopNest( |
| 224 | builder&: rewriter, loc, lbs, ubs: helper.getIterationSpaceSizes(), steps, iterArgs: dest, |
| 225 | bodyBuilder: [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs, |
| 226 | ValueRange iterArgs) -> scf::ValueVector { |
| 227 | auto [tile, insertParams] = |
| 228 | helper.emitLoopNestBody(builder&: nestedBuilder, loc, tileInductionVars: outputIvs); |
| 229 | |
| 230 | // Insert the slice into the destination. |
| 231 | return {nestedBuilder.create<tensor::InsertSliceOp>( |
| 232 | loc, tile, iterArgs[0], insertParams)}; |
| 233 | }); |
| 234 | rewriter.replaceOp(op, nest.results); |
| 235 | |
| 236 | return success(); |
| 237 | } |
| 238 | }; |
| 239 | |
| 240 | struct |
| 241 | : public RewriteExtractSliceFromCollapseShapeBase { |
| 242 | (MLIRContext *context) |
| 243 | : RewriteExtractSliceFromCollapseShapeBase(context) {} |
| 244 | LogicalResult (tensor::ExtractSliceOp op, Value dest, |
| 245 | tensor::ExtractSliceFromCollapseHelper &helper, |
| 246 | PatternRewriter &rewriter) const override { |
| 247 | Location loc = op.getLoc(); |
| 248 | auto forallOp = rewriter.create<scf::ForallOp>( |
| 249 | loc, /*numThreads=*/getAsOpFoldResult(helper.getIterationSpaceSizes()), |
| 250 | /*outputs=*/dest, |
| 251 | /*mapping=*/std::nullopt, |
| 252 | [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) { |
| 253 | unsigned numThreadIdRegionArgs = |
| 254 | helper.getIterationSpaceSizes().size(); |
| 255 | unsigned numOutputRegionArgs = |
| 256 | regionArgs.size() - numThreadIdRegionArgs; |
| 257 | ValueRange outputIvs = regionArgs.take_front(numThreadIdRegionArgs); |
| 258 | ValueRange outputArgs = regionArgs.take_back(numOutputRegionArgs); |
| 259 | assert(outputArgs.size() == 1 && |
| 260 | "there should only be one output region argument" ); |
| 261 | auto [tile, insertParams] = |
| 262 | helper.emitLoopNestBody(nestedBuilder, loc, outputIvs); |
| 263 | // Insert the slice into the destination. |
| 264 | auto term = nestedBuilder.create<scf::InParallelOp>(loc); |
| 265 | nestedBuilder.setInsertionPointToStart(term.getBody()); |
| 266 | nestedBuilder.create<tensor::ParallelInsertSliceOp>( |
| 267 | loc, tile, outputArgs[0], insertParams); |
| 268 | }); |
| 269 | rewriter.replaceOp(op, forallOp->getResult(0)); |
| 270 | return success(); |
| 271 | } |
| 272 | }; |
| 273 | } // namespace |
| 274 | |
| 275 | static LogicalResult |
| 276 | (Operation *rootOp, |
| 277 | bool useForeach) { |
| 278 | RewritePatternSet patterns(rootOp->getContext()); |
| 279 | if (useForeach) |
| 280 | patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfForeach>( |
| 281 | arg: rootOp->getContext()); |
| 282 | else |
| 283 | patterns.add<RewriteExtractSliceFromCollapseShapeUsingScfFor>( |
| 284 | arg: rootOp->getContext()); |
| 285 | return applyPatternsGreedily(rootOp, std::move(patterns)); |
| 286 | } |
| 287 | |
| 288 | namespace { |
| 289 | class DummyTrackingListener : public transform::TrackingListener { |
| 290 | public: |
| 291 | using transform::TrackingListener::TrackingListener; |
| 292 | |
| 293 | // Expose `findReplacementOp` as a public function, so that it can be tested. |
| 294 | Operation *getReplacementOp(Operation *op, ValueRange newValues) const { |
| 295 | Operation *replacementOp; |
| 296 | if (!findReplacementOp(replacementOp, op, newValues).succeeded()) |
| 297 | return nullptr; |
| 298 | return replacementOp; |
| 299 | } |
| 300 | }; |
| 301 | } // namespace |
| 302 | |
| 303 | static LogicalResult testTrackingListenerReplacements(Operation *rootOp) { |
| 304 | // Find replaced op. |
| 305 | Operation *replaced = nullptr; |
| 306 | WalkResult status = rootOp->walk(callback: [&](Operation *op) { |
| 307 | if (op->hasAttr(name: "replaced" )) { |
| 308 | if (replaced) { |
| 309 | op->emitError(message: "only one 'replaced' op is allowed per test case" ); |
| 310 | replaced->emitRemark(message: "other 'replaced' op" ); |
| 311 | return WalkResult::interrupt(); |
| 312 | } |
| 313 | replaced = op; |
| 314 | } |
| 315 | return WalkResult::advance(); |
| 316 | }); |
| 317 | if (status.wasInterrupted()) |
| 318 | return failure(); |
| 319 | if (!replaced) { |
| 320 | rootOp->emitError(message: "could not find 'replaced' op" ); |
| 321 | return failure(); |
| 322 | } |
| 323 | |
| 324 | // Find replacements. |
| 325 | SmallVector<Value> replacements(replaced->getNumResults(), Value()); |
| 326 | status = rootOp->walk(callback: [&](Operation *op) { |
| 327 | for (int64_t i = 0; i < replaced->getNumResults(); ++i) { |
| 328 | if (auto attr = op->getAttrOfType<IntegerAttr>("replacement_" + |
| 329 | std::to_string(i))) { |
| 330 | if (replacements[i]) { |
| 331 | op->emitError(message: "only one 'replacement_" + std::to_string(val: i) + |
| 332 | "' is allowed per test case" ); |
| 333 | replacements[i].getDefiningOp()->emitRemark(message: "other 'replacement_" + |
| 334 | std::to_string(val: i) + "'" ); |
| 335 | return WalkResult::interrupt(); |
| 336 | } |
| 337 | replacements[i] = op->getResult(idx: attr.getInt()); |
| 338 | } |
| 339 | } |
| 340 | return WalkResult::advance(); |
| 341 | }); |
| 342 | if (status.wasInterrupted()) |
| 343 | return failure(); |
| 344 | |
| 345 | if (!llvm::all_of(Range&: replacements, |
| 346 | P: [](Value v) { return static_cast<bool>(v); })) { |
| 347 | replaced->emitError(message: "insufficient replacement values" ); |
| 348 | return failure(); |
| 349 | } |
| 350 | |
| 351 | // Find the replacement op (if any) and emit a remark/error. |
| 352 | transform::TransformState transformState = |
| 353 | transform::detail::makeTransformStateForTesting(/*region=*/nullptr, |
| 354 | /*payloadRoot=*/nullptr); |
| 355 | MLIRContext *context = rootOp->getContext(); |
| 356 | OpBuilder builder(context); |
| 357 | OwningOpRef<transform::NamedSequenceOp> transformOp = |
| 358 | builder.create<transform::NamedSequenceOp>( |
| 359 | rootOp->getLoc(), |
| 360 | /*sym_name=*/"test_sequence" , |
| 361 | /*function_type=*/ |
| 362 | TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})), |
| 363 | /*sym_visibility*/ StringAttr::get(context, "public" ), |
| 364 | /*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()), |
| 365 | /*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>())); |
| 366 | DummyTrackingListener listener(transformState, transformOp.get()); |
| 367 | Operation *replacement = listener.getReplacementOp(op: replaced, newValues: replacements); |
| 368 | if (!replacement) { |
| 369 | replaced->emitError(message: "listener could not find replacement op" ); |
| 370 | return failure(); |
| 371 | } |
| 372 | |
| 373 | replacement->emitRemark(message: "replacement found" ); |
| 374 | return success(); |
| 375 | } |
| 376 | |
| 377 | void TestTensorTransforms::runOnOperation() { |
| 378 | Operation *rootOp = getOperation(); |
| 379 | if (testFoldConstantExtractSlice) |
| 380 | applyFoldConstantExtractSlicePatterns(rootOp); |
| 381 | if (testFoldConsecutiveInsertExtractSlice) |
| 382 | applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); |
| 383 | if (testDropRedundantInsertSliceRankExpansion) |
| 384 | applyDropRedundantInsertSliceRankExpansionPatterns(rootOp); |
| 385 | if (testReassociativeReshapeFolding) |
| 386 | applyReassociativeReshapeFoldingPatterns(rootOp); |
| 387 | if (testBubbleUpExpandShapePatterns) |
| 388 | applyBubbleUpExpandShapePatterns(rootOp); |
| 389 | if (testRewriteExtractSliceWithTiledCollapseShape) { |
| 390 | if (failed( |
| 391 | Result: applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach))) |
| 392 | return signalPassFailure(); |
| 393 | } |
| 394 | if (testFoldExtractFromCollapseShape) |
| 395 | applyFoldExtractFromCollapseShapePatterns(rootOp); |
| 396 | if (testTrackingListener) |
| 397 | if (failed(Result: testTrackingListenerReplacements(rootOp))) |
| 398 | return signalPassFailure(); |
| 399 | } |
| 400 | |
| 401 | namespace mlir { |
| 402 | namespace test { |
| 403 | void registerTestTensorTransforms() { |
| 404 | PassRegistration<TestTensorTransforms>(); |
| 405 | } |
| 406 | } // namespace test |
| 407 | } // namespace mlir |
| 408 | |