1//===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Passes.h"
10
11#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/OpDefinition.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include <iterator>
20#include <memory>
21#include <utility>
22
23namespace mlir {
24#define GEN_PASS_DEF_LINALGDETENSORIZEPASS
25#include "mlir/Dialect/Linalg/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29using namespace mlir::linalg;
30
31static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
32 ValueRange inputs, Location loc) {
33 assert(inputs.size() == 1);
34 auto inputType = inputs[0].getType();
35 if (isa<TensorType>(inputType))
36 return nullptr;
37
38 // A detensored value is converted back by creating a new tensor from its
39 // element(s).
40 return builder.create<tensor::FromElementsOp>(
41 loc, RankedTensorType::get({}, inputType), inputs[0]);
42}
43
44namespace {
45/// Defines the criteria a TensorType must follow in order to be considered
46/// "detensorable".
47///
48/// NOTE: For now, only 0-D tensors are supported.
49///
50/// Returns true if tensorType can be detensored.
51bool canBeDetensored(TensorType tensorType) {
52 return tensorType.hasRank() && tensorType.getRank() == 0;
53}
54
55bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
56 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
57 return genericOp &&
58 llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) {
59 return !typeConverter.isLegal(type: opOperand.get().getType());
60 });
61}
62
63/// A conversion pattern for detensoring `linalg.generic` ops.
64class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
65public:
66 using OpConversionPattern::OpConversionPattern;
67 LogicalResult
68 matchAndRewrite(GenericOp op, OpAdaptor adaptor,
69 ConversionPatternRewriter &rewriter) const override {
70 Block *originalBlock = op->getBlock();
71
72 // Gather some information about the op before inlining its region.
73 Block *opEntryBlock = &*op.getRegion().begin();
74 YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
75
76 // Split the op's region before the op. This way, we have a clear insertion
77 // point in which the op can be inlined.
78 Block *newBlock = rewriter.splitBlock(block: originalBlock, before: Block::iterator(op));
79 rewriter.inlineRegionBefore(op.getRegion(), newBlock);
80 // Now that op's region is inlined, the operands of its YieldOp are mapped
81 // to the materialized target values. Therefore, we can replace the op's
82 // uses with those of its YielOp's operands.
83 rewriter.replaceOp(op, yieldOp->getOperands());
84
85 // No need for these intermediate blocks, merge them into 1.
86 rewriter.mergeBlocks(source: opEntryBlock, dest: originalBlock, argValues: adaptor.getOperands());
87 rewriter.mergeBlocks(source: newBlock, dest: originalBlock, argValues: {});
88
89 rewriter.eraseOp(op: &*Block::iterator(yieldOp));
90
91 return success();
92 }
93};
94
95/// A conversion pattern for detensoring internal (non-entry) blocks within a
96/// function.
97struct FunctionNonEntryBlockConversion
98 : public OpInterfaceConversionPattern<FunctionOpInterface> {
99 FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
100 DenseSet<BlockArgument> blockArgsToDetensor)
101 : OpInterfaceConversionPattern(converter, ctx),
102 blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
103
104 LogicalResult
105 matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
106 ConversionPatternRewriter &rewriter) const override {
107 rewriter.startOpModification(op: op);
108 Region &region = op.getFunctionBody();
109
110 for (Block &block :
111 llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
112 TypeConverter::SignatureConversion conversion(
113 /*numOrigInputs=*/block.getNumArguments());
114
115 for (BlockArgument blockArgument : block.getArguments()) {
116 int idx = blockArgument.getArgNumber();
117
118 if (blockArgsToDetensor.count(blockArgument))
119 conversion.addInputs(idx, {getTypeConverter()->convertType(
120 block.getArgumentTypes()[idx])});
121 else
122 conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
123 }
124
125 rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
126 }
127
128 rewriter.finalizeOpModification(op: op);
129 return success();
130 }
131
132private:
133 const DenseSet<BlockArgument> blockArgsToDetensor;
134};
135
136class DetensorizeTypeConverter : public TypeConverter {
137public:
138 DetensorizeTypeConverter() {
139 addConversion(callback: [](Type type) { return type; });
140
141 // A TensorType that can be detensored, is converted to the underlying
142 // element type.
143 addConversion(callback: [](TensorType tensorType) -> Type {
144 if (canBeDetensored(tensorType))
145 return tensorType.getElementType();
146
147 return tensorType;
148 });
149
150 // A tensor value is detensoried by extracting its element(s).
151 addTargetMaterialization(callback: [](OpBuilder &builder, Type type,
152 ValueRange inputs, Location loc) -> Value {
153 return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
154 });
155
156 addSourceMaterialization(callback&: sourceMaterializationCallback);
157 }
158};
159
160/// @see LinalgDetensorize in Linalg/Passes.td for more details.
161struct LinalgDetensorize
162 : public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
163 using impl::LinalgDetensorizePassBase<
164 LinalgDetensorize>::LinalgDetensorizePassBase;
165 LinalgDetensorize() = default;
166
167 class CostModel {
168 public:
169 virtual ~CostModel() = default;
170
171 /// A cost model algorithm computes the following outputs:
172 ///
173 /// - opsToDetensor: the list of linalg ops that should be
174 /// detensored.
175 ///
176 /// - blockArgsToDetensor: since the operands and results of detensored
177 /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
178 /// from a BB argument and a linalg op's output can be passed to successor
179 /// BBs), we need to maintain the sub-set of arguments that should be
180 /// detensored (i.e. converted by typeConverter) for each affected BB.
181 ///
182 /// Example:
183 ///
184 /// For the following snippet:
185 /// ...
186 /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
187 /// %7 = tensor.empty() : tensor<i32>
188 /// %8 = linalg.generic #attrs
189 /// ins(%6, %6 : tensor<i32>, tensor<i32>)
190 /// outs(%7 : tensor<i32>) {
191 /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
192 /// %9 = arith.addi %arg0, %arg1 : i32
193 /// linalg.yield %9 : i32
194 /// } -> tensor<i32>
195 /// %10 = "some.op"(%9)
196 /// br ^bb2(%8 : tensor<i32>)
197 /// ...
198 ///
199 /// if the cost model decides that the linalg.generic op should be
200 /// detensored, then:
201 /// - opsToDetensor should be = {linalg.generic{add}}.
202 /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
203 virtual void compute(FunctionOpInterface func,
204 DetensorizeTypeConverter typeConverter,
205 DenseSet<Operation *> &opsToDetensor,
206 DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
207
208 /// From the blockArgsToDetensor set computed by a CostModel
209 /// implementation, this method computes the corresponding branch op
210 /// detensoring. The result is a map from a branch op to a subset of indices
211 /// of its operands. The indices specify which of the branch op's operands
212 /// should be detensored.
213 ///
214 /// For the previous example, this method would compute: {bb2 -> {0}}.
215 static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
216 const DenseSet<BlockArgument> &blockArgsToDetensor) {
217 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
218
219 for (auto blockArgumentElem : blockArgsToDetensor) {
220 Block *block = blockArgumentElem.getOwner();
221
222 for (PredecessorIterator pred = block->pred_begin();
223 pred != block->pred_end(); ++pred) {
224 BranchOpInterface terminator =
225 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
226 auto blockOperands =
227 terminator.getSuccessorOperands(pred.getSuccessorIndex());
228
229 if (blockOperands.empty() ||
230 blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
231 continue;
232
233 detensorableBranchOps[terminator].insert(
234 blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
235 }
236 }
237
238 return detensorableBranchOps;
239 }
240 };
241
242 /// Detensorize linalg ops involved in control-flow within a function.
243 ///
244 /// This model starts from BranchOps and CondBranchOps within a function. For
245 /// each such branch, the model then walks the use-def chain for the branch's
246 /// condition backwards in order to understand where the condition's value
247 /// comes from. If the condition value is (indirectly) computed by a linalg op
248 /// that can be detensored, the model then continues walking the use-def chain
249 /// in order to understand where the linalg op's operands come from. This
250 /// leads to discovering a "detensoring component". A detensoring component is
251 /// the set of operations + block arguments that are involved in control-flow
252 /// AND can be detensored.
253 class ControlFlowDetectionModel : public CostModel {
254 public:
255 void compute(FunctionOpInterface func,
256 DetensorizeTypeConverter typeConverter,
257 DenseSet<Operation *> &opsToDetensor,
258 DenseSet<BlockArgument> &blockArgsToDetensor) override {
259 SmallVector<Value> workList;
260
261 func->walk([&](cf::CondBranchOp condBr) {
262 llvm::append_range(workList, condBr.getOperands());
263 });
264
265 func->walk([&](cf::BranchOp br) {
266 llvm::append_range(workList, br.getOperands());
267 });
268
269 DenseSet<Value> visitedValues;
270 DenseSet<Operation *> visitedOps;
271
272 // For a (to-be-detesored) value, check if it "escapes" the block by being
273 // passed to terminator. If it does, then workList is updated with the
274 // corresponding argument to the successor block.
275 auto updateWorkListWithSuccessorArguments =
276 [&](Value value, BranchOpInterface terminator) {
277 if (!terminator)
278 return;
279
280 for (auto operandIdx :
281 llvm::seq<unsigned>(0, terminator->getOperands().size())) {
282 Value operand = terminator->getOperand(operandIdx);
283
284 if (operand == value) {
285 auto succBlockArg =
286 terminator.getSuccessorBlockArgument(operandIdx);
287
288 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
289 workList.push_back(*succBlockArg);
290 }
291 }
292 };
293
294 while (!workList.empty()) {
295 Value currentItem = workList.pop_back_val();
296
297 if (!visitedValues.insert(V: currentItem).second)
298 continue;
299
300 // 1 - Look forward:
301 // 1.1 - If currentItem escapes to one or more successors, add
302 // the corresponding successor arguments to workList.
303 updateWorkListWithSuccessorArguments(
304 currentItem, dyn_cast<BranchOpInterface>(
305 currentItem.getParentBlock()->getTerminator()));
306
307 // 1.2 - For each user of currentItem, add the defined values to
308 // workList. This way, the user ops can be inspected later if they are
309 // detensorable and if so, their operands will be added to workList to
310 // potentially discover other parts of the detensorable component.
311 for (auto *user : currentItem.getUsers())
312 llvm::append_range(C&: workList, R: user->getResults());
313
314 // 2 - Look backward:
315 // 2.1 - The current item is defined by a block argument. If the owner
316 // block is a non-entry one, then:
317 // * Add the argument to blockArgsToDetensor.
318 // * Walk the use-def chain backwards to add each predecessor's
319 // terminator-operands corresponding to currentItem to workList.
320 if (auto currentItemBlockArgument =
321 dyn_cast<BlockArgument>(Val&: currentItem)) {
322 Block *ownerBlock = currentItemBlockArgument.getOwner();
323
324 // Function arguments are not detensored/converted.
325 if (&*ownerBlock->getParent()->begin() == ownerBlock)
326 continue;
327
328 // This inner-block argument is involved in control-flow, it should be
329 // detensored.
330 blockArgsToDetensor.insert(V: currentItemBlockArgument);
331
332 for (PredecessorIterator pred = ownerBlock->pred_begin();
333 pred != ownerBlock->pred_end(); ++pred) {
334 BranchOpInterface predTerminator =
335 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
336
337 // TODO: For now, we give up if any of the control-flow components
338 // in a function is not detensorable. Fix that.
339 if (!predTerminator) {
340 opsToDetensor.clear();
341 blockArgsToDetensor.clear();
342 return;
343 }
344
345 auto ownerBlockOperands =
346 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
347
348 if (ownerBlockOperands.empty() ||
349 ownerBlockOperands.isOperandProduced(
350 currentItemBlockArgument.getArgNumber()))
351 continue;
352
353 // For each predecessor, add the value it passes to that argument to
354 // workList to find out how it's computed.
355 workList.push_back(
356 Elt: ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
357 }
358
359 continue;
360 }
361
362 Operation *currentItemDefiningOp = currentItem.getDefiningOp();
363
364 if (!visitedOps.insert(V: currentItemDefiningOp).second)
365 continue;
366
367 // 2.2 - The current item is computed by a GenericOp. If the op should
368 // be detensored, then:
369 // * Add it to opsToDetensor.
370 // * Add its operands to workList to discover other parts of the
371 // potentially detensorable component.
372 if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
373 // The op was encountered already, no need to inspect it again.
374 if (opsToDetensor.count(V: genericOp))
375 continue;
376
377 // The op should not be detensored, give up on it but continue with
378 // discovering the rest of the control-flow component.
379 if (!shouldBeDetensored(genericOp, typeConverter)) {
380 continue;
381 }
382
383 opsToDetensor.insert(genericOp);
384 llvm::append_range(workList, genericOp.getInputs());
385 continue;
386 }
387
388 // 2.3 - The current item is the result of a FromElementsOp, it will be
389 // trivially detensored later as part of canonicalization patterns
390 // applied at the end of detensoring.
391 //
392 // Note: No need to check whether the result type of this op is
393 // detensorable since if it wasn't we wouldn't reach that point in the
394 // work list.
395 if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
396 continue;
397
398 // 2.4 - The current item is the result of a scalar op, add all its
399 // operands to the work list.
400 if (llvm::all_of(
401 Range: currentItemDefiningOp->getResultTypes(),
402 P: [&](Type resultType) { return resultType.isIntOrFloat(); }))
403 llvm::append_range(C&: workList, R: currentItemDefiningOp->getOperands());
404 }
405
406 // Since the cost model gives up on some ops (see the details of step 2.2
407 // above), block arguments that correspond to the values produced by those
408 // ops should not be detensored as well.
409
410 DenseSet<BlockArgument> blockArgsToRemove;
411
412 for (auto &blockArg : blockArgsToDetensor) {
413 Block *block = blockArg.getParentBlock();
414
415 // For the potentially detensorable block argument, find the
416 // corresponding operands in predecessor blocks.
417 for (PredecessorIterator pred = block->pred_begin();
418 pred != block->pred_end(); ++pred) {
419 BranchOpInterface terminator =
420 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
421 auto blockOperands =
422 terminator.getSuccessorOperands(pred.getSuccessorIndex());
423
424 if (blockOperands.empty() ||
425 blockOperands.isOperandProduced(blockArg.getArgNumber()))
426 continue;
427
428 Operation *definingOp =
429 blockOperands[blockArg.getArgNumber()].getDefiningOp();
430
431 // If the operand is defined by a GenericOp that will not be
432 // detensored, then do not detensor the corresponding block argument.
433 if (isa_and_nonnull<GenericOp>(Val: definingOp) &&
434 opsToDetensor.count(V: definingOp) == 0) {
435 blockArgsToRemove.insert(V: blockArg);
436 break;
437 }
438 }
439 }
440
441 for (auto &blockArg : blockArgsToRemove) {
442 blockArgsToDetensor.erase(V: blockArg);
443 }
444 }
445 };
446
447 /// Detensorize everything that can detensored.
448 class AggressiveDetensoringModel : public CostModel {
449 public:
450 void compute(FunctionOpInterface func,
451 DetensorizeTypeConverter typeConverter,
452 DenseSet<Operation *> &opsToDetensor,
453 DenseSet<BlockArgument> &blockArgsToDetensor) override {
454 func->walk([&](GenericOp genericOp) {
455 if (shouldBeDetensored(genericOp, typeConverter))
456 opsToDetensor.insert(genericOp);
457 });
458
459 for (Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
460 blockArgsToDetensor.insert_range(block.getArguments());
461 }
462 };
463
464 void runOnOperation() override {
465 MLIRContext *context = &getContext();
466 DetensorizeTypeConverter typeConverter;
467 RewritePatternSet patterns(context);
468 ConversionTarget target(*context);
469 DenseSet<Operation *> opsToDetensor;
470 DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
471 DenseSet<BlockArgument> blockArgsToDetensor;
472 FunctionOpInterface funcOp = getOperation();
473
474 if (funcOp.getFunctionBody().empty())
475 return;
476
477 // Make sure the entry block of the function doesn't contain any Linalg ops.
478 // Otherwise, it may lead to the signature of the block being changed by the
479 // dialect conversion below, which would make the function op invalid
480 // because its type shouldn't change.
481 IRRewriter rewriter(funcOp->getContext());
482 Block *entryBlock = &funcOp.getFunctionBody().front();
483 Block *postEntryBlock =
484 rewriter.splitBlock(block: entryBlock, before: entryBlock->begin());
485 rewriter.setInsertionPointToStart(entryBlock);
486 auto branch =
487 rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock);
488
489 if (aggressiveMode.getValue()) {
490 AggressiveDetensoringModel costModel;
491 costModel.compute(func: funcOp, typeConverter, opsToDetensor,
492 blockArgsToDetensor);
493 } else {
494 ControlFlowDetectionModel costModel;
495 costModel.compute(func: funcOp, typeConverter, opsToDetensor,
496 blockArgsToDetensor);
497 }
498
499 detensorableBranchOps =
500 CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
501
502 target.addDynamicallyLegalOp<GenericOp>(
503 callback: [&](GenericOp op) { return !opsToDetensor.count(op); });
504
505 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
506 // A function is legal if all of its non-entry blocks are legal. We
507 // don't legalize the entry block (i.e. the function's signature)
508 // since detensoring can't happen along external calling convention
509 // boundaries, which we conservatively approximate as all function
510 // signatures.
511 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
512 Region &body = funcOp.getFunctionBody();
513 return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
514 return !llvm::any_of(
515 blockArgsToDetensor, [&](BlockArgument blockArgument) {
516 return blockArgument.getOwner() == &block &&
517 !typeConverter.isLegal(blockArgument.getType());
518 });
519 });
520 }
521
522 if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
523 isLegalForReturnOpTypeConversionPattern(op, typeConverter,
524 /*returnOpAlwaysLegal*/ true))
525 return true;
526
527 if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
528 if (!detensorableBranchOps.count(branchOp))
529 return true;
530
531 for (auto operandIdx : detensorableBranchOps[branchOp])
532 if (!typeConverter.isLegal(
533 branchOp->getOperand(operandIdx).getType()))
534 return false;
535
536 return true;
537 }
538
539 return false;
540 });
541
542 patterns.add<DetensorizeGenericOp>(arg&: typeConverter, args&: context);
543 patterns.add<FunctionNonEntryBlockConversion>(arg&: context, args&: typeConverter,
544 args&: blockArgsToDetensor);
545 // Since non-entry block arguments get detensorized, we also need to
546 // update the control flow inside the function to reflect the correct
547 // types.
548 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
549 int operandIdx) -> bool {
550 return detensorableBranchOps.count(Val: branchOp) &&
551 detensorableBranchOps[branchOp].count(operandIdx);
552 };
553
554 populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
555 shouldConvertBranchOperand);
556
557 if (failed(
558 applyFullConversion(getOperation(), target, std::move(patterns))))
559 signalPassFailure();
560
561 RewritePatternSet canonPatterns(context);
562 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
563 if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns))))
564 signalPassFailure();
565
566 // Get rid of the dummy entry block we created in the beginning to work
567 // around dialect conversion signature rewriting.
568 rewriter.eraseOp(op: branch);
569 rewriter.mergeBlocks(source: postEntryBlock, dest: entryBlock);
570 }
571};
572} // namespace
573

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp