1//===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Passes.h"
10
11#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/OpDefinition.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include <iterator>
20#include <memory>
21#include <utility>
22
23namespace mlir {
24#define GEN_PASS_DEF_LINALGDETENSORIZEPASS
25#include "mlir/Dialect/Linalg/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29using namespace mlir::linalg;
30
31static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
32 ValueRange inputs, Location loc) {
33 assert(inputs.size() == 1);
34 auto inputType = inputs[0].getType();
35 if (isa<TensorType>(inputType))
36 return nullptr;
37
38 // A detensored value is converted back by creating a new tensor from its
39 // element(s).
40 return builder.create<tensor::FromElementsOp>(
41 loc, RankedTensorType::get({}, inputType), inputs[0]);
42}
43
44namespace {
45/// Defines the criteria a TensorType must follow in order to be considered
46/// "detensorable".
47///
48/// NOTE: For now, only 0-D tensors are supported.
49///
50/// Returns true if tensorType can be detensored.
51bool canBeDetensored(TensorType tensorType) {
52 return tensorType.hasRank() && tensorType.getRank() == 0;
53}
54
55bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
56 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
57 return genericOp &&
58 llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) {
59 return !typeConverter.isLegal(type: opOperand.get().getType());
60 });
61}
62
63/// A conversion pattern for detensoring `linalg.generic` ops.
64class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
65public:
66 using OpConversionPattern::OpConversionPattern;
67 LogicalResult
68 matchAndRewrite(GenericOp op, OpAdaptor adaptor,
69 ConversionPatternRewriter &rewriter) const override {
70 Block *originalBlock = op->getBlock();
71
72 // Gather some information about the op before inlining its region.
73 Block *opEntryBlock = &*op.getRegion().begin();
74 YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
75
76 // Split the op's region before the op. This way, we have a clear insertion
77 // point in which the op can be inlined.
78 Block *newBlock = rewriter.splitBlock(block: originalBlock, before: Block::iterator(op));
79 rewriter.inlineRegionBefore(op.getRegion(), newBlock);
80 // Now that op's region is inlined, the operands of its YieldOp are mapped
81 // to the materialized target values. Therefore, we can replace the op's
82 // uses with those of its YielOp's operands.
83 rewriter.replaceOp(op, yieldOp->getOperands());
84
85 // No need for these intermediate blocks, merge them into 1.
86 rewriter.mergeBlocks(source: opEntryBlock, dest: originalBlock, argValues: adaptor.getOperands());
87 rewriter.mergeBlocks(source: newBlock, dest: originalBlock, argValues: {});
88
89 rewriter.eraseOp(op: &*Block::iterator(yieldOp));
90
91 return success();
92 }
93};
94
95/// A conversion pattern for detensoring internal (non-entry) blocks within a
96/// function.
97struct FunctionNonEntryBlockConversion
98 : public OpInterfaceConversionPattern<FunctionOpInterface> {
99 FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
100 DenseSet<BlockArgument> blockArgsToDetensor)
101 : OpInterfaceConversionPattern(converter, ctx),
102 blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
103
104 LogicalResult
105 matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
106 ConversionPatternRewriter &rewriter) const override {
107 rewriter.startOpModification(op: op);
108 Region &region = op.getFunctionBody();
109 SmallVector<TypeConverter::SignatureConversion, 2> conversions;
110
111 for (Block &block : llvm::drop_begin(region, 1)) {
112 conversions.emplace_back(block.getNumArguments());
113 TypeConverter::SignatureConversion &back = conversions.back();
114
115 for (BlockArgument blockArgument : block.getArguments()) {
116 int idx = blockArgument.getArgNumber();
117
118 if (blockArgsToDetensor.count(blockArgument))
119 back.addInputs(idx, {getTypeConverter()->convertType(
120 block.getArgumentTypes()[idx])});
121 else
122 back.addInputs(idx, {block.getArgumentTypes()[idx]});
123 }
124 }
125
126 if (failed(result: rewriter.convertNonEntryRegionTypes(region: &region, converter: *typeConverter,
127 blockConversions: conversions))) {
128 rewriter.cancelOpModification(op: op);
129 return failure();
130 }
131
132 rewriter.finalizeOpModification(op: op);
133 return success();
134 }
135
136private:
137 const DenseSet<BlockArgument> blockArgsToDetensor;
138};
139
140class DetensorizeTypeConverter : public TypeConverter {
141public:
142 DetensorizeTypeConverter() {
143 addConversion(callback: [](Type type) { return type; });
144
145 // A TensorType that can be detensored, is converted to the underlying
146 // element type.
147 addConversion(callback: [](TensorType tensorType) -> Type {
148 if (canBeDetensored(tensorType))
149 return tensorType.getElementType();
150
151 return tensorType;
152 });
153
154 // A tensor value is detensoried by extracting its element(s).
155 addTargetMaterialization(callback: [](OpBuilder &builder, Type type,
156 ValueRange inputs, Location loc) -> Value {
157 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
158 });
159
160 addSourceMaterialization(callback&: sourceMaterializationCallback);
161 addArgumentMaterialization(callback&: sourceMaterializationCallback);
162 }
163};
164
165/// @see LinalgDetensorize in Linalg/Passes.td for more details.
166struct LinalgDetensorize
167 : public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
168 using impl::LinalgDetensorizePassBase<
169 LinalgDetensorize>::LinalgDetensorizePassBase;
170 LinalgDetensorize() = default;
171
172 class CostModel {
173 public:
174 virtual ~CostModel() = default;
175
176 /// A cost model algorithm computes the following outputs:
177 ///
178 /// - opsToDetensor: the list of linalg ops that should be
179 /// detensored.
180 ///
181 /// - blockArgsToDetensor: since the operands and results of detensored
182 /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
183 /// from a BB argument and a linalg op's output can be passed to successor
184 /// BBs), we need to maintain the sub-set of arguments that should be
185 /// detensored (i.e. converted by typeConverter) for each affected BB.
186 ///
187 /// Example:
188 ///
189 /// For the following snippet:
190 /// ...
191 /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
192 /// %7 = tensor.empty() : tensor<i32>
193 /// %8 = linalg.generic #attrs
194 /// ins(%6, %6 : tensor<i32>, tensor<i32>)
195 /// outs(%7 : tensor<i32>) {
196 /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
197 /// %9 = arith.addi %arg0, %arg1 : i32
198 /// linalg.yield %9 : i32
199 /// } -> tensor<i32>
200 /// %10 = "some.op"(%9)
201 /// br ^bb2(%8 : tensor<i32>)
202 /// ...
203 ///
204 /// if the cost model decides that the linalg.generic op should be
205 /// detensored, then:
206 /// - opsToDetensor should be = {linalg.generic{add}}.
207 /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
208 virtual void compute(FunctionOpInterface func,
209 DetensorizeTypeConverter typeConverter,
210 DenseSet<Operation *> &opsToDetensor,
211 DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
212
213 /// From the blockArgsToDetensor set computed by a CostModel
214 /// implementation, this method computes the corresponding branch op
215 /// detensoring. The result is a map from a branch op to a subset of indices
216 /// of its operands. The indices specify which of the branch op's operands
217 /// should be detensored.
218 ///
219 /// For the previous example, this method would compute: {bb2 -> {0}}.
220 static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
221 const DenseSet<BlockArgument> &blockArgsToDetensor) {
222 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
223
224 for (auto blockArgumentElem : blockArgsToDetensor) {
225 Block *block = blockArgumentElem.getOwner();
226
227 for (PredecessorIterator pred = block->pred_begin();
228 pred != block->pred_end(); ++pred) {
229 BranchOpInterface terminator =
230 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
231 auto blockOperands =
232 terminator.getSuccessorOperands(pred.getSuccessorIndex());
233
234 if (blockOperands.empty() ||
235 blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
236 continue;
237
238 detensorableBranchOps[terminator].insert(
239 blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
240 }
241 }
242
243 return detensorableBranchOps;
244 }
245 };
246
247 /// Detensorize linalg ops involved in control-flow within a function.
248 ///
249 /// This model starts from BranchOps and CondBranchOps within a function. For
250 /// each such branch, the model then walks the use-def chain for the branch's
251 /// condition backwards in order to understand where the condition's value
252 /// comes from. If the condition value is (indirectly) computed by a linalg op
253 /// that can be detensored, the model then continues walking the use-def chain
254 /// in order to understand where the linalg op's operands come from. This
255 /// leads to discovering a "detensoring component". A detensoring component is
256 /// the set of operations + block arguments that are involved in control-flow
257 /// AND can be detensored.
258 class ControlFlowDetectionModel : public CostModel {
259 public:
260 void compute(FunctionOpInterface func,
261 DetensorizeTypeConverter typeConverter,
262 DenseSet<Operation *> &opsToDetensor,
263 DenseSet<BlockArgument> &blockArgsToDetensor) override {
264 SmallVector<Value> workList;
265
266 func->walk([&](cf::CondBranchOp condBr) {
267 llvm::append_range(workList, condBr.getOperands());
268 });
269
270 func->walk([&](cf::BranchOp br) {
271 llvm::append_range(workList, br.getOperands());
272 });
273
274 DenseSet<Value> visitedValues;
275 DenseSet<Operation *> visitedOps;
276
277 // For a (to-be-detesored) value, check if it "escapes" the block by being
278 // passed to terminator. If it does, then workList is updated with the
279 // corresponding argument to the successor block.
280 auto updateWorkListWithSuccessorArguments =
281 [&](Value value, BranchOpInterface terminator) {
282 if (!terminator)
283 return;
284
285 for (auto operandIdx :
286 llvm::seq<unsigned>(0, terminator->getOperands().size())) {
287 Value operand = terminator->getOperand(operandIdx);
288
289 if (operand == value) {
290 auto succBlockArg =
291 terminator.getSuccessorBlockArgument(operandIdx);
292
293 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
294 workList.push_back(*succBlockArg);
295 }
296 }
297 };
298
299 while (!workList.empty()) {
300 Value currentItem = workList.pop_back_val();
301
302 if (!visitedValues.insert(V: currentItem).second)
303 continue;
304
305 // 1 - Look forward:
306 // 1.1 - If currentItem escapes to one or more successors, add
307 // the corresponding successor arguments to workList.
308 updateWorkListWithSuccessorArguments(
309 currentItem, dyn_cast<BranchOpInterface>(
310 currentItem.getParentBlock()->getTerminator()));
311
312 // 1.2 - For each user of currentItem, add the defined values to
313 // workList. This way, the user ops can be inspected later if they are
314 // detensorable and if so, their operands will be added to workList to
315 // potentially discover other parts of the detensorable component.
316 for (auto *user : currentItem.getUsers())
317 llvm::append_range(C&: workList, R: user->getResults());
318
319 // 2 - Look backward:
320 // 2.1 - The current item is defined by a block argument. If the owner
321 // block is a non-entry one, then:
322 // * Add the argument to blockArgsToDetensor.
323 // * Walk the use-def chain backwards to add each predecessor's
324 // terminator-operands corresponding to currentItem to workList.
325 if (dyn_cast<BlockArgument>(Val&: currentItem)) {
326 BlockArgument currentItemBlockArgument =
327 cast<BlockArgument>(Val&: currentItem);
328 Block *ownerBlock = currentItemBlockArgument.getOwner();
329
330 // Function arguments are not detensored/converted.
331 if (&*ownerBlock->getParent()->begin() == ownerBlock)
332 continue;
333
334 // This inner-block argument is involved in control-flow, it should be
335 // detensored.
336 blockArgsToDetensor.insert(V: currentItemBlockArgument);
337
338 for (PredecessorIterator pred = ownerBlock->pred_begin();
339 pred != ownerBlock->pred_end(); ++pred) {
340 BranchOpInterface predTerminator =
341 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
342
343 // TODO: For now, we give up if any of the control-flow components
344 // in a function is not detensorable. Fix that.
345 if (!predTerminator) {
346 opsToDetensor.clear();
347 blockArgsToDetensor.clear();
348 return;
349 }
350
351 auto ownerBlockOperands =
352 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
353
354 if (ownerBlockOperands.empty() ||
355 ownerBlockOperands.isOperandProduced(
356 currentItemBlockArgument.getArgNumber()))
357 continue;
358
359 // For each predecessor, add the value it passes to that argument to
360 // workList to find out how it's computed.
361 workList.push_back(
362 Elt: ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
363 }
364
365 continue;
366 }
367
368 Operation *currentItemDefiningOp = currentItem.getDefiningOp();
369
370 if (!visitedOps.insert(V: currentItemDefiningOp).second)
371 continue;
372
373 // 2.2 - The current item is computed by a GenericOp. If the op should
374 // be detensored, then:
375 // * Add it to opsToDetensor.
376 // * Add its operands to workList to discover other parts of the
377 // potentially detensorable component.
378 if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
379 // The op was encountered already, no need to inspect it again.
380 if (opsToDetensor.count(V: genericOp))
381 continue;
382
383 // The op should not be detensored, give up on it but continue with
384 // discovering the rest of the control-flow component.
385 if (!shouldBeDetensored(genericOp, typeConverter)) {
386 continue;
387 }
388
389 opsToDetensor.insert(genericOp);
390 llvm::append_range(workList, genericOp.getInputs());
391 continue;
392 }
393
394 // 2.3 - The current item is the result of a FromElementsOp, it will be
395 // trivially detensored later as part of canonicalization patterns
396 // applied at the end of detensoring.
397 //
398 // Note: No need to check whether the result type of this op is
399 // detensorable since if it wasn't we wouldn't reach that point in the
400 // work list.
401 if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
402 continue;
403
404 // 2.4 - The current item is the result of a scalar op, add all its
405 // operands to the work list.
406 if (llvm::all_of(
407 Range: currentItemDefiningOp->getResultTypes(),
408 P: [&](Type resultType) { return resultType.isIntOrFloat(); }))
409 llvm::append_range(C&: workList, R: currentItemDefiningOp->getOperands());
410 }
411
412 // Since the cost model gives up on some ops (see the details of step 2.2
413 // above), block arguments that correspond to the values produced by those
414 // ops should not be detensored as well.
415
416 DenseSet<BlockArgument> blockArgsToRemove;
417
418 for (auto &blockArg : blockArgsToDetensor) {
419 Block *block = blockArg.getParentBlock();
420
421 // For the potentially detensorable block argument, find the
422 // correpsonding operands in predecessor blocks.
423 for (PredecessorIterator pred = block->pred_begin();
424 pred != block->pred_end(); ++pred) {
425 BranchOpInterface terminator =
426 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
427 auto blockOperands =
428 terminator.getSuccessorOperands(pred.getSuccessorIndex());
429
430 if (blockOperands.empty() ||
431 blockOperands.isOperandProduced(blockArg.getArgNumber()))
432 continue;
433
434 Operation *definingOp =
435 blockOperands[blockArg.getArgNumber()].getDefiningOp();
436
437 // If the operand is defined by a GenericOp that will not be
438 // detensored, then do not detensor the corresponding block argument.
439 if (isa_and_nonnull<GenericOp>(definingOp) &&
440 opsToDetensor.count(definingOp) == 0) {
441 blockArgsToRemove.insert(V: blockArg);
442 break;
443 }
444 }
445 }
446
447 for (auto &blockArg : blockArgsToRemove) {
448 blockArgsToDetensor.erase(V: blockArg);
449 }
450 }
451 };
452
453 /// Detensorize everything that can detensored.
454 class AggressiveDetensoringModel : public CostModel {
455 public:
456 void compute(FunctionOpInterface func,
457 DetensorizeTypeConverter typeConverter,
458 DenseSet<Operation *> &opsToDetensor,
459 DenseSet<BlockArgument> &blockArgsToDetensor) override {
460 func->walk([&](GenericOp genericOp) {
461 if (shouldBeDetensored(genericOp, typeConverter))
462 opsToDetensor.insert(genericOp);
463 });
464
465 for (Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
466 for (BlockArgument blockArgument : block.getArguments())
467 blockArgsToDetensor.insert(blockArgument);
468 }
469 };
470
471 void runOnOperation() override {
472 MLIRContext *context = &getContext();
473 DetensorizeTypeConverter typeConverter;
474 RewritePatternSet patterns(context);
475 ConversionTarget target(*context);
476 DenseSet<Operation *> opsToDetensor;
477 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
478 DenseSet<BlockArgument> blockArgsToDetensor;
479 FunctionOpInterface funcOp = getOperation();
480
481 if (funcOp.getFunctionBody().empty())
482 return;
483
484 // Make sure the entry block of the function doesn't contain any Linalg ops.
485 // Otherwise, it may lead to the signature of the block being changed by the
486 // dialect conversion below, which would make the function op invalid
487 // because its type shouldn't change.
488 IRRewriter rewriter(funcOp->getContext());
489 Block *entryBlock = &funcOp.getFunctionBody().front();
490 Block *postEntryBlock =
491 rewriter.splitBlock(block: entryBlock, before: entryBlock->begin());
492 rewriter.setInsertionPointToStart(entryBlock);
493 auto branch =
494 rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock);
495
496 if (aggressiveMode.getValue()) {
497 AggressiveDetensoringModel costModel;
498 costModel.compute(func: funcOp, typeConverter, opsToDetensor,
499 blockArgsToDetensor);
500 } else {
501 ControlFlowDetectionModel costModel;
502 costModel.compute(func: funcOp, typeConverter, opsToDetensor,
503 blockArgsToDetensor);
504 }
505
506 detensorableBranchOps =
507 CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
508
509 target.addDynamicallyLegalOp<GenericOp>(
510 [&](GenericOp op) { return !opsToDetensor.count(op); });
511
512 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
513 // A function is legal if all of its non-entry blocks are legal. We
514 // don't legalize the entry block (i.e. the function's signature)
515 // since detensoring can't happen along external calling convention
516 // boundaries, which we conservatively approximate as all function
517 // signatures.
518 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
519 Region &body = funcOp.getFunctionBody();
520 return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
521 return !llvm::any_of(
522 blockArgsToDetensor, [&](BlockArgument blockArgument) {
523 return blockArgument.getOwner() == &block &&
524 !typeConverter.isLegal(blockArgument.getType());
525 });
526 });
527 }
528
529 if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
530 isLegalForReturnOpTypeConversionPattern(op, typeConverter,
531 /*returnOpAlwaysLegal*/ true))
532 return true;
533
534 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
535 if (!detensorableBranchOps.count(branchOp))
536 return true;
537
538 for (auto operandIdx : detensorableBranchOps[branchOp])
539 if (!typeConverter.isLegal(
540 branchOp->getOperand(operandIdx).getType()))
541 return false;
542
543 return true;
544 }
545
546 return false;
547 });
548
549 patterns.add<DetensorizeGenericOp>(arg&: typeConverter, args&: context);
550 patterns.add<FunctionNonEntryBlockConversion>(arg&: context, args&: typeConverter,
551 args&: blockArgsToDetensor);
552 // Since non-entry block arguments get detensorized, we also need to
553 // update the control flow inside the function to reflect the correct
554 // types.
555 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
556 int operandIdx) -> bool {
557 return detensorableBranchOps.count(Val: branchOp) &&
558 detensorableBranchOps[branchOp].count(operandIdx);
559 };
560
561 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
562 shouldConvertBranchOperand);
563
564 if (failed(
565 applyFullConversion(getOperation(), target, std::move(patterns))))
566 signalPassFailure();
567
568 RewritePatternSet canonPatterns(context);
569 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
570 if (failed(applyPatternsAndFoldGreedily(getOperation(),
571 std::move(canonPatterns))))
572 signalPassFailure();
573
574 // Get rid of the dummy entry block we created in the beginning to work
575 // around dialect conversion signature rewriting.
576 rewriter.eraseOp(op: branch);
577 rewriter.mergeBlocks(source: postEntryBlock, dest: entryBlock);
578 }
579};
580} // namespace
581

source code of mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp