1//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#include "mlir/Dialect/Affine/IR/AffineOps.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
13#include "mlir/Dialect/Arith/Transforms/Passes.h"
14#include "mlir/Dialect/Arith/Utils/Utils.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/MemRef/Transforms/Passes.h"
17#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
18#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19#include "mlir/Dialect/Vector/IR/VectorOps.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/OpDefinition.h"
23#include "mlir/Support/LogicalResult.h"
24#include "mlir/Support/MathExtras.h"
25#include "mlir/Transforms/DialectConversion.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/Support/MathExtras.h"
28#include <cassert>
29#include <type_traits>
30
31using namespace mlir;
32
33//===----------------------------------------------------------------------===//
34// Utility functions
35//===----------------------------------------------------------------------===//
36
37/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
38/// type. The result MemRefType of the old op must have a rank and stride of 1,
39/// with static offset and size. The number of bits in the offset must evenly
40/// divide the bitwidth of the new converted type.
41template <typename MemRefOpTy>
42static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
43 typename MemRefOpTy::Adaptor adaptor,
44 MemRefOpTy op, MemRefType newTy) {
45 static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
46 std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
47 "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
48
49 auto convertedElementType = newTy.getElementType();
50 auto oldElementType = op.getType().getElementType();
51 int srcBits = oldElementType.getIntOrFloatBitWidth();
52 int dstBits = convertedElementType.getIntOrFloatBitWidth();
53 if (dstBits % srcBits != 0) {
54 return rewriter.notifyMatchFailure(op,
55 "only dstBits % srcBits == 0 supported");
56 }
57
58 // Only support stride of 1.
59 if (llvm::any_of(op.getStaticStrides(),
60 [](int64_t stride) { return stride != 1; })) {
61 return rewriter.notifyMatchFailure(op->getLoc(),
62 "stride != 1 is not supported");
63 }
64
65 auto sizes = op.getStaticSizes();
66 int64_t offset = op.getStaticOffset(0);
67 // Only support static sizes and offsets.
68 if (llvm::any_of(sizes,
69 [](int64_t size) { return size == ShapedType::kDynamic; }) ||
70 offset == ShapedType::kDynamic) {
71 return rewriter.notifyMatchFailure(
72 op->getLoc(), "dynamic size or offset is not supported");
73 }
74
75 int elementsPerByte = dstBits / srcBits;
76 if (offset % elementsPerByte != 0) {
77 return rewriter.notifyMatchFailure(
78 op->getLoc(), "offset not multiple of elementsPerByte is not "
79 "supported");
80 }
81
82 SmallVector<int64_t> size;
83 if (sizes.size())
84 size.push_back(Elt: ceilDiv(sizes[0], elementsPerByte));
85 offset = offset / elementsPerByte;
86
87 rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
88 *adaptor.getODSOperands(0).begin(),
89 offset, size, op.getStaticStrides());
90 return success();
91}
92
93/// When data is loaded/stored in `targetBits` granularity, but is used in
94/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
95/// treated as an array of elements of width `sourceBits`.
96/// Return the bit offset of the value at position `srcIdx`. For example, if
97/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
98/// located at (x % 2) * 4. Because there are two elements in one i8, and one
99/// element has 4 bits.
100static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
101 int sourceBits, int targetBits,
102 OpBuilder &builder) {
103 assert(targetBits % sourceBits == 0);
104 AffineExpr s0;
105 bindSymbols(ctx: builder.getContext(), exprs&: s0);
106 int scaleFactor = targetBits / sourceBits;
107 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
108 OpFoldResult offsetVal =
109 affine::makeComposedFoldedAffineApply(b&: builder, loc, expr: offsetExpr, operands: {srcIdx});
110 Value bitOffset = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: offsetVal);
111 IntegerType dstType = builder.getIntegerType(targetBits);
112 return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
113}
114
115/// When writing a subbyte size, masked bitwise operations are used to only
116/// modify the relevant bits. This function returns an and mask for clearing
117/// the destination bits in a subbyte write. E.g., when writing to the second
118/// i4 in an i32, 0xFFFFFF0F is created.
119static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
120 int64_t srcBits, int64_t dstBits,
121 Value bitwidthOffset, OpBuilder &builder) {
122 auto dstIntegerType = builder.getIntegerType(dstBits);
123 auto maskRightAlignedAttr =
124 builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
125 Value maskRightAligned = builder.create<arith::ConstantOp>(
126 loc, dstIntegerType, maskRightAlignedAttr);
127 Value writeMaskInverse =
128 builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
129 auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
130 Value flipVal =
131 builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
132 return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
133}
134
135/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
136/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
137/// the returned index has the granularity of `dstBits`
138static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
139 OpFoldResult linearizedIndex,
140 int64_t srcBits, int64_t dstBits) {
141 AffineExpr s0;
142 bindSymbols(ctx: builder.getContext(), exprs&: s0);
143 int64_t scaler = dstBits / srcBits;
144 OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
145 b&: builder, loc, expr: s0.floorDiv(v: scaler), operands: {linearizedIndex});
146 return getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: scaledLinearizedIndices);
147}
148
149static OpFoldResult
150getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
151 const SmallVector<OpFoldResult> &indices,
152 Value memref) {
153 auto stridedMetadata =
154 builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
155 OpFoldResult linearizedIndices;
156 std::tie(args: std::ignore, args&: linearizedIndices) =
157 memref::getLinearizedMemRefOffsetAndSize(
158 builder, loc, srcBits, srcBits,
159 stridedMetadata.getConstifiedMixedOffset(),
160 stridedMetadata.getConstifiedMixedSizes(),
161 stridedMetadata.getConstifiedMixedStrides(), indices);
162 return linearizedIndices;
163}
164
165namespace {
166
167//===----------------------------------------------------------------------===//
168// ConvertMemRefAllocation
169//===----------------------------------------------------------------------===//
170
171template <typename OpTy>
172struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
173 using OpConversionPattern<OpTy>::OpConversionPattern;
174
175 LogicalResult
176 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
177 ConversionPatternRewriter &rewriter) const override {
178 static_assert(std::is_same<OpTy, memref::AllocOp>() ||
179 std::is_same<OpTy, memref::AllocaOp>(),
180 "expected only memref::AllocOp or memref::AllocaOp");
181 auto currentType = cast<MemRefType>(op.getMemref().getType());
182 auto newResultType = dyn_cast<MemRefType>(
183 this->getTypeConverter()->convertType(op.getType()));
184 if (!newResultType) {
185 return rewriter.notifyMatchFailure(
186 op->getLoc(),
187 llvm::formatv("failed to convert memref type: {0}", op.getType()));
188 }
189
190 // Special case zero-rank memrefs.
191 if (currentType.getRank() == 0) {
192 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
193 adaptor.getSymbolOperands(),
194 adaptor.getAlignmentAttr());
195 return success();
196 }
197
198 Location loc = op.getLoc();
199 OpFoldResult zero = rewriter.getIndexAttr(0);
200 SmallVector<OpFoldResult> indices(currentType.getRank(), zero);
201
202 // Get linearized type.
203 int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
204 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
205 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
206
207 memref::LinearizedMemRefInfo linearizedMemRefInfo =
208 memref::getLinearizedMemRefOffsetAndSize(
209 builder&: rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
210 SmallVector<Value> dynamicLinearizedSize;
211 if (!newResultType.hasStaticShape()) {
212 dynamicLinearizedSize.push_back(Elt: getValueOrCreateConstantIndexOp(
213 b&: rewriter, loc, ofr: linearizedMemRefInfo.linearizedSize));
214 }
215
216 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
217 adaptor.getSymbolOperands(),
218 adaptor.getAlignmentAttr());
219 return success();
220 }
221};
222
223//===----------------------------------------------------------------------===//
224// ConvertMemRefAssumeAlignment
225//===----------------------------------------------------------------------===//
226
227struct ConvertMemRefAssumeAlignment final
228 : OpConversionPattern<memref::AssumeAlignmentOp> {
229 using OpConversionPattern::OpConversionPattern;
230
231 LogicalResult
232 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
233 ConversionPatternRewriter &rewriter) const override {
234 Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
235 if (!newTy) {
236 return rewriter.notifyMatchFailure(
237 op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
238 op.getMemref().getType()));
239 }
240
241 rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
242 op, adaptor.getMemref(), adaptor.getAlignmentAttr());
243 return success();
244 }
245};
246
247//===----------------------------------------------------------------------===//
248// ConvertMemRefLoad
249//===----------------------------------------------------------------------===//
250
251struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
252 using OpConversionPattern::OpConversionPattern;
253
254 LogicalResult
255 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter) const override {
257 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
258 auto convertedElementType = convertedType.getElementType();
259 auto oldElementType = op.getMemRefType().getElementType();
260 int srcBits = oldElementType.getIntOrFloatBitWidth();
261 int dstBits = convertedElementType.getIntOrFloatBitWidth();
262 if (dstBits % srcBits != 0) {
263 return rewriter.notifyMatchFailure(
264 op, "only dstBits % srcBits == 0 supported");
265 }
266
267 Location loc = op.getLoc();
268 // Special case 0-rank memref loads.
269 Value bitsLoad;
270 if (convertedType.getRank() == 0) {
271 bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
272 ValueRange{});
273 } else {
274 // Linearize the indices of the original load instruction. Do not account
275 // for the scaling yet. This will be accounted for later.
276 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
277 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
278
279 Value newLoad = rewriter.create<memref::LoadOp>(
280 loc, adaptor.getMemref(),
281 getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
282 dstBits));
283
284 // Get the offset and shift the bits to the rightmost.
285 // Note, currently only the big-endian is supported.
286 Value bitwidthOffset = getOffsetForBitwidth(loc, srcIdx: linearizedIndices,
287 sourceBits: srcBits, targetBits: dstBits, builder&: rewriter);
288 bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
289 }
290
291 // Get the corresponding bits. If the arith computation bitwidth equals
292 // to the emulated bitwidth, we apply a mask to extract the low bits.
293 // It is not clear if this case actually happens in practice, but we keep
294 // the operations just in case. Otherwise, if the arith computation bitwidth
295 // is different from the emulated bitwidth we truncate the result.
296 Operation *result;
297 auto resultTy = getTypeConverter()->convertType(oldElementType);
298 if (resultTy == convertedElementType) {
299 auto mask = rewriter.create<arith::ConstantOp>(
300 loc, convertedElementType,
301 rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
302
303 result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
304 } else {
305 result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
306 }
307
308 rewriter.replaceOp(op, result->getResult(idx: 0));
309 return success();
310 }
311};
312
313//===----------------------------------------------------------------------===//
314// ConvertMemRefReinterpretCast
315//===----------------------------------------------------------------------===//
316
317/// Output types should be at most one dimensional, so only the 0 or 1
318/// dimensional cases are supported.
319struct ConvertMemRefReinterpretCast final
320 : OpConversionPattern<memref::ReinterpretCastOp> {
321 using OpConversionPattern::OpConversionPattern;
322
323 LogicalResult
324 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
325 ConversionPatternRewriter &rewriter) const override {
326 MemRefType newTy =
327 dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
328 if (!newTy) {
329 return rewriter.notifyMatchFailure(
330 op->getLoc(),
331 llvm::formatv("failed to convert memref type: {0}", op.getType()));
332 }
333
334 // Only support for 0 or 1 dimensional cases.
335 if (op.getType().getRank() > 1) {
336 return rewriter.notifyMatchFailure(
337 op->getLoc(), "subview with rank > 1 is not supported");
338 }
339
340 return convertCastingOp(rewriter, adaptor, op, newTy);
341 }
342};
343
344//===----------------------------------------------------------------------===//
345// ConvertMemrefStore
346//===----------------------------------------------------------------------===//
347
348struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
349 using OpConversionPattern::OpConversionPattern;
350
351 LogicalResult
352 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter) const override {
354 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
355 int srcBits = op.getMemRefType().getElementTypeBitWidth();
356 int dstBits = convertedType.getElementTypeBitWidth();
357 auto dstIntegerType = rewriter.getIntegerType(dstBits);
358 if (dstBits % srcBits != 0) {
359 return rewriter.notifyMatchFailure(
360 op, "only dstBits % srcBits == 0 supported");
361 }
362
363 Location loc = op.getLoc();
364 Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
365 adaptor.getValue());
366
367 // Special case 0-rank memref stores. No need for masking.
368 if (convertedType.getRank() == 0) {
369 rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
370 extendedInput, adaptor.getMemref(),
371 ValueRange{});
372 rewriter.eraseOp(op: op);
373 return success();
374 }
375
376 OpFoldResult linearizedIndices = getLinearizedSrcIndices(
377 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
378 Value storeIndices = getIndicesForLoadOrStore(
379 builder&: rewriter, loc, linearizedIndex: linearizedIndices, srcBits, dstBits);
380 Value bitwidthOffset = getOffsetForBitwidth(loc, srcIdx: linearizedIndices, sourceBits: srcBits,
381 targetBits: dstBits, builder&: rewriter);
382 Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
383 dstBits, bitwidthOffset, builder&: rewriter);
384 // Align the value to write with the destination bits
385 Value alignedVal =
386 rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
387
388 // Clear destination bits
389 rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
390 writeMask, adaptor.getMemref(),
391 storeIndices);
392 // Write srcs bits to destination
393 rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
394 alignedVal, adaptor.getMemref(),
395 storeIndices);
396 rewriter.eraseOp(op: op);
397 return success();
398 }
399};
400
401//===----------------------------------------------------------------------===//
402// ConvertMemRefSubview
403//===----------------------------------------------------------------------===//
404
405/// Emulating narrow ints on subview have limited support, supporting only
406/// static offset and size and stride of 1. Ideally, the subview should be
407/// folded away before running narrow type emulation, and this pattern would
408/// never run. This pattern is mostly used for testing pruposes.
409struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
410 using OpConversionPattern::OpConversionPattern;
411
412 LogicalResult
413 matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
414 ConversionPatternRewriter &rewriter) const override {
415 MemRefType newTy =
416 dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
417 if (!newTy) {
418 return rewriter.notifyMatchFailure(
419 op->getLoc(),
420 llvm::formatv("failed to convert memref type: {0}", op.getType()));
421 }
422
423 // Only support offset for 1-D subview.
424 if (op.getType().getRank() != 1) {
425 return rewriter.notifyMatchFailure(
426 op->getLoc(), "subview with rank > 1 is not supported");
427 }
428
429 return convertCastingOp(rewriter, adaptor, op, newTy);
430 }
431};
432
433} // end anonymous namespace
434
435//===----------------------------------------------------------------------===//
436// Public Interface Definition
437//===----------------------------------------------------------------------===//
438
439void memref::populateMemRefNarrowTypeEmulationPatterns(
440 arith::NarrowTypeEmulationConverter &typeConverter,
441 RewritePatternSet &patterns) {
442
443 // Populate `memref.*` conversion patterns.
444 patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
445 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
446 ConvertMemrefStore, ConvertMemRefAssumeAlignment,
447 ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
448 arg&: typeConverter, args: patterns.getContext());
449 memref::populateResolveExtractStridedMetadataPatterns(patterns);
450}
451
452static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
453 int dstBits) {
454 if (ty.getRank() == 0)
455 return {};
456
457 int64_t linearizedShape = 1;
458 for (auto shape : ty.getShape()) {
459 if (shape == ShapedType::kDynamic)
460 return {ShapedType::kDynamic};
461 linearizedShape *= shape;
462 }
463 int scale = dstBits / srcBits;
464 // Scale the size to the ceilDiv(linearizedShape, scale)
465 // to accomodate all the values.
466 linearizedShape = (linearizedShape + scale - 1) / scale;
467 return {linearizedShape};
468}
469
470void memref::populateMemRefNarrowTypeEmulationConversions(
471 arith::NarrowTypeEmulationConverter &typeConverter) {
472 typeConverter.addConversion(
473 callback: [&typeConverter](MemRefType ty) -> std::optional<Type> {
474 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
475 if (!intTy)
476 return ty;
477
478 unsigned width = intTy.getWidth();
479 unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
480 if (width >= loadStoreWidth)
481 return ty;
482
483 // Currently only handle innermost stride being 1, checking
484 SmallVector<int64_t> strides;
485 int64_t offset;
486 if (failed(getStridesAndOffset(ty, strides, offset)))
487 return std::nullopt;
488 if (!strides.empty() && strides.back() != 1)
489 return std::nullopt;
490
491 auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
492 intTy.getSignedness());
493 if (!newElemTy)
494 return std::nullopt;
495
496 StridedLayoutAttr layoutAttr;
497 // If the offset is 0, we do not need a strided layout as the stride is
498 // 1, so we only use the strided layout if the offset is not 0.
499 if (offset != 0) {
500 if (offset == ShapedType::kDynamic) {
501 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
502 ArrayRef<int64_t>{1});
503 } else {
504 // Check if the number of bytes are a multiple of the loadStoreWidth
505 // and if so, divide it by the loadStoreWidth to get the offset.
506 if ((offset * width) % loadStoreWidth != 0)
507 return std::nullopt;
508 offset = (offset * width) / loadStoreWidth;
509
510 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
511 ArrayRef<int64_t>{1});
512 }
513 }
514
515 return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
516 newElemTy, layoutAttr, ty.getMemorySpace());
517 });
518}
519

source code of mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp