1//===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Passes.h"
10
11#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
13#include "mlir/Dialect/Linalg/IR/Linalg.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/IR/OpDefinition.h"
16#include "mlir/Transforms/DialectConversion.h"
17#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18#include <utility>
19
20namespace mlir {
21#define GEN_PASS_DEF_LINALGDETENSORIZEPASS
22#include "mlir/Dialect/Linalg/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::linalg;
27
28static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
29 ValueRange inputs, Location loc) {
30 assert(inputs.size() == 1);
31 auto inputType = inputs[0].getType();
32 if (isa<TensorType>(Val: inputType))
33 return nullptr;
34
35 // A detensored value is converted back by creating a new tensor from its
36 // element(s).
37 return builder.create<tensor::FromElementsOp>(
38 location: loc, args: RankedTensorType::get(shape: {}, elementType: inputType), args: inputs[0]);
39}
40
41namespace {
42/// Defines the criteria a TensorType must follow in order to be considered
43/// "detensorable".
44///
45/// NOTE: For now, only 0-D tensors are supported.
46///
47/// Returns true if tensorType can be detensored.
48bool canBeDetensored(TensorType tensorType) {
49 return tensorType.hasRank() && tensorType.getRank() == 0;
50}
51
52bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
53 GenericOp genericOp = dyn_cast_or_null<GenericOp>(Val: op);
54 return genericOp &&
55 llvm::all_of(Range: genericOp->getOpOperands(), P: [&](OpOperand &opOperand) {
56 return !typeConverter.isLegal(type: opOperand.get().getType());
57 });
58}
59
60/// A conversion pattern for detensoring `linalg.generic` ops.
61class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
62public:
63 using OpConversionPattern::OpConversionPattern;
64 LogicalResult
65 matchAndRewrite(GenericOp op, OpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter) const override {
67 Block *originalBlock = op->getBlock();
68
69 // Gather some information about the op before inlining its region.
70 Block *opEntryBlock = &*op.getRegion().begin();
71 YieldOp yieldOp = dyn_cast<YieldOp>(Val: op.getRegion().back().getTerminator());
72
73 // Split the op's region before the op. This way, we have a clear insertion
74 // point in which the op can be inlined.
75 Block *newBlock = rewriter.splitBlock(block: originalBlock, before: Block::iterator(op));
76 rewriter.inlineRegionBefore(region&: op.getRegion(), before: newBlock);
77 // Now that op's region is inlined, the operands of its YieldOp are mapped
78 // to the materialized target values. Therefore, we can replace the op's
79 // uses with those of its YielOp's operands.
80 rewriter.replaceOp(op, newValues: yieldOp->getOperands());
81
82 // No need for these intermediate blocks, merge them into 1.
83 rewriter.mergeBlocks(source: opEntryBlock, dest: originalBlock, argValues: adaptor.getOperands());
84 rewriter.mergeBlocks(source: newBlock, dest: originalBlock, argValues: {});
85
86 rewriter.eraseOp(op: &*Block::iterator(yieldOp));
87
88 return success();
89 }
90};
91
92/// A conversion pattern for detensoring internal (non-entry) blocks within a
93/// function.
94struct FunctionNonEntryBlockConversion
95 : public OpInterfaceConversionPattern<FunctionOpInterface> {
96 FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
97 DenseSet<BlockArgument> blockArgsToDetensor)
98 : OpInterfaceConversionPattern(converter, ctx),
99 blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
100
101 LogicalResult
102 matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
103 ConversionPatternRewriter &rewriter) const override {
104 rewriter.startOpModification(op);
105 Region &region = op.getFunctionBody();
106
107 for (Block &block :
108 llvm::make_early_inc_range(Range: llvm::drop_begin(RangeOrContainer&: region, N: 1))) {
109 TypeConverter::SignatureConversion conversion(
110 /*numOrigInputs=*/block.getNumArguments());
111
112 for (BlockArgument blockArgument : block.getArguments()) {
113 int idx = blockArgument.getArgNumber();
114
115 if (blockArgsToDetensor.count(V: blockArgument))
116 conversion.addInputs(origInputNo: idx, types: {getTypeConverter()->convertType(
117 t: block.getArgumentTypes()[idx])});
118 else
119 conversion.addInputs(origInputNo: idx, types: {block.getArgumentTypes()[idx]});
120 }
121
122 rewriter.applySignatureConversion(block: &block, conversion, converter: getTypeConverter());
123 }
124
125 rewriter.finalizeOpModification(op);
126 return success();
127 }
128
129private:
130 const DenseSet<BlockArgument> blockArgsToDetensor;
131};
132
133class DetensorizeTypeConverter : public TypeConverter {
134public:
135 DetensorizeTypeConverter() {
136 addConversion(callback: [](Type type) { return type; });
137
138 // A TensorType that can be detensored, is converted to the underlying
139 // element type.
140 addConversion(callback: [](TensorType tensorType) -> Type {
141 if (canBeDetensored(tensorType))
142 return tensorType.getElementType();
143
144 return tensorType;
145 });
146
147 // A tensor value is detensoried by extracting its element(s).
148 addTargetMaterialization(callback: [](OpBuilder &builder, Type type,
149 ValueRange inputs, Location loc) -> Value {
150 return builder.create<tensor::ExtractOp>(location: loc, args: inputs[0], args: ValueRange{});
151 });
152
153 addSourceMaterialization(callback&: sourceMaterializationCallback);
154 }
155};
156
157/// @see LinalgDetensorize in Linalg/Passes.td for more details.
158struct LinalgDetensorize
159 : public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
160 using impl::LinalgDetensorizePassBase<
161 LinalgDetensorize>::LinalgDetensorizePassBase;
162 LinalgDetensorize() = default;
163
164 class CostModel {
165 public:
166 virtual ~CostModel() = default;
167
168 /// A cost model algorithm computes the following outputs:
169 ///
170 /// - opsToDetensor: the list of linalg ops that should be
171 /// detensored.
172 ///
173 /// - blockArgsToDetensor: since the operands and results of detensored
174 /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
175 /// from a BB argument and a linalg op's output can be passed to successor
176 /// BBs), we need to maintain the sub-set of arguments that should be
177 /// detensored (i.e. converted by typeConverter) for each affected BB.
178 ///
179 /// Example:
180 ///
181 /// For the following snippet:
182 /// ...
183 /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
184 /// %7 = tensor.empty() : tensor<i32>
185 /// %8 = linalg.generic #attrs
186 /// ins(%6, %6 : tensor<i32>, tensor<i32>)
187 /// outs(%7 : tensor<i32>) {
188 /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
189 /// %9 = arith.addi %arg0, %arg1 : i32
190 /// linalg.yield %9 : i32
191 /// } -> tensor<i32>
192 /// %10 = "some.op"(%9)
193 /// br ^bb2(%8 : tensor<i32>)
194 /// ...
195 ///
196 /// if the cost model decides that the linalg.generic op should be
197 /// detensored, then:
198 /// - opsToDetensor should be = {linalg.generic{add}}.
199 /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
200 virtual void compute(FunctionOpInterface func,
201 DetensorizeTypeConverter typeConverter,
202 DenseSet<Operation *> &opsToDetensor,
203 DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
204
205 /// From the blockArgsToDetensor set computed by a CostModel
206 /// implementation, this method computes the corresponding branch op
207 /// detensoring. The result is a map from a branch op to a subset of indices
208 /// of its operands. The indices specify which of the branch op's operands
209 /// should be detensored.
210 ///
211 /// For the previous example, this method would compute: {bb2 -> {0}}.
212 static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
213 const DenseSet<BlockArgument> &blockArgsToDetensor) {
214 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
215
216 for (auto blockArgumentElem : blockArgsToDetensor) {
217 Block *block = blockArgumentElem.getOwner();
218
219 for (PredecessorIterator pred = block->pred_begin();
220 pred != block->pred_end(); ++pred) {
221 BranchOpInterface terminator =
222 dyn_cast<BranchOpInterface>(Val: (*pred)->getTerminator());
223 auto blockOperands =
224 terminator.getSuccessorOperands(index: pred.getSuccessorIndex());
225
226 if (blockOperands.empty() ||
227 blockOperands.isOperandProduced(index: blockArgumentElem.getArgNumber()))
228 continue;
229
230 detensorableBranchOps[terminator].insert(
231 V: blockOperands.getOperandIndex(blockArgumentIndex: blockArgumentElem.getArgNumber()));
232 }
233 }
234
235 return detensorableBranchOps;
236 }
237 };
238
239 /// Detensorize linalg ops involved in control-flow within a function.
240 ///
241 /// This model starts from BranchOps and CondBranchOps within a function. For
242 /// each such branch, the model then walks the use-def chain for the branch's
243 /// condition backwards in order to understand where the condition's value
244 /// comes from. If the condition value is (indirectly) computed by a linalg op
245 /// that can be detensored, the model then continues walking the use-def chain
246 /// in order to understand where the linalg op's operands come from. This
247 /// leads to discovering a "detensoring component". A detensoring component is
248 /// the set of operations + block arguments that are involved in control-flow
249 /// AND can be detensored.
250 class ControlFlowDetectionModel : public CostModel {
251 public:
252 void compute(FunctionOpInterface func,
253 DetensorizeTypeConverter typeConverter,
254 DenseSet<Operation *> &opsToDetensor,
255 DenseSet<BlockArgument> &blockArgsToDetensor) override {
256 SmallVector<Value> workList;
257
258 func->walk(callback: [&](cf::CondBranchOp condBr) {
259 llvm::append_range(C&: workList, R: condBr.getOperands());
260 });
261
262 func->walk(callback: [&](cf::BranchOp br) {
263 llvm::append_range(C&: workList, R: br.getOperands());
264 });
265
266 DenseSet<Value> visitedValues;
267 DenseSet<Operation *> visitedOps;
268
269 // For a (to-be-detesored) value, check if it "escapes" the block by being
270 // passed to terminator. If it does, then workList is updated with the
271 // corresponding argument to the successor block.
272 auto updateWorkListWithSuccessorArguments =
273 [&](Value value, BranchOpInterface terminator) {
274 if (!terminator)
275 return;
276
277 for (auto operandIdx :
278 llvm::seq<unsigned>(Begin: 0, End: terminator->getOperands().size())) {
279 Value operand = terminator->getOperand(idx: operandIdx);
280
281 if (operand == value) {
282 auto succBlockArg =
283 terminator.getSuccessorBlockArgument(operandIndex: operandIdx);
284
285 if (succBlockArg && !blockArgsToDetensor.count(V: *succBlockArg))
286 workList.push_back(Elt: *succBlockArg);
287 }
288 }
289 };
290
291 while (!workList.empty()) {
292 Value currentItem = workList.pop_back_val();
293
294 if (!visitedValues.insert(V: currentItem).second)
295 continue;
296
297 // 1 - Look forward:
298 // 1.1 - If currentItem escapes to one or more successors, add
299 // the corresponding successor arguments to workList.
300 updateWorkListWithSuccessorArguments(
301 currentItem, dyn_cast<BranchOpInterface>(
302 Val: currentItem.getParentBlock()->getTerminator()));
303
304 // 1.2 - For each user of currentItem, add the defined values to
305 // workList. This way, the user ops can be inspected later if they are
306 // detensorable and if so, their operands will be added to workList to
307 // potentially discover other parts of the detensorable component.
308 for (auto *user : currentItem.getUsers())
309 llvm::append_range(C&: workList, R: user->getResults());
310
311 // 2 - Look backward:
312 // 2.1 - The current item is defined by a block argument. If the owner
313 // block is a non-entry one, then:
314 // * Add the argument to blockArgsToDetensor.
315 // * Walk the use-def chain backwards to add each predecessor's
316 // terminator-operands corresponding to currentItem to workList.
317 if (auto currentItemBlockArgument =
318 dyn_cast<BlockArgument>(Val&: currentItem)) {
319 Block *ownerBlock = currentItemBlockArgument.getOwner();
320
321 // Function arguments are not detensored/converted.
322 if (&*ownerBlock->getParent()->begin() == ownerBlock)
323 continue;
324
325 // This inner-block argument is involved in control-flow, it should be
326 // detensored.
327 blockArgsToDetensor.insert(V: currentItemBlockArgument);
328
329 for (PredecessorIterator pred = ownerBlock->pred_begin();
330 pred != ownerBlock->pred_end(); ++pred) {
331 BranchOpInterface predTerminator =
332 dyn_cast<BranchOpInterface>(Val: (*pred)->getTerminator());
333
334 // TODO: For now, we give up if any of the control-flow components
335 // in a function is not detensorable. Fix that.
336 if (!predTerminator) {
337 opsToDetensor.clear();
338 blockArgsToDetensor.clear();
339 return;
340 }
341
342 auto ownerBlockOperands =
343 predTerminator.getSuccessorOperands(index: pred.getSuccessorIndex());
344
345 if (ownerBlockOperands.empty() ||
346 ownerBlockOperands.isOperandProduced(
347 index: currentItemBlockArgument.getArgNumber()))
348 continue;
349
350 // For each predecessor, add the value it passes to that argument to
351 // workList to find out how it's computed.
352 workList.push_back(
353 Elt: ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
354 }
355
356 continue;
357 }
358
359 Operation *currentItemDefiningOp = currentItem.getDefiningOp();
360
361 if (!visitedOps.insert(V: currentItemDefiningOp).second)
362 continue;
363
364 // 2.2 - The current item is computed by a GenericOp. If the op should
365 // be detensored, then:
366 // * Add it to opsToDetensor.
367 // * Add its operands to workList to discover other parts of the
368 // potentially detensorable component.
369 if (auto genericOp = dyn_cast<GenericOp>(Val: currentItemDefiningOp)) {
370 // The op was encountered already, no need to inspect it again.
371 if (opsToDetensor.count(V: genericOp))
372 continue;
373
374 // The op should not be detensored, give up on it but continue with
375 // discovering the rest of the control-flow component.
376 if (!shouldBeDetensored(op: genericOp, typeConverter)) {
377 continue;
378 }
379
380 opsToDetensor.insert(V: genericOp);
381 llvm::append_range(C&: workList, R: genericOp.getInputs());
382 continue;
383 }
384
385 // 2.3 - The current item is the result of a FromElementsOp, it will be
386 // trivially detensored later as part of canonicalization patterns
387 // applied at the end of detensoring.
388 //
389 // Note: No need to check whether the result type of this op is
390 // detensorable since if it wasn't we wouldn't reach that point in the
391 // work list.
392 if (isa<tensor::FromElementsOp>(Val: currentItemDefiningOp))
393 continue;
394
395 // 2.4 - The current item is the result of a scalar op, add all its
396 // operands to the work list.
397 if (llvm::all_of(
398 Range: currentItemDefiningOp->getResultTypes(),
399 P: [&](Type resultType) { return resultType.isIntOrFloat(); }))
400 llvm::append_range(C&: workList, R: currentItemDefiningOp->getOperands());
401 }
402
403 // Since the cost model gives up on some ops (see the details of step 2.2
404 // above), block arguments that correspond to the values produced by those
405 // ops should not be detensored as well.
406
407 DenseSet<BlockArgument> blockArgsToRemove;
408
409 for (auto &blockArg : blockArgsToDetensor) {
410 Block *block = blockArg.getParentBlock();
411
412 // For the potentially detensorable block argument, find the
413 // corresponding operands in predecessor blocks.
414 for (PredecessorIterator pred = block->pred_begin();
415 pred != block->pred_end(); ++pred) {
416 BranchOpInterface terminator =
417 dyn_cast<BranchOpInterface>(Val: (*pred)->getTerminator());
418 auto blockOperands =
419 terminator.getSuccessorOperands(index: pred.getSuccessorIndex());
420
421 if (blockOperands.empty() ||
422 blockOperands.isOperandProduced(index: blockArg.getArgNumber()))
423 continue;
424
425 Operation *definingOp =
426 blockOperands[blockArg.getArgNumber()].getDefiningOp();
427
428 // If the operand is defined by a GenericOp that will not be
429 // detensored, then do not detensor the corresponding block argument.
430 if (isa_and_nonnull<GenericOp>(Val: definingOp) &&
431 opsToDetensor.count(V: definingOp) == 0) {
432 blockArgsToRemove.insert(V: blockArg);
433 break;
434 }
435 }
436 }
437
438 for (auto &blockArg : blockArgsToRemove) {
439 blockArgsToDetensor.erase(V: blockArg);
440 }
441 }
442 };
443
444 /// Detensorize everything that can detensored.
445 class AggressiveDetensoringModel : public CostModel {
446 public:
447 void compute(FunctionOpInterface func,
448 DetensorizeTypeConverter typeConverter,
449 DenseSet<Operation *> &opsToDetensor,
450 DenseSet<BlockArgument> &blockArgsToDetensor) override {
451 func->walk(callback: [&](GenericOp genericOp) {
452 if (shouldBeDetensored(op: genericOp, typeConverter))
453 opsToDetensor.insert(V: genericOp);
454 });
455
456 for (Block &block : llvm::drop_begin(RangeOrContainer&: func.getFunctionBody(), N: 1))
457 blockArgsToDetensor.insert_range(R: block.getArguments());
458 }
459 };
460
461 void runOnOperation() override {
462 MLIRContext *context = &getContext();
463 DetensorizeTypeConverter typeConverter;
464 RewritePatternSet patterns(context);
465 ConversionTarget target(*context);
466 DenseSet<Operation *> opsToDetensor;
467 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
468 DenseSet<BlockArgument> blockArgsToDetensor;
469 FunctionOpInterface funcOp = getOperation();
470
471 if (funcOp.getFunctionBody().empty())
472 return;
473
474 // Make sure the entry block of the function doesn't contain any Linalg ops.
475 // Otherwise, it may lead to the signature of the block being changed by the
476 // dialect conversion below, which would make the function op invalid
477 // because its type shouldn't change.
478 IRRewriter rewriter(funcOp->getContext());
479 Block *entryBlock = &funcOp.getFunctionBody().front();
480 Block *postEntryBlock =
481 rewriter.splitBlock(block: entryBlock, before: entryBlock->begin());
482 rewriter.setInsertionPointToStart(entryBlock);
483 auto branch =
484 rewriter.create<cf::BranchOp>(location: rewriter.getUnknownLoc(), args&: postEntryBlock);
485
486 if (aggressiveMode.getValue()) {
487 AggressiveDetensoringModel costModel;
488 costModel.compute(func: funcOp, typeConverter, opsToDetensor,
489 blockArgsToDetensor);
490 } else {
491 ControlFlowDetectionModel costModel;
492 costModel.compute(func: funcOp, typeConverter, opsToDetensor,
493 blockArgsToDetensor);
494 }
495
496 detensorableBranchOps =
497 CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
498
499 target.addDynamicallyLegalOp<GenericOp>(
500 callback: [&](GenericOp op) { return !opsToDetensor.count(V: op); });
501
502 target.markUnknownOpDynamicallyLegal(fn: [&](Operation *op) {
503 // A function is legal if all of its non-entry blocks are legal. We
504 // don't legalize the entry block (i.e. the function's signature)
505 // since detensoring can't happen along external calling convention
506 // boundaries, which we conservatively approximate as all function
507 // signatures.
508 if (auto funcOp = dyn_cast<FunctionOpInterface>(Val: op)) {
509 Region &body = funcOp.getFunctionBody();
510 return llvm::all_of(Range: llvm::drop_begin(RangeOrContainer&: body, N: 1), P: [&](Block &block) {
511 return !llvm::any_of(
512 Range&: blockArgsToDetensor, P: [&](BlockArgument blockArgument) {
513 return blockArgument.getOwner() == &block &&
514 !typeConverter.isLegal(type: blockArgument.getType());
515 });
516 });
517 }
518
519 if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
520 isLegalForReturnOpTypeConversionPattern(op, converter: typeConverter,
521 /*returnOpAlwaysLegal*/ true))
522 return true;
523
524 if (auto branchOp = dyn_cast<BranchOpInterface>(Val: op)) {
525 if (!detensorableBranchOps.count(Val: branchOp))
526 return true;
527
528 for (auto operandIdx : detensorableBranchOps[branchOp])
529 if (!typeConverter.isLegal(
530 type: branchOp->getOperand(idx: operandIdx).getType()))
531 return false;
532
533 return true;
534 }
535
536 return false;
537 });
538
539 patterns.add<DetensorizeGenericOp>(arg&: typeConverter, args&: context);
540 patterns.add<FunctionNonEntryBlockConversion>(arg&: context, args&: typeConverter,
541 args&: blockArgsToDetensor);
542 // Since non-entry block arguments get detensorized, we also need to
543 // update the control flow inside the function to reflect the correct
544 // types.
545 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
546 int operandIdx) -> bool {
547 return detensorableBranchOps.count(Val: branchOp) &&
548 detensorableBranchOps[branchOp].count(V: operandIdx);
549 };
550
551 populateBranchOpInterfaceTypeConversionPattern(patterns, converter: typeConverter,
552 shouldConvertBranchOperand);
553
554 if (failed(
555 Result: applyFullConversion(op: getOperation(), target, patterns: std::move(patterns))))
556 signalPassFailure();
557
558 RewritePatternSet canonPatterns(context);
559 tensor::FromElementsOp::getCanonicalizationPatterns(results&: canonPatterns, context);
560 if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(canonPatterns))))
561 signalPassFailure();
562
563 // Get rid of the dummy entry block we created in the beginning to work
564 // around dialect conversion signature rewriting.
565 rewriter.eraseOp(op: branch);
566 rewriter.mergeBlocks(source: postEntryBlock, dest: entryBlock);
567 }
568};
569} // namespace
570

source code of mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp