1 | //===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++ |
2 | //-*-===// |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" |
13 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
14 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
17 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
18 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
19 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/IR/BuiltinTypes.h" |
22 | #include "mlir/IR/OpDefinition.h" |
23 | #include "mlir/Support/LogicalResult.h" |
24 | #include "mlir/Support/MathExtras.h" |
25 | #include "mlir/Transforms/DialectConversion.h" |
26 | #include "llvm/Support/FormatVariadic.h" |
27 | #include "llvm/Support/MathExtras.h" |
28 | #include <cassert> |
29 | #include <type_traits> |
30 | |
31 | using namespace mlir; |
32 | |
33 | //===----------------------------------------------------------------------===// |
34 | // Utility functions |
35 | //===----------------------------------------------------------------------===// |
36 | |
37 | /// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted |
38 | /// type. The result MemRefType of the old op must have a rank and stride of 1, |
39 | /// with static offset and size. The number of bits in the offset must evenly |
40 | /// divide the bitwidth of the new converted type. |
41 | template <typename MemRefOpTy> |
42 | static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter, |
43 | typename MemRefOpTy::Adaptor adaptor, |
44 | MemRefOpTy op, MemRefType newTy) { |
45 | static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() || |
46 | std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(), |
47 | "Expected only memref::SubViewOp or memref::ReinterpretCastOp" ); |
48 | |
49 | auto convertedElementType = newTy.getElementType(); |
50 | auto oldElementType = op.getType().getElementType(); |
51 | int srcBits = oldElementType.getIntOrFloatBitWidth(); |
52 | int dstBits = convertedElementType.getIntOrFloatBitWidth(); |
53 | if (dstBits % srcBits != 0) { |
54 | return rewriter.notifyMatchFailure(op, |
55 | "only dstBits % srcBits == 0 supported" ); |
56 | } |
57 | |
58 | // Only support stride of 1. |
59 | if (llvm::any_of(op.getStaticStrides(), |
60 | [](int64_t stride) { return stride != 1; })) { |
61 | return rewriter.notifyMatchFailure(op->getLoc(), |
62 | "stride != 1 is not supported" ); |
63 | } |
64 | |
65 | auto sizes = op.getStaticSizes(); |
66 | int64_t offset = op.getStaticOffset(0); |
67 | // Only support static sizes and offsets. |
68 | if (llvm::any_of(sizes, |
69 | [](int64_t size) { return size == ShapedType::kDynamic; }) || |
70 | offset == ShapedType::kDynamic) { |
71 | return rewriter.notifyMatchFailure( |
72 | op->getLoc(), "dynamic size or offset is not supported" ); |
73 | } |
74 | |
75 | int elementsPerByte = dstBits / srcBits; |
76 | if (offset % elementsPerByte != 0) { |
77 | return rewriter.notifyMatchFailure( |
78 | op->getLoc(), "offset not multiple of elementsPerByte is not " |
79 | "supported" ); |
80 | } |
81 | |
82 | SmallVector<int64_t> size; |
83 | if (sizes.size()) |
84 | size.push_back(Elt: ceilDiv(sizes[0], elementsPerByte)); |
85 | offset = offset / elementsPerByte; |
86 | |
87 | rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy, |
88 | *adaptor.getODSOperands(0).begin(), |
89 | offset, size, op.getStaticStrides()); |
90 | return success(); |
91 | } |
92 | |
93 | /// When data is loaded/stored in `targetBits` granularity, but is used in |
94 | /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is |
95 | /// treated as an array of elements of width `sourceBits`. |
96 | /// Return the bit offset of the value at position `srcIdx`. For example, if |
97 | /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is |
98 | /// located at (x % 2) * 4. Because there are two elements in one i8, and one |
99 | /// element has 4 bits. |
100 | static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, |
101 | int sourceBits, int targetBits, |
102 | OpBuilder &builder) { |
103 | assert(targetBits % sourceBits == 0); |
104 | AffineExpr s0; |
105 | bindSymbols(ctx: builder.getContext(), exprs&: s0); |
106 | int scaleFactor = targetBits / sourceBits; |
107 | AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits; |
108 | OpFoldResult offsetVal = |
109 | affine::makeComposedFoldedAffineApply(b&: builder, loc, expr: offsetExpr, operands: {srcIdx}); |
110 | Value bitOffset = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: offsetVal); |
111 | IntegerType dstType = builder.getIntegerType(targetBits); |
112 | return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset); |
113 | } |
114 | |
115 | /// When writing a subbyte size, masked bitwise operations are used to only |
116 | /// modify the relevant bits. This function returns an and mask for clearing |
117 | /// the destination bits in a subbyte write. E.g., when writing to the second |
118 | /// i4 in an i32, 0xFFFFFF0F is created. |
119 | static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, |
120 | int64_t srcBits, int64_t dstBits, |
121 | Value bitwidthOffset, OpBuilder &builder) { |
122 | auto dstIntegerType = builder.getIntegerType(dstBits); |
123 | auto maskRightAlignedAttr = |
124 | builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1); |
125 | Value maskRightAligned = builder.create<arith::ConstantOp>( |
126 | loc, dstIntegerType, maskRightAlignedAttr); |
127 | Value writeMaskInverse = |
128 | builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset); |
129 | auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1); |
130 | Value flipVal = |
131 | builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr); |
132 | return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal); |
133 | } |
134 | |
135 | /// Returns the scaled linearized index based on the `srcBits` and `dstBits` |
136 | /// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and |
137 | /// the returned index has the granularity of `dstBits` |
138 | static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, |
139 | OpFoldResult linearizedIndex, |
140 | int64_t srcBits, int64_t dstBits) { |
141 | AffineExpr s0; |
142 | bindSymbols(ctx: builder.getContext(), exprs&: s0); |
143 | int64_t scaler = dstBits / srcBits; |
144 | OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply( |
145 | b&: builder, loc, expr: s0.floorDiv(v: scaler), operands: {linearizedIndex}); |
146 | return getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: scaledLinearizedIndices); |
147 | } |
148 | |
149 | static OpFoldResult |
150 | getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, |
151 | const SmallVector<OpFoldResult> &indices, |
152 | Value memref) { |
153 | auto stridedMetadata = |
154 | builder.create<memref::ExtractStridedMetadataOp>(loc, memref); |
155 | OpFoldResult linearizedIndices; |
156 | std::tie(args: std::ignore, args&: linearizedIndices) = |
157 | memref::getLinearizedMemRefOffsetAndSize( |
158 | builder, loc, srcBits, srcBits, |
159 | stridedMetadata.getConstifiedMixedOffset(), |
160 | stridedMetadata.getConstifiedMixedSizes(), |
161 | stridedMetadata.getConstifiedMixedStrides(), indices); |
162 | return linearizedIndices; |
163 | } |
164 | |
165 | namespace { |
166 | |
167 | //===----------------------------------------------------------------------===// |
168 | // ConvertMemRefAllocation |
169 | //===----------------------------------------------------------------------===// |
170 | |
171 | template <typename OpTy> |
172 | struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> { |
173 | using OpConversionPattern<OpTy>::OpConversionPattern; |
174 | |
175 | LogicalResult |
176 | matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, |
177 | ConversionPatternRewriter &rewriter) const override { |
178 | static_assert(std::is_same<OpTy, memref::AllocOp>() || |
179 | std::is_same<OpTy, memref::AllocaOp>(), |
180 | "expected only memref::AllocOp or memref::AllocaOp" ); |
181 | auto currentType = cast<MemRefType>(op.getMemref().getType()); |
182 | auto newResultType = dyn_cast<MemRefType>( |
183 | this->getTypeConverter()->convertType(op.getType())); |
184 | if (!newResultType) { |
185 | return rewriter.notifyMatchFailure( |
186 | op->getLoc(), |
187 | llvm::formatv("failed to convert memref type: {0}" , op.getType())); |
188 | } |
189 | |
190 | // Special case zero-rank memrefs. |
191 | if (currentType.getRank() == 0) { |
192 | rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{}, |
193 | adaptor.getSymbolOperands(), |
194 | adaptor.getAlignmentAttr()); |
195 | return success(); |
196 | } |
197 | |
198 | Location loc = op.getLoc(); |
199 | OpFoldResult zero = rewriter.getIndexAttr(0); |
200 | SmallVector<OpFoldResult> indices(currentType.getRank(), zero); |
201 | |
202 | // Get linearized type. |
203 | int srcBits = currentType.getElementType().getIntOrFloatBitWidth(); |
204 | int dstBits = newResultType.getElementType().getIntOrFloatBitWidth(); |
205 | SmallVector<OpFoldResult> sizes = op.getMixedSizes(); |
206 | |
207 | memref::LinearizedMemRefInfo linearizedMemRefInfo = |
208 | memref::getLinearizedMemRefOffsetAndSize( |
209 | builder&: rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes); |
210 | SmallVector<Value> dynamicLinearizedSize; |
211 | if (!newResultType.hasStaticShape()) { |
212 | dynamicLinearizedSize.push_back(Elt: getValueOrCreateConstantIndexOp( |
213 | b&: rewriter, loc, ofr: linearizedMemRefInfo.linearizedSize)); |
214 | } |
215 | |
216 | rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize, |
217 | adaptor.getSymbolOperands(), |
218 | adaptor.getAlignmentAttr()); |
219 | return success(); |
220 | } |
221 | }; |
222 | |
223 | //===----------------------------------------------------------------------===// |
224 | // ConvertMemRefAssumeAlignment |
225 | //===----------------------------------------------------------------------===// |
226 | |
227 | struct ConvertMemRefAssumeAlignment final |
228 | : OpConversionPattern<memref::AssumeAlignmentOp> { |
229 | using OpConversionPattern::OpConversionPattern; |
230 | |
231 | LogicalResult |
232 | matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, |
233 | ConversionPatternRewriter &rewriter) const override { |
234 | Type newTy = getTypeConverter()->convertType(op.getMemref().getType()); |
235 | if (!newTy) { |
236 | return rewriter.notifyMatchFailure( |
237 | op->getLoc(), llvm::formatv("failed to convert memref type: {0}" , |
238 | op.getMemref().getType())); |
239 | } |
240 | |
241 | rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>( |
242 | op, adaptor.getMemref(), adaptor.getAlignmentAttr()); |
243 | return success(); |
244 | } |
245 | }; |
246 | |
247 | //===----------------------------------------------------------------------===// |
248 | // ConvertMemRefLoad |
249 | //===----------------------------------------------------------------------===// |
250 | |
251 | struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { |
252 | using OpConversionPattern::OpConversionPattern; |
253 | |
254 | LogicalResult |
255 | matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, |
256 | ConversionPatternRewriter &rewriter) const override { |
257 | auto convertedType = cast<MemRefType>(adaptor.getMemref().getType()); |
258 | auto convertedElementType = convertedType.getElementType(); |
259 | auto oldElementType = op.getMemRefType().getElementType(); |
260 | int srcBits = oldElementType.getIntOrFloatBitWidth(); |
261 | int dstBits = convertedElementType.getIntOrFloatBitWidth(); |
262 | if (dstBits % srcBits != 0) { |
263 | return rewriter.notifyMatchFailure( |
264 | op, "only dstBits % srcBits == 0 supported" ); |
265 | } |
266 | |
267 | Location loc = op.getLoc(); |
268 | // Special case 0-rank memref loads. |
269 | Value bitsLoad; |
270 | if (convertedType.getRank() == 0) { |
271 | bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(), |
272 | ValueRange{}); |
273 | } else { |
274 | // Linearize the indices of the original load instruction. Do not account |
275 | // for the scaling yet. This will be accounted for later. |
276 | OpFoldResult linearizedIndices = getLinearizedSrcIndices( |
277 | rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); |
278 | |
279 | Value newLoad = rewriter.create<memref::LoadOp>( |
280 | loc, adaptor.getMemref(), |
281 | getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits, |
282 | dstBits)); |
283 | |
284 | // Get the offset and shift the bits to the rightmost. |
285 | // Note, currently only the big-endian is supported. |
286 | Value bitwidthOffset = getOffsetForBitwidth(loc, srcIdx: linearizedIndices, |
287 | sourceBits: srcBits, targetBits: dstBits, builder&: rewriter); |
288 | bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset); |
289 | } |
290 | |
291 | // Get the corresponding bits. If the arith computation bitwidth equals |
292 | // to the emulated bitwidth, we apply a mask to extract the low bits. |
293 | // It is not clear if this case actually happens in practice, but we keep |
294 | // the operations just in case. Otherwise, if the arith computation bitwidth |
295 | // is different from the emulated bitwidth we truncate the result. |
296 | Operation *result; |
297 | auto resultTy = getTypeConverter()->convertType(oldElementType); |
298 | if (resultTy == convertedElementType) { |
299 | auto mask = rewriter.create<arith::ConstantOp>( |
300 | loc, convertedElementType, |
301 | rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1)); |
302 | |
303 | result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask); |
304 | } else { |
305 | result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad); |
306 | } |
307 | |
308 | rewriter.replaceOp(op, result->getResult(idx: 0)); |
309 | return success(); |
310 | } |
311 | }; |
312 | |
313 | //===----------------------------------------------------------------------===// |
314 | // ConvertMemRefReinterpretCast |
315 | //===----------------------------------------------------------------------===// |
316 | |
317 | /// Output types should be at most one dimensional, so only the 0 or 1 |
318 | /// dimensional cases are supported. |
319 | struct ConvertMemRefReinterpretCast final |
320 | : OpConversionPattern<memref::ReinterpretCastOp> { |
321 | using OpConversionPattern::OpConversionPattern; |
322 | |
323 | LogicalResult |
324 | matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, |
325 | ConversionPatternRewriter &rewriter) const override { |
326 | MemRefType newTy = |
327 | dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType())); |
328 | if (!newTy) { |
329 | return rewriter.notifyMatchFailure( |
330 | op->getLoc(), |
331 | llvm::formatv("failed to convert memref type: {0}" , op.getType())); |
332 | } |
333 | |
334 | // Only support for 0 or 1 dimensional cases. |
335 | if (op.getType().getRank() > 1) { |
336 | return rewriter.notifyMatchFailure( |
337 | op->getLoc(), "subview with rank > 1 is not supported" ); |
338 | } |
339 | |
340 | return convertCastingOp(rewriter, adaptor, op, newTy); |
341 | } |
342 | }; |
343 | |
344 | //===----------------------------------------------------------------------===// |
345 | // ConvertMemrefStore |
346 | //===----------------------------------------------------------------------===// |
347 | |
348 | struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> { |
349 | using OpConversionPattern::OpConversionPattern; |
350 | |
351 | LogicalResult |
352 | matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, |
353 | ConversionPatternRewriter &rewriter) const override { |
354 | auto convertedType = cast<MemRefType>(adaptor.getMemref().getType()); |
355 | int srcBits = op.getMemRefType().getElementTypeBitWidth(); |
356 | int dstBits = convertedType.getElementTypeBitWidth(); |
357 | auto dstIntegerType = rewriter.getIntegerType(dstBits); |
358 | if (dstBits % srcBits != 0) { |
359 | return rewriter.notifyMatchFailure( |
360 | op, "only dstBits % srcBits == 0 supported" ); |
361 | } |
362 | |
363 | Location loc = op.getLoc(); |
364 | Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, |
365 | adaptor.getValue()); |
366 | |
367 | // Special case 0-rank memref stores. No need for masking. |
368 | if (convertedType.getRank() == 0) { |
369 | rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign, |
370 | extendedInput, adaptor.getMemref(), |
371 | ValueRange{}); |
372 | rewriter.eraseOp(op: op); |
373 | return success(); |
374 | } |
375 | |
376 | OpFoldResult linearizedIndices = getLinearizedSrcIndices( |
377 | rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); |
378 | Value storeIndices = getIndicesForLoadOrStore( |
379 | builder&: rewriter, loc, linearizedIndex: linearizedIndices, srcBits, dstBits); |
380 | Value bitwidthOffset = getOffsetForBitwidth(loc, srcIdx: linearizedIndices, sourceBits: srcBits, |
381 | targetBits: dstBits, builder&: rewriter); |
382 | Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits, |
383 | dstBits, bitwidthOffset, builder&: rewriter); |
384 | // Align the value to write with the destination bits |
385 | Value alignedVal = |
386 | rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset); |
387 | |
388 | // Clear destination bits |
389 | rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, |
390 | writeMask, adaptor.getMemref(), |
391 | storeIndices); |
392 | // Write srcs bits to destination |
393 | rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, |
394 | alignedVal, adaptor.getMemref(), |
395 | storeIndices); |
396 | rewriter.eraseOp(op: op); |
397 | return success(); |
398 | } |
399 | }; |
400 | |
401 | //===----------------------------------------------------------------------===// |
402 | // ConvertMemRefSubview |
403 | //===----------------------------------------------------------------------===// |
404 | |
405 | /// Emulating narrow ints on subview have limited support, supporting only |
406 | /// static offset and size and stride of 1. Ideally, the subview should be |
407 | /// folded away before running narrow type emulation, and this pattern would |
408 | /// never run. This pattern is mostly used for testing pruposes. |
409 | struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> { |
410 | using OpConversionPattern::OpConversionPattern; |
411 | |
412 | LogicalResult |
413 | matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, |
414 | ConversionPatternRewriter &rewriter) const override { |
415 | MemRefType newTy = |
416 | dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType())); |
417 | if (!newTy) { |
418 | return rewriter.notifyMatchFailure( |
419 | op->getLoc(), |
420 | llvm::formatv("failed to convert memref type: {0}" , op.getType())); |
421 | } |
422 | |
423 | // Only support offset for 1-D subview. |
424 | if (op.getType().getRank() != 1) { |
425 | return rewriter.notifyMatchFailure( |
426 | op->getLoc(), "subview with rank > 1 is not supported" ); |
427 | } |
428 | |
429 | return convertCastingOp(rewriter, adaptor, op, newTy); |
430 | } |
431 | }; |
432 | |
433 | } // end anonymous namespace |
434 | |
435 | //===----------------------------------------------------------------------===// |
436 | // Public Interface Definition |
437 | //===----------------------------------------------------------------------===// |
438 | |
439 | void memref::populateMemRefNarrowTypeEmulationPatterns( |
440 | arith::NarrowTypeEmulationConverter &typeConverter, |
441 | RewritePatternSet &patterns) { |
442 | |
443 | // Populate `memref.*` conversion patterns. |
444 | patterns.add<ConvertMemRefAllocation<memref::AllocOp>, |
445 | ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad, |
446 | ConvertMemrefStore, ConvertMemRefAssumeAlignment, |
447 | ConvertMemRefSubview, ConvertMemRefReinterpretCast>( |
448 | arg&: typeConverter, args: patterns.getContext()); |
449 | memref::populateResolveExtractStridedMetadataPatterns(patterns); |
450 | } |
451 | |
452 | static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits, |
453 | int dstBits) { |
454 | if (ty.getRank() == 0) |
455 | return {}; |
456 | |
457 | int64_t linearizedShape = 1; |
458 | for (auto shape : ty.getShape()) { |
459 | if (shape == ShapedType::kDynamic) |
460 | return {ShapedType::kDynamic}; |
461 | linearizedShape *= shape; |
462 | } |
463 | int scale = dstBits / srcBits; |
464 | // Scale the size to the ceilDiv(linearizedShape, scale) |
465 | // to accomodate all the values. |
466 | linearizedShape = (linearizedShape + scale - 1) / scale; |
467 | return {linearizedShape}; |
468 | } |
469 | |
470 | void memref::populateMemRefNarrowTypeEmulationConversions( |
471 | arith::NarrowTypeEmulationConverter &typeConverter) { |
472 | typeConverter.addConversion( |
473 | callback: [&typeConverter](MemRefType ty) -> std::optional<Type> { |
474 | auto intTy = dyn_cast<IntegerType>(ty.getElementType()); |
475 | if (!intTy) |
476 | return ty; |
477 | |
478 | unsigned width = intTy.getWidth(); |
479 | unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth(); |
480 | if (width >= loadStoreWidth) |
481 | return ty; |
482 | |
483 | // Currently only handle innermost stride being 1, checking |
484 | SmallVector<int64_t> strides; |
485 | int64_t offset; |
486 | if (failed(getStridesAndOffset(ty, strides, offset))) |
487 | return std::nullopt; |
488 | if (!strides.empty() && strides.back() != 1) |
489 | return std::nullopt; |
490 | |
491 | auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth, |
492 | intTy.getSignedness()); |
493 | if (!newElemTy) |
494 | return std::nullopt; |
495 | |
496 | StridedLayoutAttr layoutAttr; |
497 | // If the offset is 0, we do not need a strided layout as the stride is |
498 | // 1, so we only use the strided layout if the offset is not 0. |
499 | if (offset != 0) { |
500 | if (offset == ShapedType::kDynamic) { |
501 | layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, |
502 | ArrayRef<int64_t>{1}); |
503 | } else { |
504 | // Check if the number of bytes are a multiple of the loadStoreWidth |
505 | // and if so, divide it by the loadStoreWidth to get the offset. |
506 | if ((offset * width) % loadStoreWidth != 0) |
507 | return std::nullopt; |
508 | offset = (offset * width) / loadStoreWidth; |
509 | |
510 | layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, |
511 | ArrayRef<int64_t>{1}); |
512 | } |
513 | } |
514 | |
515 | return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth), |
516 | newElemTy, layoutAttr, ty.getMemorySpace()); |
517 | }); |
518 | } |
519 | |