1//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#include "mlir/Dialect/Affine/IR/AffineOps.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
13#include "mlir/Dialect/Arith/Transforms/Passes.h"
14#include "mlir/Dialect/Arith/Utils/Utils.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
17#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/OpDefinition.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/MathExtras.h"
25#include <cassert>
26#include <type_traits>
27
28using namespace mlir;
29
30//===----------------------------------------------------------------------===//
31// Utility functions
32//===----------------------------------------------------------------------===//
33
34/// Converts a memref::ReinterpretCastOp to the converted type. The result
35/// MemRefType of the old op must have a rank and stride of 1, with static
36/// offset and size. The number of bits in the offset must evenly divide the
37/// bitwidth of the new converted type.
38static LogicalResult
39convertCastingOp(ConversionPatternRewriter &rewriter,
40 memref::ReinterpretCastOp::Adaptor adaptor,
41 memref::ReinterpretCastOp op, MemRefType newTy) {
42 auto convertedElementType = newTy.getElementType();
43 auto oldElementType = op.getType().getElementType();
44 int srcBits = oldElementType.getIntOrFloatBitWidth();
45 int dstBits = convertedElementType.getIntOrFloatBitWidth();
46 if (dstBits % srcBits != 0) {
47 return rewriter.notifyMatchFailure(arg&: op,
48 msg: "only dstBits % srcBits == 0 supported");
49 }
50
51 // Only support stride of 1.
52 if (llvm::any_of(Range: op.getStaticStrides(),
53 P: [](int64_t stride) { return stride != 1; })) {
54 return rewriter.notifyMatchFailure(arg: op->getLoc(),
55 msg: "stride != 1 is not supported");
56 }
57
58 auto sizes = op.getStaticSizes();
59 int64_t offset = op.getStaticOffset(idx: 0);
60 // Only support static sizes and offsets.
61 if (llvm::is_contained(Range&: sizes, Element: ShapedType::kDynamic) ||
62 offset == ShapedType::kDynamic) {
63 return rewriter.notifyMatchFailure(
64 arg&: op, msg: "dynamic size or offset is not supported");
65 }
66
67 int elementsPerByte = dstBits / srcBits;
68 if (offset % elementsPerByte != 0) {
69 return rewriter.notifyMatchFailure(
70 arg&: op, msg: "offset not multiple of elementsPerByte is not supported");
71 }
72
73 SmallVector<int64_t> size;
74 if (sizes.size())
75 size.push_back(Elt: llvm::divideCeilSigned(Numerator: sizes[0], Denominator: elementsPerByte));
76 offset = offset / elementsPerByte;
77
78 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
79 op, args&: newTy, args: adaptor.getSource(), args&: offset, args&: size, args: op.getStaticStrides());
80 return success();
81}
82
83/// When data is loaded/stored in `targetBits` granularity, but is used in
84/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
85/// treated as an array of elements of width `sourceBits`.
86/// Return the bit offset of the value at position `srcIdx`. For example, if
87/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
88/// located at (x % 2) * 4. Because there are two elements in one i8, and one
89/// element has 4 bits.
90static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
91 int sourceBits, int targetBits,
92 OpBuilder &builder) {
93 assert(targetBits % sourceBits == 0);
94 AffineExpr s0;
95 bindSymbols(ctx: builder.getContext(), exprs&: s0);
96 int scaleFactor = targetBits / sourceBits;
97 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
98 OpFoldResult offsetVal =
99 affine::makeComposedFoldedAffineApply(b&: builder, loc, expr: offsetExpr, operands: {srcIdx});
100 Value bitOffset = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: offsetVal);
101 IntegerType dstType = builder.getIntegerType(width: targetBits);
102 return builder.create<arith::IndexCastOp>(location: loc, args&: dstType, args&: bitOffset);
103}
104
105/// When writing a subbyte size, masked bitwise operations are used to only
106/// modify the relevant bits. This function returns an and mask for clearing
107/// the destination bits in a subbyte write. E.g., when writing to the second
108/// i4 in an i32, 0xFFFFFF0F is created.
109static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
110 int64_t srcBits, int64_t dstBits,
111 Value bitwidthOffset, OpBuilder &builder) {
112 auto dstIntegerType = builder.getIntegerType(width: dstBits);
113 auto maskRightAlignedAttr =
114 builder.getIntegerAttr(type: dstIntegerType, value: (1 << srcBits) - 1);
115 Value maskRightAligned = builder.create<arith::ConstantOp>(
116 location: loc, args&: dstIntegerType, args&: maskRightAlignedAttr);
117 Value writeMaskInverse =
118 builder.create<arith::ShLIOp>(location: loc, args&: maskRightAligned, args&: bitwidthOffset);
119 auto flipValAttr = builder.getIntegerAttr(type: dstIntegerType, value: -1);
120 Value flipVal =
121 builder.create<arith::ConstantOp>(location: loc, args&: dstIntegerType, args&: flipValAttr);
122 return builder.create<arith::XOrIOp>(location: loc, args&: writeMaskInverse, args&: flipVal);
123}
124
125/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
126/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
127/// the returned index has the granularity of `dstBits`
128static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
129 OpFoldResult linearizedIndex,
130 int64_t srcBits, int64_t dstBits) {
131 AffineExpr s0;
132 bindSymbols(ctx: builder.getContext(), exprs&: s0);
133 int64_t scaler = dstBits / srcBits;
134 OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
135 b&: builder, loc, expr: s0.floorDiv(v: scaler), operands: {linearizedIndex});
136 return getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: scaledLinearizedIndices);
137}
138
139static OpFoldResult
140getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
141 const SmallVector<OpFoldResult> &indices,
142 Value memref) {
143 auto stridedMetadata =
144 builder.create<memref::ExtractStridedMetadataOp>(location: loc, args&: memref);
145 OpFoldResult linearizedIndices;
146 std::tie(args: std::ignore, args&: linearizedIndices) =
147 memref::getLinearizedMemRefOffsetAndSize(
148 builder, loc, srcBits, dstBits: srcBits,
149 offset: stridedMetadata.getConstifiedMixedOffset(),
150 sizes: stridedMetadata.getConstifiedMixedSizes(),
151 strides: stridedMetadata.getConstifiedMixedStrides(), indices);
152 return linearizedIndices;
153}
154
155namespace {
156
157//===----------------------------------------------------------------------===//
158// ConvertMemRefAllocation
159//===----------------------------------------------------------------------===//
160
161template <typename OpTy>
162struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
163 using OpConversionPattern<OpTy>::OpConversionPattern;
164
165 LogicalResult
166 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
167 ConversionPatternRewriter &rewriter) const override {
168 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169 std::is_same<OpTy, memref::AllocaOp>(),
170 "expected only memref::AllocOp or memref::AllocaOp");
171 auto currentType = cast<MemRefType>(op.getMemref().getType());
172 auto newResultType =
173 this->getTypeConverter()->template convertType<MemRefType>(
174 op.getType());
175 if (!newResultType) {
176 return rewriter.notifyMatchFailure(
177 op->getLoc(),
178 llvm::formatv("failed to convert memref type: {0}", op.getType()));
179 }
180
181 // Special case zero-rank memrefs.
182 if (currentType.getRank() == 0) {
183 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
184 adaptor.getSymbolOperands(),
185 adaptor.getAlignmentAttr());
186 return success();
187 }
188
189 Location loc = op.getLoc();
190 OpFoldResult zero = rewriter.getIndexAttr(value: 0);
191
192 // Get linearized type.
193 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
194 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
195 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
196
197 memref::LinearizedMemRefInfo linearizedMemRefInfo =
198 memref::getLinearizedMemRefOffsetAndSize(
199 builder&: rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
200 SmallVector<Value> dynamicLinearizedSize;
201 if (!newResultType.hasStaticShape()) {
202 dynamicLinearizedSize.push_back(Elt: getValueOrCreateConstantIndexOp(
203 b&: rewriter, loc, ofr: linearizedMemRefInfo.linearizedSize));
204 }
205
206 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
207 adaptor.getSymbolOperands(),
208 adaptor.getAlignmentAttr());
209 return success();
210 }
211};
212
213//===----------------------------------------------------------------------===//
214// ConvertMemRefAssumeAlignment
215//===----------------------------------------------------------------------===//
216
217struct ConvertMemRefAssumeAlignment final
218 : OpConversionPattern<memref::AssumeAlignmentOp> {
219 using OpConversionPattern::OpConversionPattern;
220
221 LogicalResult
222 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
223 ConversionPatternRewriter &rewriter) const override {
224 Type newTy = getTypeConverter()->convertType(t: op.getMemref().getType());
225 if (!newTy) {
226 return rewriter.notifyMatchFailure(
227 arg: op->getLoc(), msg: llvm::formatv(Fmt: "failed to convert memref type: {0}",
228 Vals: op.getMemref().getType()));
229 }
230
231 rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
232 op, args&: newTy, args: adaptor.getMemref(), args: adaptor.getAlignmentAttr());
233 return success();
234 }
235};
236
237//===----------------------------------------------------------------------===//
238// ConvertMemRefCopy
239//===----------------------------------------------------------------------===//
240
241struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
242 using OpConversionPattern::OpConversionPattern;
243
244 LogicalResult
245 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter) const override {
247 auto maybeRankedSource = dyn_cast<MemRefType>(Val: op.getSource().getType());
248 auto maybeRankedDest = dyn_cast<MemRefType>(Val: op.getTarget().getType());
249 if (maybeRankedSource && maybeRankedDest &&
250 maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
251 return rewriter.notifyMatchFailure(
252 arg&: op, msg: llvm::formatv(Fmt: "memref.copy emulation with distinct layouts ({0} "
253 "and {1}) is currently unimplemented",
254 Vals: maybeRankedSource.getLayout(),
255 Vals: maybeRankedDest.getLayout()));
256 rewriter.replaceOpWithNewOp<memref::CopyOp>(op, args: adaptor.getSource(),
257 args: adaptor.getTarget());
258 return success();
259 }
260};
261
262//===----------------------------------------------------------------------===//
263// ConvertMemRefDealloc
264//===----------------------------------------------------------------------===//
265
266struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
267 using OpConversionPattern::OpConversionPattern;
268
269 LogicalResult
270 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter) const override {
272 rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, args: adaptor.getMemref());
273 return success();
274 }
275};
276
277//===----------------------------------------------------------------------===//
278// ConvertMemRefLoad
279//===----------------------------------------------------------------------===//
280
281struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
282 using OpConversionPattern::OpConversionPattern;
283
284 LogicalResult
285 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter) const override {
287 auto convertedType = cast<MemRefType>(Val: adaptor.getMemref().getType());
288 auto convertedElementType = convertedType.getElementType();
289 auto oldElementType = op.getMemRefType().getElementType();
290 int srcBits = oldElementType.getIntOrFloatBitWidth();
291 int dstBits = convertedElementType.getIntOrFloatBitWidth();
292 if (dstBits % srcBits != 0) {
293 return rewriter.notifyMatchFailure(
294 arg&: op, msg: "only dstBits % srcBits == 0 supported");
295 }
296
297 Location loc = op.getLoc();
298 // Special case 0-rank memref loads.
299 Value bitsLoad;
300 if (convertedType.getRank() == 0) {
301 bitsLoad = rewriter.create<memref::LoadOp>(location: loc, args: adaptor.getMemref(),
302 args: ValueRange{});
303 } else {
304 // Linearize the indices of the original load instruction. Do not account
305 // for the scaling yet. This will be accounted for later.
306 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
307 builder&: rewriter, loc, srcBits, indices: adaptor.getIndices(), memref: op.getMemRef());
308
309 Value newLoad = rewriter.create<memref::LoadOp>(
310 location: loc, args: adaptor.getMemref(),
311 args: getIndicesForLoadOrStore(builder&: rewriter, loc, linearizedIndex: linearizedIndices, srcBits,
312 dstBits));
313
314 // Get the offset and shift the bits to the rightmost.
315 // Note, currently only the big-endian is supported.
316 Value bitwidthOffset = getOffsetForBitwidth(loc, srcIdx: linearizedIndices,
317 sourceBits: srcBits, targetBits: dstBits, builder&: rewriter);
318 bitsLoad = rewriter.create<arith::ShRSIOp>(location: loc, args&: newLoad, args&: bitwidthOffset);
319 }
320
321 // Get the corresponding bits. If the arith computation bitwidth equals
322 // to the emulated bitwidth, we apply a mask to extract the low bits.
323 // It is not clear if this case actually happens in practice, but we keep
324 // the operations just in case. Otherwise, if the arith computation bitwidth
325 // is different from the emulated bitwidth we truncate the result.
326 Value result;
327 auto resultTy = getTypeConverter()->convertType(t: oldElementType);
328 auto conversionTy =
329 resultTy.isInteger()
330 ? resultTy
331 : IntegerType::get(context: rewriter.getContext(),
332 width: resultTy.getIntOrFloatBitWidth());
333 if (conversionTy == convertedElementType) {
334 auto mask = rewriter.create<arith::ConstantOp>(
335 location: loc, args&: convertedElementType,
336 args: rewriter.getIntegerAttr(type: convertedElementType, value: (1 << srcBits) - 1));
337
338 result = rewriter.create<arith::AndIOp>(location: loc, args&: bitsLoad, args&: mask);
339 } else {
340 result = rewriter.create<arith::TruncIOp>(location: loc, args&: conversionTy, args&: bitsLoad);
341 }
342
343 if (conversionTy != resultTy) {
344 result = rewriter.create<arith::BitcastOp>(location: loc, args&: resultTy, args&: result);
345 }
346
347 rewriter.replaceOp(op, newValues: result);
348 return success();
349 }
350};
351
352//===----------------------------------------------------------------------===//
353// ConvertMemRefMemorySpaceCast
354//===----------------------------------------------------------------------===//
355
356struct ConvertMemRefMemorySpaceCast final
357 : OpConversionPattern<memref::MemorySpaceCastOp> {
358 using OpConversionPattern::OpConversionPattern;
359
360 LogicalResult
361 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter) const override {
363 Type newTy = getTypeConverter()->convertType(t: op.getDest().getType());
364 if (!newTy) {
365 return rewriter.notifyMatchFailure(
366 arg: op->getLoc(), msg: llvm::formatv(Fmt: "failed to convert memref type: {0}",
367 Vals: op.getDest().getType()));
368 }
369
370 rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, args&: newTy,
371 args: adaptor.getSource());
372 return success();
373 }
374};
375
376//===----------------------------------------------------------------------===//
377// ConvertMemRefReinterpretCast
378//===----------------------------------------------------------------------===//
379
380/// Output types should be at most one dimensional, so only the 0 or 1
381/// dimensional cases are supported.
382struct ConvertMemRefReinterpretCast final
383 : OpConversionPattern<memref::ReinterpretCastOp> {
384 using OpConversionPattern::OpConversionPattern;
385
386 LogicalResult
387 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter) const override {
389 MemRefType newTy =
390 getTypeConverter()->convertType<MemRefType>(t: op.getType());
391 if (!newTy) {
392 return rewriter.notifyMatchFailure(
393 arg: op->getLoc(),
394 msg: llvm::formatv(Fmt: "failed to convert memref type: {0}", Vals: op.getType()));
395 }
396
397 // Only support for 0 or 1 dimensional cases.
398 if (op.getType().getRank() > 1) {
399 return rewriter.notifyMatchFailure(
400 arg: op->getLoc(), msg: "subview with rank > 1 is not supported");
401 }
402
403 return convertCastingOp(rewriter, adaptor, op, newTy);
404 }
405};
406
407//===----------------------------------------------------------------------===//
408// ConvertMemrefStore
409//===----------------------------------------------------------------------===//
410
411struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
412 using OpConversionPattern::OpConversionPattern;
413
414 LogicalResult
415 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
416 ConversionPatternRewriter &rewriter) const override {
417 auto convertedType = cast<MemRefType>(Val: adaptor.getMemref().getType());
418 int srcBits = op.getMemRefType().getElementTypeBitWidth();
419 int dstBits = convertedType.getElementTypeBitWidth();
420 auto dstIntegerType = rewriter.getIntegerType(width: dstBits);
421 if (dstBits % srcBits != 0) {
422 return rewriter.notifyMatchFailure(
423 arg&: op, msg: "only dstBits % srcBits == 0 supported");
424 }
425
426 Location loc = op.getLoc();
427
428 // Pad the input value with 0s on the left.
429 Value input = adaptor.getValue();
430 if (!input.getType().isInteger()) {
431 input = rewriter.create<arith::BitcastOp>(
432 location: loc,
433 args: IntegerType::get(context: rewriter.getContext(),
434 width: input.getType().getIntOrFloatBitWidth()),
435 args&: input);
436 }
437 Value extendedInput =
438 rewriter.create<arith::ExtUIOp>(location: loc, args&: dstIntegerType, args&: input);
439
440 // Special case 0-rank memref stores. No need for masking.
441 if (convertedType.getRank() == 0) {
442 rewriter.create<memref::AtomicRMWOp>(location: loc, args: arith::AtomicRMWKind::assign,
443 args&: extendedInput, args: adaptor.getMemref(),
444 args: ValueRange{});
445 rewriter.eraseOp(op);
446 return success();
447 }
448
449 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
450 builder&: rewriter, loc, srcBits, indices: adaptor.getIndices(), memref: op.getMemRef());
451 Value storeIndices = getIndicesForLoadOrStore(
452 builder&: rewriter, loc, linearizedIndex: linearizedIndices, srcBits, dstBits);
453 Value bitwidthOffset = getOffsetForBitwidth(loc, srcIdx: linearizedIndices, sourceBits: srcBits,
454 targetBits: dstBits, builder&: rewriter);
455 Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
456 dstBits, bitwidthOffset, builder&: rewriter);
457 // Align the value to write with the destination bits
458 Value alignedVal =
459 rewriter.create<arith::ShLIOp>(location: loc, args&: extendedInput, args&: bitwidthOffset);
460
461 // Clear destination bits
462 rewriter.create<memref::AtomicRMWOp>(location: loc, args: arith::AtomicRMWKind::andi,
463 args&: writeMask, args: adaptor.getMemref(),
464 args&: storeIndices);
465 // Write srcs bits to destination
466 rewriter.create<memref::AtomicRMWOp>(location: loc, args: arith::AtomicRMWKind::ori,
467 args&: alignedVal, args: adaptor.getMemref(),
468 args&: storeIndices);
469 rewriter.eraseOp(op);
470 return success();
471 }
472};
473
474//===----------------------------------------------------------------------===//
475// ConvertMemRefSubview
476//===----------------------------------------------------------------------===//
477
478/// Emulating narrow ints on subview have limited support, supporting only
479/// static offset and size and stride of 1. Ideally, the subview should be
480/// folded away before running narrow type emulation, and this pattern should
481/// only run for cases that can't be folded.
482struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
483 using OpConversionPattern::OpConversionPattern;
484
485 LogicalResult
486 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
487 ConversionPatternRewriter &rewriter) const override {
488 MemRefType newTy =
489 getTypeConverter()->convertType<MemRefType>(t: subViewOp.getType());
490 if (!newTy) {
491 return rewriter.notifyMatchFailure(
492 arg: subViewOp->getLoc(),
493 msg: llvm::formatv(Fmt: "failed to convert memref type: {0}",
494 Vals: subViewOp.getType()));
495 }
496
497 Location loc = subViewOp.getLoc();
498 Type convertedElementType = newTy.getElementType();
499 Type oldElementType = subViewOp.getType().getElementType();
500 int srcBits = oldElementType.getIntOrFloatBitWidth();
501 int dstBits = convertedElementType.getIntOrFloatBitWidth();
502 if (dstBits % srcBits != 0)
503 return rewriter.notifyMatchFailure(
504 arg&: subViewOp, msg: "only dstBits % srcBits == 0 supported");
505
506 // Only support stride of 1.
507 if (llvm::any_of(Range: subViewOp.getStaticStrides(),
508 P: [](int64_t stride) { return stride != 1; })) {
509 return rewriter.notifyMatchFailure(arg: subViewOp->getLoc(),
510 msg: "stride != 1 is not supported");
511 }
512
513 if (!memref::isStaticShapeAndContiguousRowMajor(type: subViewOp.getType())) {
514 return rewriter.notifyMatchFailure(
515 arg&: subViewOp, msg: "the result memref type is not contiguous");
516 }
517
518 auto sizes = subViewOp.getStaticSizes();
519 int64_t lastOffset = subViewOp.getStaticOffsets().back();
520 // Only support static sizes and offsets.
521 if (llvm::is_contained(Range&: sizes, Element: ShapedType::kDynamic) ||
522 lastOffset == ShapedType::kDynamic) {
523 return rewriter.notifyMatchFailure(
524 arg: subViewOp->getLoc(), msg: "dynamic size or offset is not supported");
525 }
526
527 // Transform the offsets, sizes and strides according to the emulation.
528 auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
529 location: loc, args: subViewOp.getViewSource());
530
531 OpFoldResult linearizedIndices;
532 auto strides = stridedMetadata.getConstifiedMixedStrides();
533 memref::LinearizedMemRefInfo linearizedInfo;
534 std::tie(args&: linearizedInfo, args&: linearizedIndices) =
535 memref::getLinearizedMemRefOffsetAndSize(
536 builder&: rewriter, loc, srcBits, dstBits,
537 offset: stridedMetadata.getConstifiedMixedOffset(),
538 sizes: subViewOp.getMixedSizes(), strides,
539 indices: getMixedValues(staticValues: adaptor.getStaticOffsets(), dynamicValues: adaptor.getOffsets(),
540 b&: rewriter));
541
542 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
543 op: subViewOp, args&: newTy, args: adaptor.getSource(), args&: linearizedIndices,
544 args&: linearizedInfo.linearizedSize, args&: strides.back());
545 return success();
546 }
547};
548
549//===----------------------------------------------------------------------===//
550// ConvertMemRefCollapseShape
551//===----------------------------------------------------------------------===//
552
553/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
554/// that we flatten memrefs to a single dimension as part of the emulation and
555/// there is no dimension to collapse any further.
556struct ConvertMemRefCollapseShape final
557 : OpConversionPattern<memref::CollapseShapeOp> {
558 using OpConversionPattern::OpConversionPattern;
559
560 LogicalResult
561 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
562 ConversionPatternRewriter &rewriter) const override {
563 Value srcVal = adaptor.getSrc();
564 auto newTy = dyn_cast<MemRefType>(Val: srcVal.getType());
565 if (!newTy)
566 return failure();
567
568 if (newTy.getRank() != 1)
569 return failure();
570
571 rewriter.replaceOp(op: collapseShapeOp, newValues: srcVal);
572 return success();
573 }
574};
575
576/// Emulating a `memref.expand_shape` becomes a no-op after emulation given
577/// that we flatten memrefs to a single dimension as part of the emulation and
578/// the expansion would just have been undone.
579struct ConvertMemRefExpandShape final
580 : OpConversionPattern<memref::ExpandShapeOp> {
581 using OpConversionPattern::OpConversionPattern;
582
583 LogicalResult
584 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
585 ConversionPatternRewriter &rewriter) const override {
586 Value srcVal = adaptor.getSrc();
587 auto newTy = dyn_cast<MemRefType>(Val: srcVal.getType());
588 if (!newTy)
589 return failure();
590
591 if (newTy.getRank() != 1)
592 return failure();
593
594 rewriter.replaceOp(op: expandShapeOp, newValues: srcVal);
595 return success();
596 }
597};
598} // end anonymous namespace
599
600//===----------------------------------------------------------------------===//
601// Public Interface Definition
602//===----------------------------------------------------------------------===//
603
604void memref::populateMemRefNarrowTypeEmulationPatterns(
605 const arith::NarrowTypeEmulationConverter &typeConverter,
606 RewritePatternSet &patterns) {
607
608 // Populate `memref.*` conversion patterns.
609 patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
610 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
611 ConvertMemRefDealloc, ConvertMemRefCollapseShape,
612 ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
613 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
614 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
615 arg: typeConverter, args: patterns.getContext());
616 memref::populateResolveExtractStridedMetadataPatterns(patterns);
617}
618
619static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
620 int dstBits) {
621 if (ty.getRank() == 0)
622 return {};
623
624 int64_t linearizedShape = 1;
625 for (auto shape : ty.getShape()) {
626 if (shape == ShapedType::kDynamic)
627 return {ShapedType::kDynamic};
628 linearizedShape *= shape;
629 }
630 int scale = dstBits / srcBits;
631 // Scale the size to the ceilDiv(linearizedShape, scale)
632 // to accomodate all the values.
633 linearizedShape = (linearizedShape + scale - 1) / scale;
634 return {linearizedShape};
635}
636
637void memref::populateMemRefNarrowTypeEmulationConversions(
638 arith::NarrowTypeEmulationConverter &typeConverter) {
639 typeConverter.addConversion(
640 callback: [&typeConverter](MemRefType ty) -> std::optional<Type> {
641 Type elementType = ty.getElementType();
642 if (!elementType.isIntOrFloat())
643 return ty;
644
645 unsigned width = elementType.getIntOrFloatBitWidth();
646 unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
647 if (width >= loadStoreWidth)
648 return ty;
649
650 // Currently only handle innermost stride being 1, checking
651 SmallVector<int64_t> strides;
652 int64_t offset;
653 if (failed(Result: ty.getStridesAndOffset(strides, offset)))
654 return nullptr;
655 if (!strides.empty() && strides.back() != 1)
656 return nullptr;
657
658 auto newElemTy = IntegerType::get(
659 context: ty.getContext(), width: loadStoreWidth,
660 signedness: elementType.isInteger()
661 ? cast<IntegerType>(Val&: elementType).getSignedness()
662 : IntegerType::SignednessSemantics::Signless);
663 if (!newElemTy)
664 return nullptr;
665
666 StridedLayoutAttr layoutAttr;
667 // If the offset is 0, we do not need a strided layout as the stride is
668 // 1, so we only use the strided layout if the offset is not 0.
669 if (offset != 0) {
670 if (offset == ShapedType::kDynamic) {
671 layoutAttr = StridedLayoutAttr::get(context: ty.getContext(), offset,
672 strides: ArrayRef<int64_t>{1});
673 } else {
674 // Check if the number of bytes are a multiple of the loadStoreWidth
675 // and if so, divide it by the loadStoreWidth to get the offset.
676 if ((offset * width) % loadStoreWidth != 0)
677 return std::nullopt;
678 offset = (offset * width) / loadStoreWidth;
679
680 layoutAttr = StridedLayoutAttr::get(context: ty.getContext(), offset,
681 strides: ArrayRef<int64_t>{1});
682 }
683 }
684
685 return MemRefType::get(shape: getLinearizedShape(ty, srcBits: width, dstBits: loadStoreWidth),
686 elementType: newElemTy, layout: layoutAttr, memorySpace: ty.getMemorySpace());
687 });
688}
689

source code of mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp