1//===- CodegenEnv.h - Code generation environment class ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines the code generation environment class.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
14#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
15
16#include "CodegenUtils.h"
17#include "LoopEmitter.h"
18
19#include "mlir/Dialect/Linalg/IR/Linalg.h"
20#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
22#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
23#include <optional>
24
25namespace mlir {
26namespace sparse_tensor {
27
28/// The code generation environment class aggregates a number of data
29/// structures that are needed during the code generation phase of
30/// sparsification. This environment simplifies passing around such
31/// data during sparsification (rather than passing around all the
32/// individual compoments where needed). Furthermore, it provides
33/// convience methods that keep implementation details transparent
34/// to sparsification while asserting on internal consistency.
35class CodegenEnv {
36public:
37 /// Constructs a code generation environment which can be
38 /// passed around during sparsification for bookkeeping
39 /// together with some consistency asserts.
40 CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
41 unsigned numTensors, unsigned numLoops, unsigned maxRank);
42
43 //
44 // General methods.
45 //
46
47 LogicalResult initTensorExp();
48 ExprId getExprId() const { return tensorExp; }
49
50 linalg::GenericOp op() const { return linalgOp; }
51 const SparsificationOptions &options() const { return sparseOptions; }
52 Merger &merger() { return latticeMerger; }
53 LoopEmitter &emitter() { return loopEmitter; }
54
55 void startEmit(SparseEmitStrategy emitStrategy);
56
57 /// Generates loop boundary statements (entering/exiting loops). The function
58 /// passes and updates the passed-in parameters.
59 std::optional<Operation *>
60 genLoopBoundary(function_ref<
61 std::optional<Operation *>(MutableArrayRef<Value> parameters)>
62 callback);
63
64 //
65 // Merger delegates.
66 //
67
68 constexpr TensorId makeTensorId(unsigned t) const {
69 return latticeMerger.makeTensorId(t);
70 }
71 constexpr LoopId makeLoopId(unsigned i) const {
72 return latticeMerger.makeLoopId(i);
73 }
74 constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const {
75 return latticeMerger.makeTensorLoopId(t, i);
76 }
77 const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); }
78 const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(p: l); }
79 ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); }
80 LevelType lt(TensorId t, LoopId i) const {
81 return latticeMerger.getLvlType(t, i);
82 }
83 LevelType lt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
84
85 unsigned getLoopNum() const { return latticeMerger.getNumLoops(); }
86
87 //
88 // LoopEmitter delegates.
89 //
90
91 TensorLevel makeTensorLevel(TensorId t, Level l) const {
92 // Make sure LoopEmitter, GenericOp, and Merger agree on the number of
93 // tensors.
94 assert(loopEmitter.getNumManifestTensors() == linalgOp->getNumOperands() &&
95 loopEmitter.getNumTensors() == latticeMerger.getNumTensors() &&
96 loopEmitter.getOutTensorId() == latticeMerger.getOutTensorID() &&
97 loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID());
98 return loopEmitter.makeTensorLevel(t, l);
99 }
100 TensorLevel makeTensorLevel(std::pair<TensorId, Level> tlPair) const {
101 return makeTensorLevel(t: tlPair.first, l: tlPair.second);
102 }
103 std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const {
104 return loopEmitter.unpackTensorLevel(tl);
105 }
106 template <class ContainerTy>
107 auto unpackTensorLevelRange(ContainerTy &&c) const {
108 return loopEmitter.unpackTensorLevelRange(std::forward<ContainerTy>(c));
109 }
110
111 unsigned getCurrentDepth() const { return loopEmitter.getCurrentDepth(); }
112
113 //
114 // Code generation environment verify functions.
115 //
116
117 /// Whether the tensor expression is admissible for codegen.
118 /// It also sets the sparseOut if the output tensor is sparse.
119 bool isAdmissibleTensorExp(ExprId e);
120
121 /// Returns the induction-variable for the given loop.
122 Value getLoopVar(LoopId i) const;
123
124 //
125 // Sparse tensor output and expansion methods.
126 //
127
128 bool hasSparseOutput() const { return sparseOut != nullptr; }
129 bool isSparseOutput(OpOperand *o) const { return sparseOut == o; }
130
131 Value getInsertionChain() const { return insChain; }
132 void updateInsertionChain(Value chain);
133
134 bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const;
135 void startExpand(Value values, Value filled, Value added, Value count);
136 bool isExpand() const { return expValues != nullptr; }
137 void updateExpandCount(Value count);
138 Value getExpandValues() const { return expValues; }
139 Value getExpandFilled() const { return expFilled; }
140 Value getExpandAdded() const { return expAdded; }
141 Value getExpandCount() const { return expCount; }
142 void endExpand();
143
144 //
145 // Reduction methods.
146 //
147
148 void startReduc(ExprId exp, Value val);
149 bool isReduc() const { return redExp != detail::kInvalidId; }
150 void updateReduc(Value val);
151 Value getReduc() const { return redVal; }
152 Value endReduc();
153
154 void startValidLexInsert(Value val);
155 bool isValidLexInsert() const { return redValidLexInsert != nullptr; }
156 void updateValidLexInsert(Value val);
157 Value getValidLexInsert() const { return redValidLexInsert; }
158 void endValidLexInsert();
159
160 void startCustomReduc(ExprId exp);
161 bool isCustomReduc() const { return redCustom != detail::kInvalidId; }
162 Value getCustomRedId() const;
163 void endCustomReduc();
164
165private:
166 // Linalg operation.
167 linalg::GenericOp linalgOp;
168
169 // Sparsification options.
170 SparsificationOptions sparseOptions;
171
172 // Merger helper class.
173 Merger latticeMerger;
174
175 // Loop emitter helper class.
176 LoopEmitter loopEmitter;
177
178 // Sparse tensor as output. Implemented either through direct injective
179 // insertion in lexicographic index order or through access pattern
180 // expansion in the innermost loop nest (`expValues` through `expCount`).
181 OpOperand *sparseOut;
182 // The count of outer non-filter loops, as defined by `isAdmissibleTopoOrder`.
183 LoopId outerParNest;
184 Value insChain;
185 Value expValues;
186 Value expFilled;
187 Value expAdded;
188 Value expCount;
189
190 // Bookkeeping for reductions (up-to-date value of the reduction, and indices
191 // into the merger's expression tree. When the indices of a tensor reduction
192 // expression are exhausted, all inner loops can use a scalarized reduction.
193 Value redVal;
194 ExprId redExp;
195 ExprId redCustom;
196
197 // Bookkeeping for lex insertion during reductions. Holds the runtime boolean
198 // value of whether any reduction occurred. This is only set during a
199 // reduction and cleared once the reduction is finished.
200 Value redValidLexInsert;
201
202 // The root tensor expression of the kernel.
203 ExprId tensorExp;
204};
205
206} // namespace sparse_tensor
207} // namespace mlir
208
209#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
210

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h