| 1 | //===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file contains patterns for flattening an multi-rank memref-related |
| 10 | // ops into 1-d memref ops. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
| 18 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| 19 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
| 20 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 22 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 23 | #include "mlir/IR/Attributes.h" |
| 24 | #include "mlir/IR/Builders.h" |
| 25 | #include "mlir/IR/BuiltinTypes.h" |
| 26 | #include "mlir/IR/OpDefinition.h" |
| 27 | #include "mlir/IR/PatternMatch.h" |
| 28 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 29 | #include "llvm/ADT/TypeSwitch.h" |
| 30 | |
| 31 | namespace mlir { |
| 32 | namespace memref { |
| 33 | #define GEN_PASS_DEF_FLATTENMEMREFSPASS |
| 34 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
| 35 | } // namespace memref |
| 36 | } // namespace mlir |
| 37 | |
| 38 | using namespace mlir; |
| 39 | |
| 40 | static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, |
| 41 | OpFoldResult in) { |
| 42 | if (Attribute offsetAttr = dyn_cast<Attribute>(Val&: in)) { |
| 43 | return rewriter.create<arith::ConstantIndexOp>( |
| 44 | location: loc, args: cast<IntegerAttr>(Val&: offsetAttr).getInt()); |
| 45 | } |
| 46 | return cast<Value>(Val&: in); |
| 47 | } |
| 48 | |
| 49 | /// Returns a collapsed memref and the linearized index to access the element |
| 50 | /// at the specified indices. |
| 51 | static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter, |
| 52 | Location loc, |
| 53 | Value source, |
| 54 | ValueRange indices) { |
| 55 | int64_t sourceOffset; |
| 56 | SmallVector<int64_t, 4> sourceStrides; |
| 57 | auto sourceType = cast<MemRefType>(Val: source.getType()); |
| 58 | if (failed(Result: sourceType.getStridesAndOffset(strides&: sourceStrides, offset&: sourceOffset))) { |
| 59 | assert(false); |
| 60 | } |
| 61 | |
| 62 | memref::ExtractStridedMetadataOp stridedMetadata = |
| 63 | rewriter.create<memref::ExtractStridedMetadataOp>(location: loc, args&: source); |
| 64 | |
| 65 | auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); |
| 66 | OpFoldResult linearizedIndices; |
| 67 | memref::LinearizedMemRefInfo linearizedInfo; |
| 68 | std::tie(args&: linearizedInfo, args&: linearizedIndices) = |
| 69 | memref::getLinearizedMemRefOffsetAndSize( |
| 70 | builder&: rewriter, loc, srcBits: typeBit, dstBits: typeBit, |
| 71 | offset: stridedMetadata.getConstifiedMixedOffset(), |
| 72 | sizes: stridedMetadata.getConstifiedMixedSizes(), |
| 73 | strides: stridedMetadata.getConstifiedMixedStrides(), |
| 74 | indices: getAsOpFoldResult(values: indices)); |
| 75 | |
| 76 | return std::make_pair( |
| 77 | x: rewriter.create<memref::ReinterpretCastOp>( |
| 78 | location: loc, args&: source, |
| 79 | /* offset = */ args&: linearizedInfo.linearizedOffset, |
| 80 | /* shapes = */ |
| 81 | args: ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize}, |
| 82 | /* strides = */ |
| 83 | args: ArrayRef<OpFoldResult>{rewriter.getIndexAttr(value: 1)}), |
| 84 | y: getValueFromOpFoldResult(rewriter, loc, in: linearizedIndices)); |
| 85 | } |
| 86 | |
| 87 | static bool needFlattening(Value val) { |
| 88 | auto type = cast<MemRefType>(Val: val.getType()); |
| 89 | return type.getRank() > 1; |
| 90 | } |
| 91 | |
| 92 | static bool checkLayout(Value val) { |
| 93 | auto type = cast<MemRefType>(Val: val.getType()); |
| 94 | return type.getLayout().isIdentity() || |
| 95 | isa<StridedLayoutAttr>(Val: type.getLayout()); |
| 96 | } |
| 97 | |
| 98 | namespace { |
| 99 | static Value getTargetMemref(Operation *op) { |
| 100 | return llvm::TypeSwitch<Operation *, Value>(op) |
| 101 | .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp, |
| 102 | memref::AllocOp>(caseFn: [](auto op) { return op.getMemref(); }) |
| 103 | .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp, |
| 104 | vector::MaskedStoreOp, vector::TransferReadOp, |
| 105 | vector::TransferWriteOp>( |
| 106 | caseFn: [](auto op) { return op.getBase(); }) |
| 107 | .Default(defaultFn: [](auto) { return Value{}; }); |
| 108 | } |
| 109 | |
| 110 | template <typename T> |
| 111 | static void castAllocResult(T oper, T newOper, Location loc, |
| 112 | PatternRewriter &rewriter) { |
| 113 | memref::ExtractStridedMetadataOp stridedMetadata = |
| 114 | rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper); |
| 115 | rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( |
| 116 | oper, cast<MemRefType>(oper.getType()), newOper, |
| 117 | /*offset=*/rewriter.getIndexAttr(value: 0), |
| 118 | stridedMetadata.getConstifiedMixedSizes(), |
| 119 | stridedMetadata.getConstifiedMixedStrides()); |
| 120 | } |
| 121 | |
| 122 | template <typename T> |
| 123 | static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, |
| 124 | Value offset) { |
| 125 | Location loc = op->getLoc(); |
| 126 | llvm::TypeSwitch<Operation *>(op.getOperation()) |
| 127 | .template Case<memref::AllocOp>([&](auto oper) { |
| 128 | auto newAlloc = rewriter.create<memref::AllocOp>( |
| 129 | loc, cast<MemRefType>(Val: flatMemref.getType()), |
| 130 | oper.getAlignmentAttr()); |
| 131 | castAllocResult(oper, newAlloc, loc, rewriter); |
| 132 | }) |
| 133 | .template Case<memref::AllocaOp>([&](auto oper) { |
| 134 | auto newAlloca = rewriter.create<memref::AllocaOp>( |
| 135 | loc, cast<MemRefType>(Val: flatMemref.getType()), |
| 136 | oper.getAlignmentAttr()); |
| 137 | castAllocResult(oper, newAlloca, loc, rewriter); |
| 138 | }) |
| 139 | .template Case<memref::LoadOp>([&](auto op) { |
| 140 | auto newLoad = rewriter.create<memref::LoadOp>( |
| 141 | loc, op->getResultTypes(), flatMemref, ValueRange{offset}); |
| 142 | newLoad->setAttrs(op->getAttrs()); |
| 143 | rewriter.replaceOp(op, newLoad.getResult()); |
| 144 | }) |
| 145 | .template Case<memref::StoreOp>([&](auto op) { |
| 146 | auto newStore = rewriter.create<memref::StoreOp>( |
| 147 | loc, op->getOperands().front(), flatMemref, ValueRange{offset}); |
| 148 | newStore->setAttrs(op->getAttrs()); |
| 149 | rewriter.replaceOp(op, newStore); |
| 150 | }) |
| 151 | .template Case<vector::LoadOp>([&](auto op) { |
| 152 | auto newLoad = rewriter.create<vector::LoadOp>( |
| 153 | loc, op->getResultTypes(), flatMemref, ValueRange{offset}); |
| 154 | newLoad->setAttrs(op->getAttrs()); |
| 155 | rewriter.replaceOp(op, newLoad.getResult()); |
| 156 | }) |
| 157 | .template Case<vector::StoreOp>([&](auto op) { |
| 158 | auto newStore = rewriter.create<vector::StoreOp>( |
| 159 | loc, op->getOperands().front(), flatMemref, ValueRange{offset}); |
| 160 | newStore->setAttrs(op->getAttrs()); |
| 161 | rewriter.replaceOp(op, newStore); |
| 162 | }) |
| 163 | .template Case<vector::MaskedLoadOp>([&](auto op) { |
| 164 | auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>( |
| 165 | loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), |
| 166 | op.getPassThru()); |
| 167 | newMaskedLoad->setAttrs(op->getAttrs()); |
| 168 | rewriter.replaceOp(op, newMaskedLoad.getResult()); |
| 169 | }) |
| 170 | .template Case<vector::MaskedStoreOp>([&](auto op) { |
| 171 | auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>( |
| 172 | loc, flatMemref, ValueRange{offset}, op.getMask(), |
| 173 | op.getValueToStore()); |
| 174 | newMaskedStore->setAttrs(op->getAttrs()); |
| 175 | rewriter.replaceOp(op, newMaskedStore); |
| 176 | }) |
| 177 | .template Case<vector::TransferReadOp>([&](auto op) { |
| 178 | auto newTransferRead = rewriter.create<vector::TransferReadOp>( |
| 179 | loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); |
| 180 | rewriter.replaceOp(op, newTransferRead.getResult()); |
| 181 | }) |
| 182 | .template Case<vector::TransferWriteOp>([&](auto op) { |
| 183 | auto newTransferWrite = rewriter.create<vector::TransferWriteOp>( |
| 184 | loc, op.getVector(), flatMemref, ValueRange{offset}); |
| 185 | rewriter.replaceOp(op, newTransferWrite); |
| 186 | }) |
| 187 | .Default([&](auto op) { |
| 188 | op->emitOpError("unimplemented: do not know how to replace op." ); |
| 189 | }); |
| 190 | } |
| 191 | |
| 192 | template <typename T> |
| 193 | static ValueRange getIndices(T op) { |
| 194 | if constexpr (std::is_same_v<T, memref::AllocaOp> || |
| 195 | std::is_same_v<T, memref::AllocOp>) { |
| 196 | return ValueRange{}; |
| 197 | } else { |
| 198 | return op.getIndices(); |
| 199 | } |
| 200 | } |
| 201 | |
| 202 | template <typename T> |
| 203 | static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) { |
| 204 | return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation()) |
| 205 | .template Case<vector::TransferReadOp, vector::TransferWriteOp>( |
| 206 | [&](auto oper) { |
| 207 | // For vector.transfer_read/write, must make sure: |
| 208 | // 1. all accesses are inbound, and |
| 209 | // 2. has an identity or minor identity permutation map. |
| 210 | auto permutationMap = oper.getPermutationMap(); |
| 211 | if (!permutationMap.isIdentity() && |
| 212 | !permutationMap.isMinorIdentity()) { |
| 213 | return rewriter.notifyMatchFailure( |
| 214 | oper, "only identity permutation map is supported" ); |
| 215 | } |
| 216 | mlir::ArrayAttr inbounds = oper.getInBounds(); |
| 217 | if (llvm::any_of(inbounds, [](Attribute attr) { |
| 218 | return !cast<BoolAttr>(Val&: attr).getValue(); |
| 219 | })) { |
| 220 | return rewriter.notifyMatchFailure(oper, |
| 221 | "only inbounds are supported" ); |
| 222 | } |
| 223 | return success(); |
| 224 | }) |
| 225 | .Default([&](auto op) { return success(); }); |
| 226 | } |
| 227 | |
| 228 | template <typename T> |
| 229 | struct MemRefRewritePattern : public OpRewritePattern<T> { |
| 230 | using OpRewritePattern<T>::OpRewritePattern; |
| 231 | LogicalResult matchAndRewrite(T op, |
| 232 | PatternRewriter &rewriter) const override { |
| 233 | LogicalResult canFlatten = canBeFlattened(op, rewriter); |
| 234 | if (failed(Result: canFlatten)) { |
| 235 | return canFlatten; |
| 236 | } |
| 237 | |
| 238 | Value memref = getTargetMemref(op); |
| 239 | if (!needFlattening(val: memref) || !checkLayout(val: memref)) |
| 240 | return failure(); |
| 241 | auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( |
| 242 | rewriter, op->getLoc(), memref, getIndices<T>(op)); |
| 243 | replaceOp<T>(op, rewriter, flatMemref, offset); |
| 244 | return success(); |
| 245 | } |
| 246 | }; |
| 247 | |
| 248 | struct FlattenMemrefsPass |
| 249 | : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> { |
| 250 | using Base::Base; |
| 251 | |
| 252 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 253 | registry.insert<affine::AffineDialect, arith::ArithDialect, |
| 254 | memref::MemRefDialect, vector::VectorDialect>(); |
| 255 | } |
| 256 | |
| 257 | void runOnOperation() override { |
| 258 | RewritePatternSet patterns(&getContext()); |
| 259 | |
| 260 | memref::populateFlattenMemrefsPatterns(patterns); |
| 261 | |
| 262 | if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns)))) |
| 263 | return signalPassFailure(); |
| 264 | } |
| 265 | }; |
| 266 | |
| 267 | } // namespace |
| 268 | |
| 269 | void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { |
| 270 | patterns.insert<MemRefRewritePattern<memref::LoadOp>, |
| 271 | MemRefRewritePattern<memref::StoreOp>, |
| 272 | MemRefRewritePattern<memref::AllocOp>, |
| 273 | MemRefRewritePattern<memref::AllocaOp>, |
| 274 | MemRefRewritePattern<vector::LoadOp>, |
| 275 | MemRefRewritePattern<vector::StoreOp>, |
| 276 | MemRefRewritePattern<vector::TransferReadOp>, |
| 277 | MemRefRewritePattern<vector::TransferWriteOp>, |
| 278 | MemRefRewritePattern<vector::MaskedLoadOp>, |
| 279 | MemRefRewritePattern<vector::MaskedStoreOp>>( |
| 280 | arg: patterns.getContext()); |
| 281 | } |
| 282 | |