1//===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements constant folding on Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16#include "mlir/IR/Matchers.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Support/LLVM.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20#include <optional>
21
22using namespace mlir;
23using namespace mlir::linalg;
24
25namespace {
26/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
27/// and permutation indexing maps.
28///
29/// `ConcreteType` should provide methods with signatures
30///
31/// ```c++
32/// bool matchIndexingMaps(GenericOp genericOp) const;
33/// RegionComputationFn getRegionComputeFn(GenericOp) const;
34/// ```
35///
36/// The latter inspects the region and returns the computation inside as a
37/// functor. The functor will be invoked with constant elements for all inputs
38/// and should return the corresponding computed constant element for output.
39template <typename ConcreteType>
40class FoldConstantBase : public OpRewritePattern<GenericOp> {
41public:
42 struct APIntOrFloat {
43 std::optional<APInt> apInt;
44 std::optional<APFloat> apFloat;
45 };
46 struct APIntOrFloatArray {
47 SmallVector<APInt> apInts;
48 SmallVector<APFloat> apFloats;
49 };
50 using RegionComputationFn =
51 std::function<APIntOrFloat(const APIntOrFloatArray &)>;
52
53 FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
54 PatternBenefit benefit = 1)
55 : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
56
57 LogicalResult matchAndRewrite(GenericOp genericOp,
58 PatternRewriter &rewriter) const override {
59 // Mixed and buffer sematics aren't supported.
60 if (!genericOp.hasPureTensorSemantics())
61 return failure();
62
63 // Only support ops generating one output for now.
64 if (genericOp.getNumDpsInits() != 1)
65 return failure();
66
67 auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front());
68 // Require the output types to be static given that we are generating
69 // constants.
70 if (!outputType || !outputType.hasStaticShape())
71 return failure();
72
73 if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
74 return isa<ShapedType>(Val: input.getType());
75 }))
76 return failure();
77
78 // Make sure all element types are the same.
79 auto getOperandElementType = [](Value value) {
80 return cast<ShapedType>(value.getType()).getElementType();
81 };
82 if (!llvm::all_equal(
83 llvm::map_range(genericOp->getOperands(), getOperandElementType)))
84 return failure();
85
86 // We can only handle the case where we have int/float elements.
87 auto elementType = outputType.getElementType();
88 if (!elementType.isIntOrFloat())
89 return failure();
90
91 // Require all indexing maps to be permutations for now. This is common and
92 // it simplifies input/output access greatly: we can do the data shuffling
93 // entirely in the compiler, without needing to turn all indices into
94 // Values, and then do affine apply on them, and then match back the
95 // constant again.
96 if (!llvm::all_of(genericOp.getIndexingMapsArray(),
97 [](AffineMap map) { return map.isPermutation(); }))
98 return failure();
99
100 for (OpOperand &operand : genericOp.getDpsInitsMutable()) {
101 if (genericOp.payloadUsesValueFromOperand(&operand))
102 return failure();
103 }
104
105 // Further check the indexing maps are okay for the ConcreteType.
106 if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
107 return failure();
108
109 // Defer to the concrete type to check the region and discover the
110 // computation inside.
111 RegionComputationFn computeFn =
112 static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
113 if (!computeFn)
114 return failure();
115
116 // All inputs should be constants.
117 int numInputs = genericOp.getNumDpsInputs();
118 SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
119 for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
120 if (!matchPattern(en.value()->get(),
121 m_Constant(&inputValues[en.index()])))
122 return failure();
123 }
124
125 // Identified this as a potential candidate for folding. Now check the
126 // policy to see whether we are allowed to proceed.
127 for (OpOperand *operand : genericOp.getDpsInputOperands()) {
128 if (!controlFn(operand))
129 return failure();
130 }
131
132 auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
133 SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
134 int64_t numElements = outputType.getNumElements();
135
136 // Use APInt/APFloat instead of Attribute here for constructing the output.
137 // This helps to avoid blowing up compiler memory usage: Attributes would
138 // unify the following cases but they have lifetime as the MLIRContext.
139 SmallVector<APInt> intOutputValues;
140 SmallVector<APFloat> fpOutputValues;
141 if (isa<FloatType>(elementType))
142 fpOutputValues.resize(N: numElements, NV: APFloat(0.f));
143 else
144 intOutputValues.resize(N: numElements);
145
146 // Return the constant dim positions from the given permutation map.
147 auto getDimPositions = [](AffineMap map) {
148 SmallVector<unsigned> dims;
149 dims.reserve(N: map.getNumResults());
150 for (AffineExpr result : map.getResults()) {
151 dims.push_back(Elt: cast<AffineDimExpr>(Val&: result).getPosition());
152 }
153 return dims;
154 };
155
156 SmallVector<SmallVector<unsigned>> inputDims;
157 for (int i = 0; i < numInputs; ++i)
158 inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i]));
159 auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back());
160 auto outputShape = outputType.getShape();
161
162 // Allocate small vectors for index delinearization. Initial values do not
163 // matter here as they will be overwritten later.
164 SmallVector<uint64_t> indices(loopBounds.size(), 0);
165 SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
166 SmallVector<SmallVector<uint64_t>> srcIndices(
167 numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
168 SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
169 uint64_t dstLinearIndex = 0;
170
171 // Allocate spaces for compute function inputs. Initial values do not matter
172 // here as they will be overwritten later.
173 APIntOrFloatArray computeFnInputs;
174
175 auto inputShapes = llvm::to_vector<4>(
176 llvm::map_range(genericOp.getInputs(), [](Value value) {
177 return cast<ShapedType>(value.getType()).getShape();
178 }));
179
180 // Given a `linearIndex`, remap it to a linear index to access linalg op
181 // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
182 // `srcLinearIndices`, `dstLinearIndex` in place.
183 auto computeRemappedLinearIndex = [&](int linearIndex) {
184 int totalCount = linearIndex;
185 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
186 indices[dim] = totalCount % loopBounds[dim];
187 totalCount /= loopBounds[dim];
188 }
189
190 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
191 for (int i = 0; i < numInputs; ++i)
192 srcIndices[i][dim] = indices[inputDims[i][dim]];
193 dstIndices[dim] = indices[outputDims[dim]];
194 }
195
196 dstLinearIndex = dstIndices.front();
197 for (int i = 0; i < numInputs; ++i)
198 srcLinearIndices[i] = srcIndices[i].front();
199
200 for (int dim = 1; dim < outputType.getRank(); ++dim) {
201 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
202 for (int i = 0; i < numInputs; ++i)
203 srcLinearIndices[i] =
204 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
205 }
206 };
207
208 bool isFloat = isa<FloatType>(elementType);
209 if (isFloat) {
210 SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
211 for (int i = 0; i < numInputs; ++i)
212 inFpRanges.push_back(inputValues[i].getValues<APFloat>());
213
214 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
215
216 // Transpose the input constant. Because we don't know its rank in
217 // advance, we need to loop over the range [0, element count) and
218 // delinearize the index.
219 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
220 computeRemappedLinearIndex(linearIndex);
221
222 // Collect constant elements for all inputs at this loop iteration.
223 for (int i = 0; i < numInputs; ++i)
224 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
225
226 // Invoke the computation to get the corresponding constant output
227 // element.
228 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
229 }
230 } else {
231 SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
232 for (int i = 0; i < numInputs; ++i)
233 inIntRanges.push_back(inputValues[i].getValues<APInt>());
234
235 computeFnInputs.apInts.resize(numInputs);
236
237 // Transpose the input constant. Because we don't know its rank in
238 // advance, we need to loop over the range [0, element count) and
239 // delinearize the index.
240 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
241 computeRemappedLinearIndex(linearIndex);
242
243 // Collect constant elements for all inputs at this loop iteration.
244 for (int i = 0; i < numInputs; ++i)
245 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
246
247 // Invoke the computation to get the corresponding constant output
248 // element.
249 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
250 }
251 }
252
253 DenseElementsAttr outputAttr =
254 isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
255 : DenseElementsAttr::get(outputType, intOutputValues);
256
257 rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
258 return success();
259 }
260
261private:
262 ControlFusionFn controlFn;
263};
264
265// Folds linalg.generic ops that are actually transposes on constant values.
266struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
267 using FoldConstantBase::FoldConstantBase;
268
269 bool matchIndexingMaps(GenericOp genericOp) const {
270 // We should have one input and one output.
271 return genericOp.getIndexingMapsArray().size() == 2;
272 }
273
274 RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
275 // Make sure the region only contains a yield op.
276 Block &body = genericOp.getRegion().front();
277 if (!llvm::hasSingleElement(C&: body))
278 return nullptr;
279 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
280 if (!yieldOp)
281 return nullptr;
282
283 // The yield op should return the block argument corresponds to the input.
284 for (Value yieldVal : yieldOp.getValues()) {
285 auto yieldArg = dyn_cast<BlockArgument>(yieldVal);
286 if (!yieldArg || yieldArg.getOwner() != &body)
287 return nullptr;
288 if (yieldArg.getArgNumber() != 0)
289 return nullptr;
290 }
291
292 // No computation; just return the orginal value.
293 return [](const APIntOrFloatArray &inputs) {
294 if (inputs.apFloats.empty())
295 return APIntOrFloat{inputs.apInts.front(), std::nullopt};
296 return APIntOrFloat{std::nullopt, inputs.apFloats.front()};
297 };
298 }
299
300 ControlFusionFn controlFn;
301};
302} // namespace
303
304void mlir::linalg::populateConstantFoldLinalgOperations(
305 RewritePatternSet &patterns, const ControlFusionFn &controlFn) {
306 MLIRContext *context = patterns.getContext();
307 patterns.insert<FoldConstantTranspose>(arg&: context, args: controlFn);
308}
309

source code of mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp