| 1 | //===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements constant folding on Linalg operations. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 15 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 16 | #include "mlir/IR/Matchers.h" |
| 17 | #include "mlir/IR/PatternMatch.h" |
| 18 | #include "mlir/Support/LLVM.h" |
| 19 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 20 | #include <optional> |
| 21 | |
| 22 | using namespace mlir; |
| 23 | using namespace mlir::linalg; |
| 24 | |
| 25 | namespace { |
| 26 | /// Base class for constant folding linalg structured ops with N inputs, 1 |
| 27 | /// output, and permutation indexing maps. |
| 28 | /// |
| 29 | /// `ConcreteType` should provide methods with signatures |
| 30 | /// |
| 31 | /// ```c++ |
| 32 | /// bool matchIndexingMaps(LinalgOp linalgOp) const; |
| 33 | /// RegionComputationFn getRegionComputeFn(LinalgOp) const; |
| 34 | /// ``` |
| 35 | /// |
| 36 | /// The latter inspects the region and returns the computation inside as a |
| 37 | /// functor. The functor will be invoked with constant elements for all inputs |
| 38 | /// and should return the corresponding computed constant element for output. |
| 39 | template <typename ConcreteType> |
| 40 | class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> { |
| 41 | public: |
| 42 | struct APIntOrFloat { |
| 43 | std::optional<APInt> apInt; |
| 44 | std::optional<APFloat> apFloat; |
| 45 | }; |
| 46 | struct APIntOrFloatArray { |
| 47 | SmallVector<APInt> apInts; |
| 48 | SmallVector<APFloat> apFloats; |
| 49 | }; |
| 50 | using RegionComputationFn = |
| 51 | std::function<APIntOrFloat(const APIntOrFloatArray &)>; |
| 52 | |
| 53 | FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn, |
| 54 | PatternBenefit benefit = 1) |
| 55 | : OpInterfaceRewritePattern<LinalgOp>(context, benefit), |
| 56 | controlFn(controlFn) {} |
| 57 | |
| 58 | LogicalResult matchAndRewrite(LinalgOp linalgOp, |
| 59 | PatternRewriter &rewriter) const override { |
| 60 | // Mixed and buffer sematics aren't supported. |
| 61 | if (!linalgOp.hasPureTensorSemantics()) |
| 62 | return failure(); |
| 63 | |
| 64 | // Only support ops generating one output for now. |
| 65 | if (linalgOp.getNumDpsInits() != 1) |
| 66 | return failure(); |
| 67 | |
| 68 | auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front()); |
| 69 | // Require the output types to be static given that we are generating |
| 70 | // constants. |
| 71 | if (!outputType || !outputType.hasStaticShape()) |
| 72 | return failure(); |
| 73 | |
| 74 | if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) { |
| 75 | return isa<ShapedType>(Val: input.getType()); |
| 76 | })) |
| 77 | return failure(); |
| 78 | |
| 79 | // Make sure all element types are the same. |
| 80 | auto getOperandElementType = [](Value value) { |
| 81 | return cast<ShapedType>(value.getType()).getElementType(); |
| 82 | }; |
| 83 | if (!llvm::all_equal( |
| 84 | llvm::map_range(linalgOp->getOperands(), getOperandElementType))) |
| 85 | return failure(); |
| 86 | |
| 87 | // We can only handle the case where we have int/float elements. |
| 88 | auto elementType = outputType.getElementType(); |
| 89 | if (!elementType.isIntOrFloat()) |
| 90 | return failure(); |
| 91 | |
| 92 | // Require all indexing maps to be permutations for now. This is common and |
| 93 | // it simplifies input/output access greatly: we can do the data shuffling |
| 94 | // entirely in the compiler, without needing to turn all indices into |
| 95 | // Values, and then do affine apply on them, and then match back the |
| 96 | // constant again. |
| 97 | if (!llvm::all_of(linalgOp.getIndexingMapsArray(), |
| 98 | [](AffineMap map) { return map.isPermutation(); })) |
| 99 | return failure(); |
| 100 | |
| 101 | for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { |
| 102 | if (linalgOp.payloadUsesValueFromOperand(&operand)) |
| 103 | return failure(); |
| 104 | } |
| 105 | |
| 106 | // Further check the indexing maps are okay for the ConcreteType. |
| 107 | if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp)) |
| 108 | return failure(); |
| 109 | |
| 110 | // Defer to the concrete type to check the region and discover the |
| 111 | // computation inside. |
| 112 | RegionComputationFn computeFn = |
| 113 | static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp); |
| 114 | if (!computeFn) |
| 115 | return failure(); |
| 116 | |
| 117 | // All inputs should be constants. |
| 118 | int numInputs = linalgOp.getNumDpsInputs(); |
| 119 | SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs); |
| 120 | for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) { |
| 121 | if (!matchPattern(en.value()->get(), |
| 122 | m_Constant(&inputValues[en.index()]))) |
| 123 | return failure(); |
| 124 | } |
| 125 | |
| 126 | // Identified this as a potential candidate for folding. Now check the |
| 127 | // policy to see whether we are allowed to proceed. |
| 128 | for (OpOperand *operand : linalgOp.getDpsInputOperands()) { |
| 129 | if (!controlFn(operand)) |
| 130 | return failure(); |
| 131 | } |
| 132 | |
| 133 | SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges(); |
| 134 | int64_t numElements = outputType.getNumElements(); |
| 135 | |
| 136 | // Use APInt/APFloat instead of Attribute here for constructing the output. |
| 137 | // This helps to avoid blowing up compiler memory usage: Attributes would |
| 138 | // unify the following cases but they have lifetime as the MLIRContext. |
| 139 | SmallVector<APInt> intOutputValues; |
| 140 | SmallVector<APFloat> fpOutputValues; |
| 141 | if (isa<FloatType>(elementType)) |
| 142 | fpOutputValues.resize(N: numElements, NV: APFloat(0.f)); |
| 143 | else |
| 144 | intOutputValues.resize(N: numElements); |
| 145 | |
| 146 | // Return the constant dim positions from the given permutation map. |
| 147 | auto getDimPositions = [](AffineMap map) { |
| 148 | SmallVector<unsigned> dims; |
| 149 | dims.reserve(N: map.getNumResults()); |
| 150 | for (AffineExpr result : map.getResults()) { |
| 151 | dims.push_back(Elt: cast<AffineDimExpr>(Val&: result).getPosition()); |
| 152 | } |
| 153 | return dims; |
| 154 | }; |
| 155 | |
| 156 | SmallVector<SmallVector<unsigned>> inputDims; |
| 157 | for (int i = 0; i < numInputs; ++i) |
| 158 | inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i])); |
| 159 | auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back()); |
| 160 | auto outputShape = outputType.getShape(); |
| 161 | |
| 162 | // Allocate small vectors for index delinearization. Initial values do not |
| 163 | // matter here as they will be overwritten later. |
| 164 | SmallVector<uint64_t> indices(loopBounds.size(), 0); |
| 165 | SmallVector<uint64_t> dstIndices(loopBounds.size(), 0); |
| 166 | SmallVector<SmallVector<uint64_t>> srcIndices( |
| 167 | numInputs, SmallVector<uint64_t>(loopBounds.size(), 0)); |
| 168 | SmallVector<uint64_t> srcLinearIndices(numInputs, 0); |
| 169 | uint64_t dstLinearIndex = 0; |
| 170 | |
| 171 | // Allocate spaces for compute function inputs. Initial values do not matter |
| 172 | // here as they will be overwritten later. |
| 173 | APIntOrFloatArray computeFnInputs; |
| 174 | |
| 175 | auto inputShapes = llvm::to_vector<4>( |
| 176 | llvm::map_range(linalgOp.getDpsInputs(), [](Value value) { |
| 177 | return cast<ShapedType>(value.getType()).getShape(); |
| 178 | })); |
| 179 | |
| 180 | // Given a `linearIndex`, remap it to a linear index to access linalg op |
| 181 | // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, |
| 182 | // `srcLinearIndices`, `dstLinearIndex` in place. |
| 183 | auto computeRemappedLinearIndex = [&](int linearIndex) { |
| 184 | int totalCount = linearIndex; |
| 185 | for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { |
| 186 | indices[dim] = totalCount % loopBounds[dim]; |
| 187 | totalCount /= loopBounds[dim]; |
| 188 | } |
| 189 | |
| 190 | for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { |
| 191 | for (int i = 0; i < numInputs; ++i) |
| 192 | srcIndices[i][dim] = indices[inputDims[i][dim]]; |
| 193 | dstIndices[dim] = indices[outputDims[dim]]; |
| 194 | } |
| 195 | |
| 196 | dstLinearIndex = dstIndices.front(); |
| 197 | for (int i = 0; i < numInputs; ++i) |
| 198 | srcLinearIndices[i] = srcIndices[i].front(); |
| 199 | |
| 200 | for (int dim = 1; dim < outputType.getRank(); ++dim) { |
| 201 | dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; |
| 202 | for (int i = 0; i < numInputs; ++i) |
| 203 | srcLinearIndices[i] = |
| 204 | srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; |
| 205 | } |
| 206 | }; |
| 207 | |
| 208 | bool isFloat = isa<FloatType>(elementType); |
| 209 | if (isFloat) { |
| 210 | SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges; |
| 211 | for (int i = 0; i < numInputs; ++i) |
| 212 | inFpRanges.push_back(inputValues[i].getValues<APFloat>()); |
| 213 | |
| 214 | computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); |
| 215 | |
| 216 | // Transpose the input constant. Because we don't know its rank in |
| 217 | // advance, we need to loop over the range [0, element count) and |
| 218 | // delinearize the index. |
| 219 | for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { |
| 220 | computeRemappedLinearIndex(linearIndex); |
| 221 | |
| 222 | // Collect constant elements for all inputs at this loop iteration. |
| 223 | for (int i = 0; i < numInputs; ++i) |
| 224 | computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; |
| 225 | |
| 226 | // Invoke the computation to get the corresponding constant output |
| 227 | // element. |
| 228 | fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; |
| 229 | } |
| 230 | } else { |
| 231 | SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges; |
| 232 | for (int i = 0; i < numInputs; ++i) |
| 233 | inIntRanges.push_back(inputValues[i].getValues<APInt>()); |
| 234 | |
| 235 | computeFnInputs.apInts.resize(numInputs); |
| 236 | |
| 237 | // Transpose the input constant. Because we don't know its rank in |
| 238 | // advance, we need to loop over the range [0, element count) and |
| 239 | // delinearize the index. |
| 240 | for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { |
| 241 | computeRemappedLinearIndex(linearIndex); |
| 242 | |
| 243 | // Collect constant elements for all inputs at this loop iteration. |
| 244 | for (int i = 0; i < numInputs; ++i) |
| 245 | computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; |
| 246 | |
| 247 | // Invoke the computation to get the corresponding constant output |
| 248 | // element. |
| 249 | intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; |
| 250 | } |
| 251 | } |
| 252 | |
| 253 | DenseElementsAttr outputAttr = |
| 254 | isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) |
| 255 | : DenseElementsAttr::get(outputType, intOutputValues); |
| 256 | |
| 257 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr); |
| 258 | return success(); |
| 259 | } |
| 260 | |
| 261 | private: |
| 262 | ControlFusionFn controlFn; |
| 263 | }; |
| 264 | |
| 265 | // Folds linalg.transpose (and linalg.generic ops that are actually transposes) |
| 266 | // on constant values. |
| 267 | struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> { |
| 268 | |
| 269 | using FoldConstantBase::FoldConstantBase; |
| 270 | |
| 271 | bool matchIndexingMaps(LinalgOp linalgOp) const { |
| 272 | // We should have one input and one output. |
| 273 | return linalgOp.getIndexingMapsArray().size() == 2; |
| 274 | } |
| 275 | |
| 276 | RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const { |
| 277 | // Make sure the region only contains a yield op. |
| 278 | Block &body = linalgOp->getRegion(0).front(); |
| 279 | if (!llvm::hasSingleElement(C&: body)) |
| 280 | return nullptr; |
| 281 | auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); |
| 282 | if (!yieldOp) |
| 283 | return nullptr; |
| 284 | |
| 285 | // The yield op should return the block argument corresponds to the input. |
| 286 | for (Value yieldVal : yieldOp.getValues()) { |
| 287 | auto yieldArg = dyn_cast<BlockArgument>(yieldVal); |
| 288 | if (!yieldArg || yieldArg.getOwner() != &body) |
| 289 | return nullptr; |
| 290 | if (yieldArg.getArgNumber() != 0) |
| 291 | return nullptr; |
| 292 | } |
| 293 | |
| 294 | // No computation; just return the orginal value. |
| 295 | return [](const APIntOrFloatArray &inputs) { |
| 296 | if (inputs.apFloats.empty()) |
| 297 | return APIntOrFloat{.apInt: inputs.apInts.front(), .apFloat: std::nullopt}; |
| 298 | return APIntOrFloat{.apInt: std::nullopt, .apFloat: inputs.apFloats.front()}; |
| 299 | }; |
| 300 | } |
| 301 | |
| 302 | ControlFusionFn controlFn; |
| 303 | }; |
| 304 | } // namespace |
| 305 | |
| 306 | void mlir::linalg::populateConstantFoldLinalgOperations( |
| 307 | RewritePatternSet &patterns, const ControlFusionFn &controlFn) { |
| 308 | MLIRContext *context = patterns.getContext(); |
| 309 | patterns.insert<FoldConstantTranspose>(arg&: context, args: controlFn); |
| 310 | } |
| 311 | |