| 1 | //===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file contains patterns for flattening an multi-rank memref-related |
| 10 | // ops into 1-d memref ops. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
| 18 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| 19 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
| 20 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 22 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 23 | #include "mlir/IR/AffineExpr.h" |
| 24 | #include "mlir/IR/Attributes.h" |
| 25 | #include "mlir/IR/Builders.h" |
| 26 | #include "mlir/IR/BuiltinTypes.h" |
| 27 | #include "mlir/IR/OpDefinition.h" |
| 28 | #include "mlir/IR/PatternMatch.h" |
| 29 | #include "mlir/Pass/Pass.h" |
| 30 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 31 | #include "llvm/ADT/SmallVector.h" |
| 32 | #include "llvm/ADT/TypeSwitch.h" |
| 33 | |
| 34 | #include <numeric> |
| 35 | |
| 36 | namespace mlir { |
| 37 | namespace memref { |
| 38 | #define GEN_PASS_DEF_FLATTENMEMREFSPASS |
| 39 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
| 40 | } // namespace memref |
| 41 | } // namespace mlir |
| 42 | |
| 43 | using namespace mlir; |
| 44 | |
| 45 | static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, |
| 46 | OpFoldResult in) { |
| 47 | if (Attribute offsetAttr = dyn_cast<Attribute>(in)) { |
| 48 | return rewriter.create<arith::ConstantIndexOp>( |
| 49 | loc, cast<IntegerAttr>(offsetAttr).getInt()); |
| 50 | } |
| 51 | return cast<Value>(in); |
| 52 | } |
| 53 | |
| 54 | /// Returns a collapsed memref and the linearized index to access the element |
| 55 | /// at the specified indices. |
| 56 | static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter, |
| 57 | Location loc, |
| 58 | Value source, |
| 59 | ValueRange indices) { |
| 60 | int64_t sourceOffset; |
| 61 | SmallVector<int64_t, 4> sourceStrides; |
| 62 | auto sourceType = cast<MemRefType>(source.getType()); |
| 63 | if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) { |
| 64 | assert(false); |
| 65 | } |
| 66 | |
| 67 | memref::ExtractStridedMetadataOp stridedMetadata = |
| 68 | rewriter.create<memref::ExtractStridedMetadataOp>(loc, source); |
| 69 | |
| 70 | auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); |
| 71 | OpFoldResult linearizedIndices; |
| 72 | memref::LinearizedMemRefInfo linearizedInfo; |
| 73 | std::tie(args&: linearizedInfo, args&: linearizedIndices) = |
| 74 | memref::getLinearizedMemRefOffsetAndSize( |
| 75 | rewriter, loc, typeBit, typeBit, |
| 76 | stridedMetadata.getConstifiedMixedOffset(), |
| 77 | stridedMetadata.getConstifiedMixedSizes(), |
| 78 | stridedMetadata.getConstifiedMixedStrides(), |
| 79 | getAsOpFoldResult(values: indices)); |
| 80 | |
| 81 | return std::make_pair( |
| 82 | rewriter.create<memref::ReinterpretCastOp>( |
| 83 | loc, source, |
| 84 | /* offset = */ linearizedInfo.linearizedOffset, |
| 85 | /* shapes = */ |
| 86 | ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize}, |
| 87 | /* strides = */ |
| 88 | ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}), |
| 89 | getValueFromOpFoldResult(rewriter, loc, linearizedIndices)); |
| 90 | } |
| 91 | |
| 92 | static bool needFlattening(Value val) { |
| 93 | auto type = cast<MemRefType>(val.getType()); |
| 94 | return type.getRank() > 1; |
| 95 | } |
| 96 | |
| 97 | static bool checkLayout(Value val) { |
| 98 | auto type = cast<MemRefType>(val.getType()); |
| 99 | return type.getLayout().isIdentity() || |
| 100 | isa<StridedLayoutAttr>(type.getLayout()); |
| 101 | } |
| 102 | |
| 103 | namespace { |
| 104 | static Value getTargetMemref(Operation *op) { |
| 105 | return llvm::TypeSwitch<Operation *, Value>(op) |
| 106 | .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp, |
| 107 | memref::AllocOp>([](auto op) { return op.getMemref(); }) |
| 108 | .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp, |
| 109 | vector::MaskedStoreOp, vector::TransferReadOp, |
| 110 | vector::TransferWriteOp>( |
| 111 | [](auto op) { return op.getBase(); }) |
| 112 | .Default([](auto) { return Value{}; }); |
| 113 | } |
| 114 | |
| 115 | template <typename T> |
| 116 | static void castAllocResult(T oper, T newOper, Location loc, |
| 117 | PatternRewriter &rewriter) { |
| 118 | memref::ExtractStridedMetadataOp stridedMetadata = |
| 119 | rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper); |
| 120 | rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( |
| 121 | oper, cast<MemRefType>(oper.getType()), newOper, |
| 122 | /*offset=*/rewriter.getIndexAttr(0), |
| 123 | stridedMetadata.getConstifiedMixedSizes(), |
| 124 | stridedMetadata.getConstifiedMixedStrides()); |
| 125 | } |
| 126 | |
| 127 | template <typename T> |
| 128 | static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, |
| 129 | Value offset) { |
| 130 | Location loc = op->getLoc(); |
| 131 | llvm::TypeSwitch<Operation *>(op.getOperation()) |
| 132 | .template Case<memref::AllocOp>([&](auto oper) { |
| 133 | auto newAlloc = rewriter.create<memref::AllocOp>( |
| 134 | loc, cast<MemRefType>(flatMemref.getType()), |
| 135 | oper.getAlignmentAttr()); |
| 136 | castAllocResult(oper, newAlloc, loc, rewriter); |
| 137 | }) |
| 138 | .template Case<memref::AllocaOp>([&](auto oper) { |
| 139 | auto newAlloca = rewriter.create<memref::AllocaOp>( |
| 140 | loc, cast<MemRefType>(flatMemref.getType()), |
| 141 | oper.getAlignmentAttr()); |
| 142 | castAllocResult(oper, newAlloca, loc, rewriter); |
| 143 | }) |
| 144 | .template Case<memref::LoadOp>([&](auto op) { |
| 145 | auto newLoad = rewriter.create<memref::LoadOp>( |
| 146 | loc, op->getResultTypes(), flatMemref, ValueRange{offset}); |
| 147 | newLoad->setAttrs(op->getAttrs()); |
| 148 | rewriter.replaceOp(op, newLoad.getResult()); |
| 149 | }) |
| 150 | .template Case<memref::StoreOp>([&](auto op) { |
| 151 | auto newStore = rewriter.create<memref::StoreOp>( |
| 152 | loc, op->getOperands().front(), flatMemref, ValueRange{offset}); |
| 153 | newStore->setAttrs(op->getAttrs()); |
| 154 | rewriter.replaceOp(op, newStore); |
| 155 | }) |
| 156 | .template Case<vector::LoadOp>([&](auto op) { |
| 157 | auto newLoad = rewriter.create<vector::LoadOp>( |
| 158 | loc, op->getResultTypes(), flatMemref, ValueRange{offset}); |
| 159 | newLoad->setAttrs(op->getAttrs()); |
| 160 | rewriter.replaceOp(op, newLoad.getResult()); |
| 161 | }) |
| 162 | .template Case<vector::StoreOp>([&](auto op) { |
| 163 | auto newStore = rewriter.create<vector::StoreOp>( |
| 164 | loc, op->getOperands().front(), flatMemref, ValueRange{offset}); |
| 165 | newStore->setAttrs(op->getAttrs()); |
| 166 | rewriter.replaceOp(op, newStore); |
| 167 | }) |
| 168 | .template Case<vector::MaskedLoadOp>([&](auto op) { |
| 169 | auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>( |
| 170 | loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), |
| 171 | op.getPassThru()); |
| 172 | newMaskedLoad->setAttrs(op->getAttrs()); |
| 173 | rewriter.replaceOp(op, newMaskedLoad.getResult()); |
| 174 | }) |
| 175 | .template Case<vector::MaskedStoreOp>([&](auto op) { |
| 176 | auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>( |
| 177 | loc, flatMemref, ValueRange{offset}, op.getMask(), |
| 178 | op.getValueToStore()); |
| 179 | newMaskedStore->setAttrs(op->getAttrs()); |
| 180 | rewriter.replaceOp(op, newMaskedStore); |
| 181 | }) |
| 182 | .template Case<vector::TransferReadOp>([&](auto op) { |
| 183 | auto newTransferRead = rewriter.create<vector::TransferReadOp>( |
| 184 | loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); |
| 185 | rewriter.replaceOp(op, newTransferRead.getResult()); |
| 186 | }) |
| 187 | .template Case<vector::TransferWriteOp>([&](auto op) { |
| 188 | auto newTransferWrite = rewriter.create<vector::TransferWriteOp>( |
| 189 | loc, op.getVector(), flatMemref, ValueRange{offset}); |
| 190 | rewriter.replaceOp(op, newTransferWrite); |
| 191 | }) |
| 192 | .Default([&](auto op) { |
| 193 | op->emitOpError("unimplemented: do not know how to replace op." ); |
| 194 | }); |
| 195 | } |
| 196 | |
| 197 | template <typename T> |
| 198 | static ValueRange getIndices(T op) { |
| 199 | if constexpr (std::is_same_v<T, memref::AllocaOp> || |
| 200 | std::is_same_v<T, memref::AllocOp>) { |
| 201 | return ValueRange{}; |
| 202 | } else { |
| 203 | return op.getIndices(); |
| 204 | } |
| 205 | } |
| 206 | |
| 207 | template <typename T> |
| 208 | static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) { |
| 209 | return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation()) |
| 210 | .template Case<vector::TransferReadOp, vector::TransferWriteOp>( |
| 211 | [&](auto oper) { |
| 212 | // For vector.transfer_read/write, must make sure: |
| 213 | // 1. all accesses are inbound, and |
| 214 | // 2. has an identity or minor identity permutation map. |
| 215 | auto permutationMap = oper.getPermutationMap(); |
| 216 | if (!permutationMap.isIdentity() && |
| 217 | !permutationMap.isMinorIdentity()) { |
| 218 | return rewriter.notifyMatchFailure( |
| 219 | oper, "only identity permutation map is supported" ); |
| 220 | } |
| 221 | mlir::ArrayAttr inbounds = oper.getInBounds(); |
| 222 | if (llvm::any_of(inbounds, [](Attribute attr) { |
| 223 | return !cast<BoolAttr>(Val&: attr).getValue(); |
| 224 | })) { |
| 225 | return rewriter.notifyMatchFailure(oper, |
| 226 | "only inbounds are supported" ); |
| 227 | } |
| 228 | return success(); |
| 229 | }) |
| 230 | .Default([&](auto op) { return success(); }); |
| 231 | } |
| 232 | |
| 233 | template <typename T> |
| 234 | struct MemRefRewritePattern : public OpRewritePattern<T> { |
| 235 | using OpRewritePattern<T>::OpRewritePattern; |
| 236 | LogicalResult matchAndRewrite(T op, |
| 237 | PatternRewriter &rewriter) const override { |
| 238 | LogicalResult canFlatten = canBeFlattened(op, rewriter); |
| 239 | if (failed(Result: canFlatten)) { |
| 240 | return canFlatten; |
| 241 | } |
| 242 | |
| 243 | Value memref = getTargetMemref(op); |
| 244 | if (!needFlattening(val: memref) || !checkLayout(val: memref)) |
| 245 | return failure(); |
| 246 | auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( |
| 247 | rewriter, op->getLoc(), memref, getIndices<T>(op)); |
| 248 | replaceOp<T>(op, rewriter, flatMemref, offset); |
| 249 | return success(); |
| 250 | } |
| 251 | }; |
| 252 | |
| 253 | struct FlattenMemrefsPass |
| 254 | : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> { |
| 255 | using Base::Base; |
| 256 | |
| 257 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 258 | registry.insert<affine::AffineDialect, arith::ArithDialect, |
| 259 | memref::MemRefDialect, vector::VectorDialect>(); |
| 260 | } |
| 261 | |
| 262 | void runOnOperation() override { |
| 263 | RewritePatternSet patterns(&getContext()); |
| 264 | |
| 265 | memref::populateFlattenMemrefsPatterns(patterns); |
| 266 | |
| 267 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
| 268 | return signalPassFailure(); |
| 269 | } |
| 270 | }; |
| 271 | |
| 272 | } // namespace |
| 273 | |
| 274 | void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { |
| 275 | patterns.insert<MemRefRewritePattern<memref::LoadOp>, |
| 276 | MemRefRewritePattern<memref::StoreOp>, |
| 277 | MemRefRewritePattern<memref::AllocOp>, |
| 278 | MemRefRewritePattern<memref::AllocaOp>, |
| 279 | MemRefRewritePattern<vector::LoadOp>, |
| 280 | MemRefRewritePattern<vector::StoreOp>, |
| 281 | MemRefRewritePattern<vector::TransferReadOp>, |
| 282 | MemRefRewritePattern<vector::TransferWriteOp>, |
| 283 | MemRefRewritePattern<vector::MaskedLoadOp>, |
| 284 | MemRefRewritePattern<vector::MaskedStoreOp>>( |
| 285 | patterns.getContext()); |
| 286 | } |
| 287 | |