1//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#include "mlir/Dialect/Affine/IR/AffineOps.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
13#include "mlir/Dialect/Arith/Transforms/Passes.h"
14#include "mlir/Dialect/Arith/Utils/Utils.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
17#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/OpDefinition.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/Support/FormatVariadic.h"
24#include "llvm/Support/MathExtras.h"
25#include <cassert>
26#include <type_traits>
27
28using namespace mlir;
29
30//===----------------------------------------------------------------------===//
31// Utility functions
32//===----------------------------------------------------------------------===//
33
34/// Converts a memref::ReinterpretCastOp to the converted type. The result
35/// MemRefType of the old op must have a rank and stride of 1, with static
36/// offset and size. The number of bits in the offset must evenly divide the
37/// bitwidth of the new converted type.
38static LogicalResult
39convertCastingOp(ConversionPatternRewriter &rewriter,
40 memref::ReinterpretCastOp::Adaptor adaptor,
41 memref::ReinterpretCastOp op, MemRefType newTy) {
42 auto convertedElementType = newTy.getElementType();
43 auto oldElementType = op.getType().getElementType();
44 int srcBits = oldElementType.getIntOrFloatBitWidth();
45 int dstBits = convertedElementType.getIntOrFloatBitWidth();
46 if (dstBits % srcBits != 0) {
47 return rewriter.notifyMatchFailure(op,
48 "only dstBits % srcBits == 0 supported");
49 }
50
51 // Only support stride of 1.
52 if (llvm::any_of(op.getStaticStrides(),
53 [](int64_t stride) { return stride != 1; })) {
54 return rewriter.notifyMatchFailure(op->getLoc(),
55 "stride != 1 is not supported");
56 }
57
58 auto sizes = op.getStaticSizes();
59 int64_t offset = op.getStaticOffset(0);
60 // Only support static sizes and offsets.
61 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62 offset == ShapedType::kDynamic) {
63 return rewriter.notifyMatchFailure(
64 op, "dynamic size or offset is not supported");
65 }
66
67 int elementsPerByte = dstBits / srcBits;
68 if (offset % elementsPerByte != 0) {
69 return rewriter.notifyMatchFailure(
70 op, "offset not multiple of elementsPerByte is not supported");
71 }
72
73 SmallVector<int64_t> size;
74 if (sizes.size())
75 size.push_back(Elt: llvm::divideCeilSigned(sizes[0], elementsPerByte));
76 offset = offset / elementsPerByte;
77
78 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
79 op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
80 return success();
81}
82
83/// When data is loaded/stored in `targetBits` granularity, but is used in
84/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
85/// treated as an array of elements of width `sourceBits`.
86/// Return the bit offset of the value at position `srcIdx`. For example, if
87/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
88/// located at (x % 2) * 4. Because there are two elements in one i8, and one
89/// element has 4 bits.
90static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
91 int sourceBits, int targetBits,
92 OpBuilder &builder) {
93 assert(targetBits % sourceBits == 0);
94 AffineExpr s0;
95 bindSymbols(ctx: builder.getContext(), exprs&: s0);
96 int scaleFactor = targetBits / sourceBits;
97 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
98 OpFoldResult offsetVal =
99 affine::makeComposedFoldedAffineApply(b&: builder, loc, expr: offsetExpr, operands: {srcIdx});
100 Value bitOffset = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: offsetVal);
101 IntegerType dstType = builder.getIntegerType(targetBits);
102 return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
103}
104
105/// When writing a subbyte size, masked bitwise operations are used to only
106/// modify the relevant bits. This function returns an and mask for clearing
107/// the destination bits in a subbyte write. E.g., when writing to the second
108/// i4 in an i32, 0xFFFFFF0F is created.
109static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
110 int64_t srcBits, int64_t dstBits,
111 Value bitwidthOffset, OpBuilder &builder) {
112 auto dstIntegerType = builder.getIntegerType(dstBits);
113 auto maskRightAlignedAttr =
114 builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
115 Value maskRightAligned = builder.create<arith::ConstantOp>(
116 loc, dstIntegerType, maskRightAlignedAttr);
117 Value writeMaskInverse =
118 builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
119 auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
120 Value flipVal =
121 builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
122 return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
123}
124
125/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
126/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
127/// the returned index has the granularity of `dstBits`
128static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
129 OpFoldResult linearizedIndex,
130 int64_t srcBits, int64_t dstBits) {
131 AffineExpr s0;
132 bindSymbols(ctx: builder.getContext(), exprs&: s0);
133 int64_t scaler = dstBits / srcBits;
134 OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
135 b&: builder, loc, expr: s0.floorDiv(v: scaler), operands: {linearizedIndex});
136 return getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: scaledLinearizedIndices);
137}
138
139static OpFoldResult
140getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
141 const SmallVector<OpFoldResult> &indices,
142 Value memref) {
143 auto stridedMetadata =
144 builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
145 OpFoldResult linearizedIndices;
146 std::tie(args: std::ignore, args&: linearizedIndices) =
147 memref::getLinearizedMemRefOffsetAndSize(
148 builder, loc, srcBits, srcBits,
149 stridedMetadata.getConstifiedMixedOffset(),
150 stridedMetadata.getConstifiedMixedSizes(),
151 stridedMetadata.getConstifiedMixedStrides(), indices);
152 return linearizedIndices;
153}
154
155namespace {
156
157//===----------------------------------------------------------------------===//
158// ConvertMemRefAllocation
159//===----------------------------------------------------------------------===//
160
161template <typename OpTy>
162struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
163 using OpConversionPattern<OpTy>::OpConversionPattern;
164
165 LogicalResult
166 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
167 ConversionPatternRewriter &rewriter) const override {
168 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169 std::is_same<OpTy, memref::AllocaOp>(),
170 "expected only memref::AllocOp or memref::AllocaOp");
171 auto currentType = cast<MemRefType>(op.getMemref().getType());
172 auto newResultType =
173 this->getTypeConverter()->template convertType<MemRefType>(
174 op.getType());
175 if (!newResultType) {
176 return rewriter.notifyMatchFailure(
177 op->getLoc(),
178 llvm::formatv("failed to convert memref type: {0}", op.getType()));
179 }
180
181 // Special case zero-rank memrefs.
182 if (currentType.getRank() == 0) {
183 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
184 adaptor.getSymbolOperands(),
185 adaptor.getAlignmentAttr());
186 return success();
187 }
188
189 Location loc = op.getLoc();
190 OpFoldResult zero = rewriter.getIndexAttr(0);
191
192 // Get linearized type.
193 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
194 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
195 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
196
197 memref::LinearizedMemRefInfo linearizedMemRefInfo =
198 memref::getLinearizedMemRefOffsetAndSize(
199 builder&: rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
200 SmallVector<Value> dynamicLinearizedSize;
201 if (!newResultType.hasStaticShape()) {
202 dynamicLinearizedSize.push_back(Elt: getValueOrCreateConstantIndexOp(
203 b&: rewriter, loc, ofr: linearizedMemRefInfo.linearizedSize));
204 }
205
206 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
207 adaptor.getSymbolOperands(),
208 adaptor.getAlignmentAttr());
209 return success();
210 }
211};
212
213//===----------------------------------------------------------------------===//
214// ConvertMemRefAssumeAlignment
215//===----------------------------------------------------------------------===//
216
217struct ConvertMemRefAssumeAlignment final
218 : OpConversionPattern<memref::AssumeAlignmentOp> {
219 using OpConversionPattern::OpConversionPattern;
220
221 LogicalResult
222 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
223 ConversionPatternRewriter &rewriter) const override {
224 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
225 if (!newTy) {
226 return rewriter.notifyMatchFailure(
227 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
228 op.getMemref().getType()));
229 }
230
231 rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
232 op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
233 return success();
234 }
235};
236
237//===----------------------------------------------------------------------===//
238// ConvertMemRefCopy
239//===----------------------------------------------------------------------===//
240
241struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
242 using OpConversionPattern::OpConversionPattern;
243
244 LogicalResult
245 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter) const override {
247 auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
248 auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
249 if (maybeRankedSource && maybeRankedDest &&
250 maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
251 return rewriter.notifyMatchFailure(
252 op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
253 "and {1}) is currently unimplemented",
254 maybeRankedSource.getLayout(),
255 maybeRankedDest.getLayout()));
256 rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
257 adaptor.getTarget());
258 return success();
259 }
260};
261
262//===----------------------------------------------------------------------===//
263// ConvertMemRefDealloc
264//===----------------------------------------------------------------------===//
265
266struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
267 using OpConversionPattern::OpConversionPattern;
268
269 LogicalResult
270 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter) const override {
272 rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
273 return success();
274 }
275};
276
277//===----------------------------------------------------------------------===//
278// ConvertMemRefLoad
279//===----------------------------------------------------------------------===//
280
281struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
282 using OpConversionPattern::OpConversionPattern;
283
284 LogicalResult
285 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
286 ConversionPatternRewriter &rewriter) const override {
287 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
288 auto convertedElementType = convertedType.getElementType();
289 auto oldElementType = op.getMemRefType().getElementType();
290 int srcBits = oldElementType.getIntOrFloatBitWidth();
291 int dstBits = convertedElementType.getIntOrFloatBitWidth();
292 if (dstBits % srcBits != 0) {
293 return rewriter.notifyMatchFailure(
294 op, "only dstBits % srcBits == 0 supported");
295 }
296
297 Location loc = op.getLoc();
298 // Special case 0-rank memref loads.
299 Value bitsLoad;
300 if (convertedType.getRank() == 0) {
301 bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
302 ValueRange{});
303 } else {
304 // Linearize the indices of the original load instruction. Do not account
305 // for the scaling yet. This will be accounted for later.
306 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
307 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
308
309 Value newLoad = rewriter.create<memref::LoadOp>(
310 loc, adaptor.getMemref(),
311 getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
312 dstBits));
313
314 // Get the offset and shift the bits to the rightmost.
315 // Note, currently only the big-endian is supported.
316 Value bitwidthOffset = getOffsetForBitwidth(loc, srcIdx: linearizedIndices,
317 sourceBits: srcBits, targetBits: dstBits, builder&: rewriter);
318 bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
319 }
320
321 // Get the corresponding bits. If the arith computation bitwidth equals
322 // to the emulated bitwidth, we apply a mask to extract the low bits.
323 // It is not clear if this case actually happens in practice, but we keep
324 // the operations just in case. Otherwise, if the arith computation bitwidth
325 // is different from the emulated bitwidth we truncate the result.
326 Operation *result;
327 auto resultTy = getTypeConverter()->convertType(oldElementType);
328 if (resultTy == convertedElementType) {
329 auto mask = rewriter.create<arith::ConstantOp>(
330 loc, convertedElementType,
331 rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
332
333 result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
334 } else {
335 result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
336 }
337
338 rewriter.replaceOp(op, result->getResult(idx: 0));
339 return success();
340 }
341};
342
343//===----------------------------------------------------------------------===//
344// ConvertMemRefMemorySpaceCast
345//===----------------------------------------------------------------------===//
346
347struct ConvertMemRefMemorySpaceCast final
348 : OpConversionPattern<memref::MemorySpaceCastOp> {
349 using OpConversionPattern::OpConversionPattern;
350
351 LogicalResult
352 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter) const override {
354 Type newTy = getTypeConverter()->convertType(op.getDest().getType());
355 if (!newTy) {
356 return rewriter.notifyMatchFailure(
357 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
358 op.getDest().getType()));
359 }
360
361 rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
362 adaptor.getSource());
363 return success();
364 }
365};
366
367//===----------------------------------------------------------------------===//
368// ConvertMemRefReinterpretCast
369//===----------------------------------------------------------------------===//
370
371/// Output types should be at most one dimensional, so only the 0 or 1
372/// dimensional cases are supported.
373struct ConvertMemRefReinterpretCast final
374 : OpConversionPattern<memref::ReinterpretCastOp> {
375 using OpConversionPattern::OpConversionPattern;
376
377 LogicalResult
378 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
379 ConversionPatternRewriter &rewriter) const override {
380 MemRefType newTy =
381 getTypeConverter()->convertType<MemRefType>(op.getType());
382 if (!newTy) {
383 return rewriter.notifyMatchFailure(
384 op->getLoc(),
385 llvm::formatv("failed to convert memref type: {0}", op.getType()));
386 }
387
388 // Only support for 0 or 1 dimensional cases.
389 if (op.getType().getRank() > 1) {
390 return rewriter.notifyMatchFailure(
391 op->getLoc(), "subview with rank > 1 is not supported");
392 }
393
394 return convertCastingOp(rewriter, adaptor, op, newTy);
395 }
396};
397
398//===----------------------------------------------------------------------===//
399// ConvertMemrefStore
400//===----------------------------------------------------------------------===//
401
402struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
403 using OpConversionPattern::OpConversionPattern;
404
405 LogicalResult
406 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter) const override {
408 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
409 int srcBits = op.getMemRefType().getElementTypeBitWidth();
410 int dstBits = convertedType.getElementTypeBitWidth();
411 auto dstIntegerType = rewriter.getIntegerType(dstBits);
412 if (dstBits % srcBits != 0) {
413 return rewriter.notifyMatchFailure(
414 op, "only dstBits % srcBits == 0 supported");
415 }
416
417 Location loc = op.getLoc();
418 Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
419 adaptor.getValue());
420
421 // Special case 0-rank memref stores. No need for masking.
422 if (convertedType.getRank() == 0) {
423 rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
424 extendedInput, adaptor.getMemref(),
425 ValueRange{});
426 rewriter.eraseOp(op: op);
427 return success();
428 }
429
430 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
431 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
432 Value storeIndices = getIndicesForLoadOrStore(
433 builder&: rewriter, loc, linearizedIndex: linearizedIndices, srcBits, dstBits);
434 Value bitwidthOffset = getOffsetForBitwidth(loc, srcIdx: linearizedIndices, sourceBits: srcBits,
435 targetBits: dstBits, builder&: rewriter);
436 Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
437 dstBits, bitwidthOffset, builder&: rewriter);
438 // Align the value to write with the destination bits
439 Value alignedVal =
440 rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
441
442 // Clear destination bits
443 rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
444 writeMask, adaptor.getMemref(),
445 storeIndices);
446 // Write srcs bits to destination
447 rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
448 alignedVal, adaptor.getMemref(),
449 storeIndices);
450 rewriter.eraseOp(op: op);
451 return success();
452 }
453};
454
455//===----------------------------------------------------------------------===//
456// ConvertMemRefSubview
457//===----------------------------------------------------------------------===//
458
459/// Emulating narrow ints on subview have limited support, supporting only
460/// static offset and size and stride of 1. Ideally, the subview should be
461/// folded away before running narrow type emulation, and this pattern should
462/// only run for cases that can't be folded.
463struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
464 using OpConversionPattern::OpConversionPattern;
465
466 LogicalResult
467 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter) const override {
469 MemRefType newTy =
470 getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
471 if (!newTy) {
472 return rewriter.notifyMatchFailure(
473 subViewOp->getLoc(),
474 llvm::formatv("failed to convert memref type: {0}",
475 subViewOp.getType()));
476 }
477
478 Location loc = subViewOp.getLoc();
479 Type convertedElementType = newTy.getElementType();
480 Type oldElementType = subViewOp.getType().getElementType();
481 int srcBits = oldElementType.getIntOrFloatBitWidth();
482 int dstBits = convertedElementType.getIntOrFloatBitWidth();
483 if (dstBits % srcBits != 0)
484 return rewriter.notifyMatchFailure(
485 subViewOp, "only dstBits % srcBits == 0 supported");
486
487 // Only support stride of 1.
488 if (llvm::any_of(subViewOp.getStaticStrides(),
489 [](int64_t stride) { return stride != 1; })) {
490 return rewriter.notifyMatchFailure(subViewOp->getLoc(),
491 "stride != 1 is not supported");
492 }
493
494 if (!memref::isStaticShapeAndContiguousRowMajor(type: subViewOp.getType())) {
495 return rewriter.notifyMatchFailure(
496 subViewOp, "the result memref type is not contiguous");
497 }
498
499 auto sizes = subViewOp.getStaticSizes();
500 int64_t lastOffset = subViewOp.getStaticOffsets().back();
501 // Only support static sizes and offsets.
502 if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
503 lastOffset == ShapedType::kDynamic) {
504 return rewriter.notifyMatchFailure(
505 subViewOp->getLoc(), "dynamic size or offset is not supported");
506 }
507
508 // Transform the offsets, sizes and strides according to the emulation.
509 auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
510 loc, subViewOp.getViewSource());
511
512 OpFoldResult linearizedIndices;
513 auto strides = stridedMetadata.getConstifiedMixedStrides();
514 memref::LinearizedMemRefInfo linearizedInfo;
515 std::tie(args&: linearizedInfo, args&: linearizedIndices) =
516 memref::getLinearizedMemRefOffsetAndSize(
517 rewriter, loc, srcBits, dstBits,
518 stridedMetadata.getConstifiedMixedOffset(),
519 subViewOp.getMixedSizes(), strides,
520 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
521 rewriter));
522
523 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
524 subViewOp, newTy, adaptor.getSource(), linearizedIndices,
525 linearizedInfo.linearizedSize, strides.back());
526 return success();
527 }
528};
529
530//===----------------------------------------------------------------------===//
531// ConvertMemRefCollapseShape
532//===----------------------------------------------------------------------===//
533
534/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
535/// that we flatten memrefs to a single dimension as part of the emulation and
536/// there is no dimension to collapse any further.
537struct ConvertMemRefCollapseShape final
538 : OpConversionPattern<memref::CollapseShapeOp> {
539 using OpConversionPattern::OpConversionPattern;
540
541 LogicalResult
542 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
543 ConversionPatternRewriter &rewriter) const override {
544 Value srcVal = adaptor.getSrc();
545 auto newTy = dyn_cast<MemRefType>(srcVal.getType());
546 if (!newTy)
547 return failure();
548
549 if (newTy.getRank() != 1)
550 return failure();
551
552 rewriter.replaceOp(collapseShapeOp, srcVal);
553 return success();
554 }
555};
556
557/// Emulating a `memref.expand_shape` becomes a no-op after emulation given
558/// that we flatten memrefs to a single dimension as part of the emulation and
559/// the expansion would just have been undone.
560struct ConvertMemRefExpandShape final
561 : OpConversionPattern<memref::ExpandShapeOp> {
562 using OpConversionPattern::OpConversionPattern;
563
564 LogicalResult
565 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
566 ConversionPatternRewriter &rewriter) const override {
567 Value srcVal = adaptor.getSrc();
568 auto newTy = dyn_cast<MemRefType>(srcVal.getType());
569 if (!newTy)
570 return failure();
571
572 if (newTy.getRank() != 1)
573 return failure();
574
575 rewriter.replaceOp(expandShapeOp, srcVal);
576 return success();
577 }
578};
579} // end anonymous namespace
580
581//===----------------------------------------------------------------------===//
582// Public Interface Definition
583//===----------------------------------------------------------------------===//
584
585void memref::populateMemRefNarrowTypeEmulationPatterns(
586 const arith::NarrowTypeEmulationConverter &typeConverter,
587 RewritePatternSet &patterns) {
588
589 // Populate `memref.*` conversion patterns.
590 patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
591 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
592 ConvertMemRefDealloc, ConvertMemRefCollapseShape,
593 ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
594 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
595 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
596 arg: typeConverter, args: patterns.getContext());
597 memref::populateResolveExtractStridedMetadataPatterns(patterns);
598}
599
600static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
601 int dstBits) {
602 if (ty.getRank() == 0)
603 return {};
604
605 int64_t linearizedShape = 1;
606 for (auto shape : ty.getShape()) {
607 if (shape == ShapedType::kDynamic)
608 return {ShapedType::kDynamic};
609 linearizedShape *= shape;
610 }
611 int scale = dstBits / srcBits;
612 // Scale the size to the ceilDiv(linearizedShape, scale)
613 // to accomodate all the values.
614 linearizedShape = (linearizedShape + scale - 1) / scale;
615 return {linearizedShape};
616}
617
618void memref::populateMemRefNarrowTypeEmulationConversions(
619 arith::NarrowTypeEmulationConverter &typeConverter) {
620 typeConverter.addConversion(
621 callback: [&typeConverter](MemRefType ty) -> std::optional<Type> {
622 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
623 if (!intTy)
624 return ty;
625
626 unsigned width = intTy.getWidth();
627 unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
628 if (width >= loadStoreWidth)
629 return ty;
630
631 // Currently only handle innermost stride being 1, checking
632 SmallVector<int64_t> strides;
633 int64_t offset;
634 if (failed(ty.getStridesAndOffset(strides, offset)))
635 return nullptr;
636 if (!strides.empty() && strides.back() != 1)
637 return nullptr;
638
639 auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
640 intTy.getSignedness());
641 if (!newElemTy)
642 return nullptr;
643
644 StridedLayoutAttr layoutAttr;
645 // If the offset is 0, we do not need a strided layout as the stride is
646 // 1, so we only use the strided layout if the offset is not 0.
647 if (offset != 0) {
648 if (offset == ShapedType::kDynamic) {
649 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
650 ArrayRef<int64_t>{1});
651 } else {
652 // Check if the number of bytes are a multiple of the loadStoreWidth
653 // and if so, divide it by the loadStoreWidth to get the offset.
654 if ((offset * width) % loadStoreWidth != 0)
655 return std::nullopt;
656 offset = (offset * width) / loadStoreWidth;
657
658 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
659 ArrayRef<int64_t>{1});
660 }
661 }
662
663 return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
664 newElemTy, layoutAttr, ty.getMemorySpace());
665 });
666}
667

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp