| 1 | //===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements constant folding on Linalg operations. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 14 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 15 | #include "mlir/IR/Matchers.h" |
| 16 | #include "mlir/IR/PatternMatch.h" |
| 17 | #include "mlir/Support/LLVM.h" |
| 18 | #include <optional> |
| 19 | |
| 20 | using namespace mlir; |
| 21 | using namespace mlir::linalg; |
| 22 | |
| 23 | namespace { |
| 24 | /// Base class for constant folding linalg structured ops with N inputs, 1 |
| 25 | /// output, and permutation indexing maps. |
| 26 | /// |
| 27 | /// `ConcreteType` should provide methods with signatures |
| 28 | /// |
| 29 | /// ```c++ |
| 30 | /// bool matchIndexingMaps(LinalgOp linalgOp) const; |
| 31 | /// RegionComputationFn getRegionComputeFn(LinalgOp) const; |
| 32 | /// ``` |
| 33 | /// |
| 34 | /// The latter inspects the region and returns the computation inside as a |
| 35 | /// functor. The functor will be invoked with constant elements for all inputs |
| 36 | /// and should return the corresponding computed constant element for output. |
| 37 | template <typename ConcreteType> |
| 38 | class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> { |
| 39 | public: |
| 40 | struct APIntOrFloat { |
| 41 | std::optional<APInt> apInt; |
| 42 | std::optional<APFloat> apFloat; |
| 43 | }; |
| 44 | struct APIntOrFloatArray { |
| 45 | SmallVector<APInt> apInts; |
| 46 | SmallVector<APFloat> apFloats; |
| 47 | }; |
| 48 | using RegionComputationFn = |
| 49 | std::function<APIntOrFloat(const APIntOrFloatArray &)>; |
| 50 | |
| 51 | FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn, |
| 52 | PatternBenefit benefit = 1) |
| 53 | : OpInterfaceRewritePattern<LinalgOp>(context, benefit), |
| 54 | controlFn(controlFn) {} |
| 55 | |
| 56 | LogicalResult matchAndRewrite(LinalgOp linalgOp, |
| 57 | PatternRewriter &rewriter) const override { |
| 58 | // Mixed and buffer sematics aren't supported. |
| 59 | if (!linalgOp.hasPureTensorSemantics()) |
| 60 | return failure(); |
| 61 | |
| 62 | // Only support ops generating one output for now. |
| 63 | if (linalgOp.getNumDpsInits() != 1) |
| 64 | return failure(); |
| 65 | |
| 66 | auto outputType = dyn_cast<ShapedType>(Val: linalgOp->getResultTypes().front()); |
| 67 | // Require the output types to be static given that we are generating |
| 68 | // constants. |
| 69 | if (!outputType || !outputType.hasStaticShape()) |
| 70 | return failure(); |
| 71 | |
| 72 | if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) { |
| 73 | return isa<ShapedType>(Val: input.getType()); |
| 74 | })) |
| 75 | return failure(); |
| 76 | |
| 77 | // Make sure all element types are the same. |
| 78 | auto getOperandElementType = [](Value value) { |
| 79 | return cast<ShapedType>(Val: value.getType()).getElementType(); |
| 80 | }; |
| 81 | if (!llvm::all_equal( |
| 82 | llvm::map_range(linalgOp->getOperands(), getOperandElementType))) |
| 83 | return failure(); |
| 84 | |
| 85 | // We can only handle the case where we have int/float elements. |
| 86 | auto elementType = outputType.getElementType(); |
| 87 | if (!elementType.isIntOrFloat()) |
| 88 | return failure(); |
| 89 | |
| 90 | // Require all indexing maps to be permutations for now. This is common and |
| 91 | // it simplifies input/output access greatly: we can do the data shuffling |
| 92 | // entirely in the compiler, without needing to turn all indices into |
| 93 | // Values, and then do affine apply on them, and then match back the |
| 94 | // constant again. |
| 95 | if (!llvm::all_of(linalgOp.getIndexingMapsArray(), |
| 96 | [](AffineMap map) { return map.isPermutation(); })) |
| 97 | return failure(); |
| 98 | |
| 99 | for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { |
| 100 | if (linalgOp.payloadUsesValueFromOperand(opOperand: &operand)) |
| 101 | return failure(); |
| 102 | } |
| 103 | |
| 104 | // Further check the indexing maps are okay for the ConcreteType. |
| 105 | if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp)) |
| 106 | return failure(); |
| 107 | |
| 108 | // Defer to the concrete type to check the region and discover the |
| 109 | // computation inside. |
| 110 | RegionComputationFn computeFn = |
| 111 | static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp); |
| 112 | if (!computeFn) |
| 113 | return failure(); |
| 114 | |
| 115 | // All inputs should be constants. |
| 116 | int numInputs = linalgOp.getNumDpsInputs(); |
| 117 | SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs); |
| 118 | for (const auto &en : llvm::enumerate(First: linalgOp.getDpsInputOperands())) { |
| 119 | if (!matchPattern(value: en.value()->get(), |
| 120 | pattern: m_Constant(bind_value: &inputValues[en.index()]))) |
| 121 | return failure(); |
| 122 | } |
| 123 | |
| 124 | // Identified this as a potential candidate for folding. Now check the |
| 125 | // policy to see whether we are allowed to proceed. |
| 126 | for (OpOperand *operand : linalgOp.getDpsInputOperands()) { |
| 127 | if (!controlFn(operand)) |
| 128 | return failure(); |
| 129 | } |
| 130 | |
| 131 | SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges(); |
| 132 | int64_t numElements = outputType.getNumElements(); |
| 133 | |
| 134 | // Use APInt/APFloat instead of Attribute here for constructing the output. |
| 135 | // This helps to avoid blowing up compiler memory usage: Attributes would |
| 136 | // unify the following cases but they have lifetime as the MLIRContext. |
| 137 | SmallVector<APInt> intOutputValues; |
| 138 | SmallVector<APFloat> fpOutputValues; |
| 139 | if (isa<FloatType>(Val: elementType)) |
| 140 | fpOutputValues.resize(N: numElements, NV: APFloat(0.f)); |
| 141 | else |
| 142 | intOutputValues.resize(N: numElements); |
| 143 | |
| 144 | // Return the constant dim positions from the given permutation map. |
| 145 | auto getDimPositions = [](AffineMap map) { |
| 146 | SmallVector<unsigned> dims; |
| 147 | dims.reserve(N: map.getNumResults()); |
| 148 | for (AffineExpr result : map.getResults()) { |
| 149 | dims.push_back(Elt: cast<AffineDimExpr>(Val&: result).getPosition()); |
| 150 | } |
| 151 | return dims; |
| 152 | }; |
| 153 | |
| 154 | SmallVector<SmallVector<unsigned>> inputDims; |
| 155 | for (int i = 0; i < numInputs; ++i) |
| 156 | inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i])); |
| 157 | auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back()); |
| 158 | auto outputShape = outputType.getShape(); |
| 159 | |
| 160 | // Allocate small vectors for index delinearization. Initial values do not |
| 161 | // matter here as they will be overwritten later. |
| 162 | SmallVector<uint64_t> indices(loopBounds.size(), 0); |
| 163 | SmallVector<uint64_t> dstIndices(loopBounds.size(), 0); |
| 164 | SmallVector<SmallVector<uint64_t>> srcIndices( |
| 165 | numInputs, SmallVector<uint64_t>(loopBounds.size(), 0)); |
| 166 | SmallVector<uint64_t> srcLinearIndices(numInputs, 0); |
| 167 | uint64_t dstLinearIndex = 0; |
| 168 | |
| 169 | // Allocate spaces for compute function inputs. Initial values do not matter |
| 170 | // here as they will be overwritten later. |
| 171 | APIntOrFloatArray computeFnInputs; |
| 172 | |
| 173 | auto inputShapes = llvm::to_vector<4>( |
| 174 | llvm::map_range(linalgOp.getDpsInputs(), [](Value value) { |
| 175 | return cast<ShapedType>(Val: value.getType()).getShape(); |
| 176 | })); |
| 177 | |
| 178 | // Given a `linearIndex`, remap it to a linear index to access linalg op |
| 179 | // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, |
| 180 | // `srcLinearIndices`, `dstLinearIndex` in place. |
| 181 | auto computeRemappedLinearIndex = [&](int linearIndex) { |
| 182 | int totalCount = linearIndex; |
| 183 | for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { |
| 184 | indices[dim] = totalCount % loopBounds[dim]; |
| 185 | totalCount /= loopBounds[dim]; |
| 186 | } |
| 187 | |
| 188 | for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { |
| 189 | for (int i = 0; i < numInputs; ++i) |
| 190 | srcIndices[i][dim] = indices[inputDims[i][dim]]; |
| 191 | dstIndices[dim] = indices[outputDims[dim]]; |
| 192 | } |
| 193 | |
| 194 | dstLinearIndex = dstIndices.front(); |
| 195 | for (int i = 0; i < numInputs; ++i) |
| 196 | srcLinearIndices[i] = srcIndices[i].front(); |
| 197 | |
| 198 | for (int dim = 1; dim < outputType.getRank(); ++dim) { |
| 199 | dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; |
| 200 | for (int i = 0; i < numInputs; ++i) |
| 201 | srcLinearIndices[i] = |
| 202 | srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; |
| 203 | } |
| 204 | }; |
| 205 | |
| 206 | bool isFloat = isa<FloatType>(Val: elementType); |
| 207 | if (isFloat) { |
| 208 | SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges; |
| 209 | for (int i = 0; i < numInputs; ++i) |
| 210 | inFpRanges.push_back(Elt: inputValues[i].getValues<APFloat>()); |
| 211 | |
| 212 | computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); |
| 213 | |
| 214 | // Transpose the input constant. Because we don't know its rank in |
| 215 | // advance, we need to loop over the range [0, element count) and |
| 216 | // delinearize the index. |
| 217 | for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { |
| 218 | computeRemappedLinearIndex(linearIndex); |
| 219 | |
| 220 | // Collect constant elements for all inputs at this loop iteration. |
| 221 | for (int i = 0; i < numInputs; ++i) |
| 222 | computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; |
| 223 | |
| 224 | // Invoke the computation to get the corresponding constant output |
| 225 | // element. |
| 226 | fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; |
| 227 | } |
| 228 | } else { |
| 229 | SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges; |
| 230 | for (int i = 0; i < numInputs; ++i) |
| 231 | inIntRanges.push_back(Elt: inputValues[i].getValues<APInt>()); |
| 232 | |
| 233 | computeFnInputs.apInts.resize(numInputs); |
| 234 | |
| 235 | // Transpose the input constant. Because we don't know its rank in |
| 236 | // advance, we need to loop over the range [0, element count) and |
| 237 | // delinearize the index. |
| 238 | for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { |
| 239 | computeRemappedLinearIndex(linearIndex); |
| 240 | |
| 241 | // Collect constant elements for all inputs at this loop iteration. |
| 242 | for (int i = 0; i < numInputs; ++i) |
| 243 | computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; |
| 244 | |
| 245 | // Invoke the computation to get the corresponding constant output |
| 246 | // element. |
| 247 | intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; |
| 248 | } |
| 249 | } |
| 250 | |
| 251 | DenseElementsAttr outputAttr = |
| 252 | isFloat ? DenseElementsAttr::get(type: outputType, values: fpOutputValues) |
| 253 | : DenseElementsAttr::get(type: outputType, values: intOutputValues); |
| 254 | |
| 255 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(op: linalgOp, args&: outputAttr); |
| 256 | return success(); |
| 257 | } |
| 258 | |
| 259 | private: |
| 260 | ControlFusionFn controlFn; |
| 261 | }; |
| 262 | |
| 263 | // Folds linalg.transpose (and linalg.generic ops that are actually transposes) |
| 264 | // on constant values. |
| 265 | struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> { |
| 266 | |
| 267 | using FoldConstantBase::FoldConstantBase; |
| 268 | |
| 269 | bool matchIndexingMaps(LinalgOp linalgOp) const { |
| 270 | // We should have one input and one output. |
| 271 | return linalgOp.getIndexingMapsArray().size() == 2; |
| 272 | } |
| 273 | |
| 274 | RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const { |
| 275 | // Make sure the region only contains a yield op. |
| 276 | Block &body = linalgOp->getRegion(index: 0).front(); |
| 277 | if (!llvm::hasSingleElement(C&: body)) |
| 278 | return nullptr; |
| 279 | auto yieldOp = dyn_cast<linalg::YieldOp>(Val: body.getTerminator()); |
| 280 | if (!yieldOp) |
| 281 | return nullptr; |
| 282 | |
| 283 | // The yield op should return the block argument corresponds to the input. |
| 284 | for (Value yieldVal : yieldOp.getValues()) { |
| 285 | auto yieldArg = dyn_cast<BlockArgument>(Val&: yieldVal); |
| 286 | if (!yieldArg || yieldArg.getOwner() != &body) |
| 287 | return nullptr; |
| 288 | if (yieldArg.getArgNumber() != 0) |
| 289 | return nullptr; |
| 290 | } |
| 291 | |
| 292 | // No computation; just return the orginal value. |
| 293 | return [](const APIntOrFloatArray &inputs) { |
| 294 | if (inputs.apFloats.empty()) |
| 295 | return APIntOrFloat{.apInt: inputs.apInts.front(), .apFloat: std::nullopt}; |
| 296 | return APIntOrFloat{.apInt: std::nullopt, .apFloat: inputs.apFloats.front()}; |
| 297 | }; |
| 298 | } |
| 299 | |
| 300 | ControlFusionFn controlFn; |
| 301 | }; |
| 302 | } // namespace |
| 303 | |
| 304 | void mlir::linalg::populateConstantFoldLinalgOperations( |
| 305 | RewritePatternSet &patterns, const ControlFusionFn &controlFn) { |
| 306 | MLIRContext *context = patterns.getContext(); |
| 307 | patterns.insert<FoldConstantTranspose>(arg&: context, args: controlFn); |
| 308 | } |
| 309 | |