1 | //===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements constant folding on Linalg operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
15 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
16 | #include "mlir/IR/Matchers.h" |
17 | #include "mlir/IR/PatternMatch.h" |
18 | #include "mlir/Support/LLVM.h" |
19 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
20 | #include <optional> |
21 | |
22 | using namespace mlir; |
23 | using namespace mlir::linalg; |
24 | |
25 | namespace { |
26 | /// Base class for constant folding linalg.generic ops with N inputs, 1 output, |
27 | /// and permutation indexing maps. |
28 | /// |
29 | /// `ConcreteType` should provide methods with signatures |
30 | /// |
31 | /// ```c++ |
32 | /// bool matchIndexingMaps(GenericOp genericOp) const; |
33 | /// RegionComputationFn getRegionComputeFn(GenericOp) const; |
34 | /// ``` |
35 | /// |
36 | /// The latter inspects the region and returns the computation inside as a |
37 | /// functor. The functor will be invoked with constant elements for all inputs |
38 | /// and should return the corresponding computed constant element for output. |
39 | template <typename ConcreteType> |
40 | class FoldConstantBase : public OpRewritePattern<GenericOp> { |
41 | public: |
42 | struct APIntOrFloat { |
43 | std::optional<APInt> apInt; |
44 | std::optional<APFloat> apFloat; |
45 | }; |
46 | struct APIntOrFloatArray { |
47 | SmallVector<APInt> apInts; |
48 | SmallVector<APFloat> apFloats; |
49 | }; |
50 | using RegionComputationFn = |
51 | std::function<APIntOrFloat(const APIntOrFloatArray &)>; |
52 | |
53 | FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn, |
54 | PatternBenefit benefit = 1) |
55 | : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {} |
56 | |
57 | LogicalResult matchAndRewrite(GenericOp genericOp, |
58 | PatternRewriter &rewriter) const override { |
59 | // Mixed and buffer sematics aren't supported. |
60 | if (!genericOp.hasPureTensorSemantics()) |
61 | return failure(); |
62 | |
63 | // Only support ops generating one output for now. |
64 | if (genericOp.getNumDpsInits() != 1) |
65 | return failure(); |
66 | |
67 | auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front()); |
68 | // Require the output types to be static given that we are generating |
69 | // constants. |
70 | if (!outputType || !outputType.hasStaticShape()) |
71 | return failure(); |
72 | |
73 | if (!llvm::all_of(genericOp.getInputs(), [](Value input) { |
74 | return isa<ShapedType>(Val: input.getType()); |
75 | })) |
76 | return failure(); |
77 | |
78 | // Make sure all element types are the same. |
79 | auto getOperandElementType = [](Value value) { |
80 | return cast<ShapedType>(value.getType()).getElementType(); |
81 | }; |
82 | if (!llvm::all_equal( |
83 | llvm::map_range(genericOp->getOperands(), getOperandElementType))) |
84 | return failure(); |
85 | |
86 | // We can only handle the case where we have int/float elements. |
87 | auto elementType = outputType.getElementType(); |
88 | if (!elementType.isIntOrFloat()) |
89 | return failure(); |
90 | |
91 | // Require all indexing maps to be permutations for now. This is common and |
92 | // it simplifies input/output access greatly: we can do the data shuffling |
93 | // entirely in the compiler, without needing to turn all indices into |
94 | // Values, and then do affine apply on them, and then match back the |
95 | // constant again. |
96 | if (!llvm::all_of(genericOp.getIndexingMapsArray(), |
97 | [](AffineMap map) { return map.isPermutation(); })) |
98 | return failure(); |
99 | |
100 | for (OpOperand &operand : genericOp.getDpsInitsMutable()) { |
101 | if (genericOp.payloadUsesValueFromOperand(&operand)) |
102 | return failure(); |
103 | } |
104 | |
105 | // Further check the indexing maps are okay for the ConcreteType. |
106 | if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp)) |
107 | return failure(); |
108 | |
109 | // Defer to the concrete type to check the region and discover the |
110 | // computation inside. |
111 | RegionComputationFn computeFn = |
112 | static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp); |
113 | if (!computeFn) |
114 | return failure(); |
115 | |
116 | // All inputs should be constants. |
117 | int numInputs = genericOp.getNumDpsInputs(); |
118 | SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs); |
119 | for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) { |
120 | if (!matchPattern(en.value()->get(), |
121 | m_Constant(&inputValues[en.index()]))) |
122 | return failure(); |
123 | } |
124 | |
125 | // Identified this as a potential candidate for folding. Now check the |
126 | // policy to see whether we are allowed to proceed. |
127 | for (OpOperand *operand : genericOp.getDpsInputOperands()) { |
128 | if (!controlFn(operand)) |
129 | return failure(); |
130 | } |
131 | |
132 | auto linalgOp = cast<LinalgOp>(genericOp.getOperation()); |
133 | SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes(); |
134 | int64_t numElements = outputType.getNumElements(); |
135 | |
136 | // Use APInt/APFloat instead of Attribute here for constructing the output. |
137 | // This helps to avoid blowing up compiler memory usage: Attributes would |
138 | // unify the following cases but they have lifetime as the MLIRContext. |
139 | SmallVector<APInt> intOutputValues; |
140 | SmallVector<APFloat> fpOutputValues; |
141 | if (isa<FloatType>(elementType)) |
142 | fpOutputValues.resize(N: numElements, NV: APFloat(0.f)); |
143 | else |
144 | intOutputValues.resize(N: numElements); |
145 | |
146 | // Return the constant dim positions from the given permutation map. |
147 | auto getDimPositions = [](AffineMap map) { |
148 | SmallVector<unsigned> dims; |
149 | dims.reserve(N: map.getNumResults()); |
150 | for (AffineExpr result : map.getResults()) { |
151 | dims.push_back(Elt: cast<AffineDimExpr>(Val&: result).getPosition()); |
152 | } |
153 | return dims; |
154 | }; |
155 | |
156 | SmallVector<SmallVector<unsigned>> inputDims; |
157 | for (int i = 0; i < numInputs; ++i) |
158 | inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i])); |
159 | auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back()); |
160 | auto outputShape = outputType.getShape(); |
161 | |
162 | // Allocate small vectors for index delinearization. Initial values do not |
163 | // matter here as they will be overwritten later. |
164 | SmallVector<uint64_t> indices(loopBounds.size(), 0); |
165 | SmallVector<uint64_t> dstIndices(loopBounds.size(), 0); |
166 | SmallVector<SmallVector<uint64_t>> srcIndices( |
167 | numInputs, SmallVector<uint64_t>(loopBounds.size(), 0)); |
168 | SmallVector<uint64_t> srcLinearIndices(numInputs, 0); |
169 | uint64_t dstLinearIndex = 0; |
170 | |
171 | // Allocate spaces for compute function inputs. Initial values do not matter |
172 | // here as they will be overwritten later. |
173 | APIntOrFloatArray computeFnInputs; |
174 | |
175 | auto inputShapes = llvm::to_vector<4>( |
176 | llvm::map_range(genericOp.getInputs(), [](Value value) { |
177 | return cast<ShapedType>(value.getType()).getShape(); |
178 | })); |
179 | |
180 | // Given a `linearIndex`, remap it to a linear index to access linalg op |
181 | // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, |
182 | // `srcLinearIndices`, `dstLinearIndex` in place. |
183 | auto computeRemappedLinearIndex = [&](int linearIndex) { |
184 | int totalCount = linearIndex; |
185 | for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { |
186 | indices[dim] = totalCount % loopBounds[dim]; |
187 | totalCount /= loopBounds[dim]; |
188 | } |
189 | |
190 | for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { |
191 | for (int i = 0; i < numInputs; ++i) |
192 | srcIndices[i][dim] = indices[inputDims[i][dim]]; |
193 | dstIndices[dim] = indices[outputDims[dim]]; |
194 | } |
195 | |
196 | dstLinearIndex = dstIndices.front(); |
197 | for (int i = 0; i < numInputs; ++i) |
198 | srcLinearIndices[i] = srcIndices[i].front(); |
199 | |
200 | for (int dim = 1; dim < outputType.getRank(); ++dim) { |
201 | dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; |
202 | for (int i = 0; i < numInputs; ++i) |
203 | srcLinearIndices[i] = |
204 | srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; |
205 | } |
206 | }; |
207 | |
208 | bool isFloat = isa<FloatType>(elementType); |
209 | if (isFloat) { |
210 | SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges; |
211 | for (int i = 0; i < numInputs; ++i) |
212 | inFpRanges.push_back(inputValues[i].getValues<APFloat>()); |
213 | |
214 | computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); |
215 | |
216 | // Transpose the input constant. Because we don't know its rank in |
217 | // advance, we need to loop over the range [0, element count) and |
218 | // delinearize the index. |
219 | for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { |
220 | computeRemappedLinearIndex(linearIndex); |
221 | |
222 | // Collect constant elements for all inputs at this loop iteration. |
223 | for (int i = 0; i < numInputs; ++i) |
224 | computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; |
225 | |
226 | // Invoke the computation to get the corresponding constant output |
227 | // element. |
228 | fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; |
229 | } |
230 | } else { |
231 | SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges; |
232 | for (int i = 0; i < numInputs; ++i) |
233 | inIntRanges.push_back(inputValues[i].getValues<APInt>()); |
234 | |
235 | computeFnInputs.apInts.resize(numInputs); |
236 | |
237 | // Transpose the input constant. Because we don't know its rank in |
238 | // advance, we need to loop over the range [0, element count) and |
239 | // delinearize the index. |
240 | for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { |
241 | computeRemappedLinearIndex(linearIndex); |
242 | |
243 | // Collect constant elements for all inputs at this loop iteration. |
244 | for (int i = 0; i < numInputs; ++i) |
245 | computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; |
246 | |
247 | // Invoke the computation to get the corresponding constant output |
248 | // element. |
249 | intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; |
250 | } |
251 | } |
252 | |
253 | DenseElementsAttr outputAttr = |
254 | isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) |
255 | : DenseElementsAttr::get(outputType, intOutputValues); |
256 | |
257 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr); |
258 | return success(); |
259 | } |
260 | |
261 | private: |
262 | ControlFusionFn controlFn; |
263 | }; |
264 | |
265 | // Folds linalg.generic ops that are actually transposes on constant values. |
266 | struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> { |
267 | using FoldConstantBase::FoldConstantBase; |
268 | |
269 | bool matchIndexingMaps(GenericOp genericOp) const { |
270 | // We should have one input and one output. |
271 | return genericOp.getIndexingMapsArray().size() == 2; |
272 | } |
273 | |
274 | RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { |
275 | // Make sure the region only contains a yield op. |
276 | Block &body = genericOp.getRegion().front(); |
277 | if (!llvm::hasSingleElement(C&: body)) |
278 | return nullptr; |
279 | auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); |
280 | if (!yieldOp) |
281 | return nullptr; |
282 | |
283 | // The yield op should return the block argument corresponds to the input. |
284 | for (Value yieldVal : yieldOp.getValues()) { |
285 | auto yieldArg = dyn_cast<BlockArgument>(yieldVal); |
286 | if (!yieldArg || yieldArg.getOwner() != &body) |
287 | return nullptr; |
288 | if (yieldArg.getArgNumber() != 0) |
289 | return nullptr; |
290 | } |
291 | |
292 | // No computation; just return the orginal value. |
293 | return [](const APIntOrFloatArray &inputs) { |
294 | if (inputs.apFloats.empty()) |
295 | return APIntOrFloat{inputs.apInts.front(), std::nullopt}; |
296 | return APIntOrFloat{std::nullopt, inputs.apFloats.front()}; |
297 | }; |
298 | } |
299 | |
300 | ControlFusionFn controlFn; |
301 | }; |
302 | } // namespace |
303 | |
304 | void mlir::linalg::populateConstantFoldLinalgOperations( |
305 | RewritePatternSet &patterns, const ControlFusionFn &controlFn) { |
306 | MLIRContext *context = patterns.getContext(); |
307 | patterns.insert<FoldConstantTranspose>(arg&: context, args: controlFn); |
308 | } |
309 | |