1 | //===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to rewrite sequential chains of |
10 | // `spirv::CompositeInsert` operations into `spirv::CompositeConstruct` |
11 | // operations. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h" |
16 | |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/BuiltinOps.h" |
20 | |
21 | namespace mlir { |
22 | namespace spirv { |
23 | #define GEN_PASS_DEF_SPIRVREWRITEINSERTSPASS |
24 | #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc" |
25 | } // namespace spirv |
26 | } // namespace mlir |
27 | |
28 | using namespace mlir; |
29 | |
30 | namespace { |
31 | |
32 | /// Replaces sequential chains of `spirv::CompositeInsertOp` operation into |
33 | /// `spirv::CompositeConstructOp` operation if possible. |
34 | class RewriteInsertsPass |
35 | : public spirv::impl::SPIRVRewriteInsertsPassBase<RewriteInsertsPass> { |
36 | public: |
37 | void runOnOperation() override; |
38 | |
39 | private: |
40 | /// Collects a sequential insertion chain by the given |
41 | /// `spirv::CompositeInsertOp` operation, if the given operation is the last |
42 | /// in the chain. |
43 | LogicalResult |
44 | collectInsertionChain(spirv::CompositeInsertOp op, |
45 | SmallVectorImpl<spirv::CompositeInsertOp> &insertions); |
46 | }; |
47 | |
48 | } // namespace |
49 | |
50 | void RewriteInsertsPass::runOnOperation() { |
51 | SmallVector<SmallVector<spirv::CompositeInsertOp, 4>, 4> workList; |
52 | getOperation().walk([this, &workList](spirv::CompositeInsertOp op) { |
53 | SmallVector<spirv::CompositeInsertOp, 4> insertions; |
54 | if (succeeded(collectInsertionChain(op, insertions))) |
55 | workList.push_back(insertions); |
56 | }); |
57 | |
58 | for (const auto &insertions : workList) { |
59 | auto lastCompositeInsertOp = insertions.back(); |
60 | auto compositeType = lastCompositeInsertOp.getType(); |
61 | auto location = lastCompositeInsertOp.getLoc(); |
62 | |
63 | SmallVector<Value, 4> operands; |
64 | // Collect inserted objects. |
65 | for (auto insertionOp : insertions) |
66 | operands.push_back(insertionOp.getObject()); |
67 | |
68 | OpBuilder builder(lastCompositeInsertOp); |
69 | auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>( |
70 | location, compositeType, operands); |
71 | |
72 | lastCompositeInsertOp.replaceAllUsesWith( |
73 | compositeConstructOp->getResult(0)); |
74 | |
75 | // Erase ops. |
76 | for (auto insertOp : llvm::reverse(insertions)) { |
77 | auto *op = insertOp.getOperation(); |
78 | if (op->use_empty()) |
79 | insertOp.erase(); |
80 | } |
81 | } |
82 | } |
83 | |
84 | LogicalResult RewriteInsertsPass::collectInsertionChain( |
85 | spirv::CompositeInsertOp op, |
86 | SmallVectorImpl<spirv::CompositeInsertOp> &insertions) { |
87 | auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices()); |
88 | // TODO: handle nested composite object. |
89 | if (indicesArrayAttr.size() == 1) { |
90 | auto numElements = cast<spirv::CompositeType>(op.getComposite().getType()) |
91 | .getNumElements(); |
92 | |
93 | auto index = cast<IntegerAttr>(indicesArrayAttr[0]).getInt(); |
94 | // Need a last index to collect a sequential chain. |
95 | if (index + 1 != numElements) |
96 | return failure(); |
97 | |
98 | insertions.resize(numElements); |
99 | while (true) { |
100 | insertions[index] = op; |
101 | |
102 | if (index == 0) |
103 | return success(); |
104 | |
105 | op = op.getComposite().getDefiningOp<spirv::CompositeInsertOp>(); |
106 | if (!op) |
107 | return failure(); |
108 | |
109 | --index; |
110 | indicesArrayAttr = cast<ArrayAttr>(op.getIndices()); |
111 | if ((indicesArrayAttr.size() != 1) || |
112 | (cast<IntegerAttr>(indicesArrayAttr[0]).getInt() != index)) |
113 | return failure(); |
114 | } |
115 | } |
116 | return failure(); |
117 | } |
118 | |