1 | //===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++ |
---|---|
2 | //-*-===// |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" |
13 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
14 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
17 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
18 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/BuiltinTypes.h" |
21 | #include "mlir/IR/OpDefinition.h" |
22 | #include "mlir/Transforms/DialectConversion.h" |
23 | #include "llvm/Support/FormatVariadic.h" |
24 | #include "llvm/Support/MathExtras.h" |
25 | #include <cassert> |
26 | #include <type_traits> |
27 | |
28 | using namespace mlir; |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // Utility functions |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | /// Converts a memref::ReinterpretCastOp to the converted type. The result |
35 | /// MemRefType of the old op must have a rank and stride of 1, with static |
36 | /// offset and size. The number of bits in the offset must evenly divide the |
37 | /// bitwidth of the new converted type. |
38 | static LogicalResult |
39 | convertCastingOp(ConversionPatternRewriter &rewriter, |
40 | memref::ReinterpretCastOp::Adaptor adaptor, |
41 | memref::ReinterpretCastOp op, MemRefType newTy) { |
42 | auto convertedElementType = newTy.getElementType(); |
43 | auto oldElementType = op.getType().getElementType(); |
44 | int srcBits = oldElementType.getIntOrFloatBitWidth(); |
45 | int dstBits = convertedElementType.getIntOrFloatBitWidth(); |
46 | if (dstBits % srcBits != 0) { |
47 | return rewriter.notifyMatchFailure(op, |
48 | "only dstBits % srcBits == 0 supported"); |
49 | } |
50 | |
51 | // Only support stride of 1. |
52 | if (llvm::any_of(op.getStaticStrides(), |
53 | [](int64_t stride) { return stride != 1; })) { |
54 | return rewriter.notifyMatchFailure(op->getLoc(), |
55 | "stride != 1 is not supported"); |
56 | } |
57 | |
58 | auto sizes = op.getStaticSizes(); |
59 | int64_t offset = op.getStaticOffset(0); |
60 | // Only support static sizes and offsets. |
61 | if (llvm::is_contained(sizes, ShapedType::kDynamic) || |
62 | offset == ShapedType::kDynamic) { |
63 | return rewriter.notifyMatchFailure( |
64 | op, "dynamic size or offset is not supported"); |
65 | } |
66 | |
67 | int elementsPerByte = dstBits / srcBits; |
68 | if (offset % elementsPerByte != 0) { |
69 | return rewriter.notifyMatchFailure( |
70 | op, "offset not multiple of elementsPerByte is not supported"); |
71 | } |
72 | |
73 | SmallVector<int64_t> size; |
74 | if (sizes.size()) |
75 | size.push_back(Elt: llvm::divideCeilSigned(sizes[0], elementsPerByte)); |
76 | offset = offset / elementsPerByte; |
77 | |
78 | rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( |
79 | op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides()); |
80 | return success(); |
81 | } |
82 | |
83 | /// When data is loaded/stored in `targetBits` granularity, but is used in |
84 | /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is |
85 | /// treated as an array of elements of width `sourceBits`. |
86 | /// Return the bit offset of the value at position `srcIdx`. For example, if |
87 | /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is |
88 | /// located at (x % 2) * 4. Because there are two elements in one i8, and one |
89 | /// element has 4 bits. |
90 | static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, |
91 | int sourceBits, int targetBits, |
92 | OpBuilder &builder) { |
93 | assert(targetBits % sourceBits == 0); |
94 | AffineExpr s0; |
95 | bindSymbols(ctx: builder.getContext(), exprs&: s0); |
96 | int scaleFactor = targetBits / sourceBits; |
97 | AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits; |
98 | OpFoldResult offsetVal = |
99 | affine::makeComposedFoldedAffineApply(b&: builder, loc, expr: offsetExpr, operands: {srcIdx}); |
100 | Value bitOffset = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: offsetVal); |
101 | IntegerType dstType = builder.getIntegerType(targetBits); |
102 | return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset); |
103 | } |
104 | |
105 | /// When writing a subbyte size, masked bitwise operations are used to only |
106 | /// modify the relevant bits. This function returns an and mask for clearing |
107 | /// the destination bits in a subbyte write. E.g., when writing to the second |
108 | /// i4 in an i32, 0xFFFFFF0F is created. |
109 | static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, |
110 | int64_t srcBits, int64_t dstBits, |
111 | Value bitwidthOffset, OpBuilder &builder) { |
112 | auto dstIntegerType = builder.getIntegerType(dstBits); |
113 | auto maskRightAlignedAttr = |
114 | builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1); |
115 | Value maskRightAligned = builder.create<arith::ConstantOp>( |
116 | loc, dstIntegerType, maskRightAlignedAttr); |
117 | Value writeMaskInverse = |
118 | builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset); |
119 | auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1); |
120 | Value flipVal = |
121 | builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr); |
122 | return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal); |
123 | } |
124 | |
125 | /// Returns the scaled linearized index based on the `srcBits` and `dstBits` |
126 | /// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and |
127 | /// the returned index has the granularity of `dstBits` |
128 | static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, |
129 | OpFoldResult linearizedIndex, |
130 | int64_t srcBits, int64_t dstBits) { |
131 | AffineExpr s0; |
132 | bindSymbols(ctx: builder.getContext(), exprs&: s0); |
133 | int64_t scaler = dstBits / srcBits; |
134 | OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply( |
135 | b&: builder, loc, expr: s0.floorDiv(v: scaler), operands: {linearizedIndex}); |
136 | return getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: scaledLinearizedIndices); |
137 | } |
138 | |
139 | static OpFoldResult |
140 | getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, |
141 | const SmallVector<OpFoldResult> &indices, |
142 | Value memref) { |
143 | auto stridedMetadata = |
144 | builder.create<memref::ExtractStridedMetadataOp>(loc, memref); |
145 | OpFoldResult linearizedIndices; |
146 | std::tie(args: std::ignore, args&: linearizedIndices) = |
147 | memref::getLinearizedMemRefOffsetAndSize( |
148 | builder, loc, srcBits, srcBits, |
149 | stridedMetadata.getConstifiedMixedOffset(), |
150 | stridedMetadata.getConstifiedMixedSizes(), |
151 | stridedMetadata.getConstifiedMixedStrides(), indices); |
152 | return linearizedIndices; |
153 | } |
154 | |
155 | namespace { |
156 | |
157 | //===----------------------------------------------------------------------===// |
158 | // ConvertMemRefAllocation |
159 | //===----------------------------------------------------------------------===// |
160 | |
161 | template <typename OpTy> |
162 | struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> { |
163 | using OpConversionPattern<OpTy>::OpConversionPattern; |
164 | |
165 | LogicalResult |
166 | matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, |
167 | ConversionPatternRewriter &rewriter) const override { |
168 | static_assert(std::is_same<OpTy, memref::AllocOp>() || |
169 | std::is_same<OpTy, memref::AllocaOp>(), |
170 | "expected only memref::AllocOp or memref::AllocaOp"); |
171 | auto currentType = cast<MemRefType>(op.getMemref().getType()); |
172 | auto newResultType = |
173 | this->getTypeConverter()->template convertType<MemRefType>( |
174 | op.getType()); |
175 | if (!newResultType) { |
176 | return rewriter.notifyMatchFailure( |
177 | op->getLoc(), |
178 | llvm::formatv("failed to convert memref type: {0}", op.getType())); |
179 | } |
180 | |
181 | // Special case zero-rank memrefs. |
182 | if (currentType.getRank() == 0) { |
183 | rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{}, |
184 | adaptor.getSymbolOperands(), |
185 | adaptor.getAlignmentAttr()); |
186 | return success(); |
187 | } |
188 | |
189 | Location loc = op.getLoc(); |
190 | OpFoldResult zero = rewriter.getIndexAttr(0); |
191 | |
192 | // Get linearized type. |
193 | int srcBits = currentType.getElementType().getIntOrFloatBitWidth(); |
194 | int dstBits = newResultType.getElementType().getIntOrFloatBitWidth(); |
195 | SmallVector<OpFoldResult> sizes = op.getMixedSizes(); |
196 | |
197 | memref::LinearizedMemRefInfo linearizedMemRefInfo = |
198 | memref::getLinearizedMemRefOffsetAndSize( |
199 | builder&: rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes); |
200 | SmallVector<Value> dynamicLinearizedSize; |
201 | if (!newResultType.hasStaticShape()) { |
202 | dynamicLinearizedSize.push_back(Elt: getValueOrCreateConstantIndexOp( |
203 | b&: rewriter, loc, ofr: linearizedMemRefInfo.linearizedSize)); |
204 | } |
205 | |
206 | rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize, |
207 | adaptor.getSymbolOperands(), |
208 | adaptor.getAlignmentAttr()); |
209 | return success(); |
210 | } |
211 | }; |
212 | |
213 | //===----------------------------------------------------------------------===// |
214 | // ConvertMemRefAssumeAlignment |
215 | //===----------------------------------------------------------------------===// |
216 | |
217 | struct ConvertMemRefAssumeAlignment final |
218 | : OpConversionPattern<memref::AssumeAlignmentOp> { |
219 | using OpConversionPattern::OpConversionPattern; |
220 | |
221 | LogicalResult |
222 | matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, |
223 | ConversionPatternRewriter &rewriter) const override { |
224 | Type newTy = getTypeConverter()->convertType(op.getMemref().getType()); |
225 | if (!newTy) { |
226 | return rewriter.notifyMatchFailure( |
227 | op->getLoc(), llvm::formatv("failed to convert memref type: {0}", |
228 | op.getMemref().getType())); |
229 | } |
230 | |
231 | rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>( |
232 | op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr()); |
233 | return success(); |
234 | } |
235 | }; |
236 | |
237 | //===----------------------------------------------------------------------===// |
238 | // ConvertMemRefCopy |
239 | //===----------------------------------------------------------------------===// |
240 | |
241 | struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> { |
242 | using OpConversionPattern::OpConversionPattern; |
243 | |
244 | LogicalResult |
245 | matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, |
246 | ConversionPatternRewriter &rewriter) const override { |
247 | auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType()); |
248 | auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType()); |
249 | if (maybeRankedSource && maybeRankedDest && |
250 | maybeRankedSource.getLayout() != maybeRankedDest.getLayout()) |
251 | return rewriter.notifyMatchFailure( |
252 | op, llvm::formatv("memref.copy emulation with distinct layouts ({0} " |
253 | "and {1}) is currently unimplemented", |
254 | maybeRankedSource.getLayout(), |
255 | maybeRankedDest.getLayout())); |
256 | rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(), |
257 | adaptor.getTarget()); |
258 | return success(); |
259 | } |
260 | }; |
261 | |
262 | //===----------------------------------------------------------------------===// |
263 | // ConvertMemRefDealloc |
264 | //===----------------------------------------------------------------------===// |
265 | |
266 | struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> { |
267 | using OpConversionPattern::OpConversionPattern; |
268 | |
269 | LogicalResult |
270 | matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, |
271 | ConversionPatternRewriter &rewriter) const override { |
272 | rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref()); |
273 | return success(); |
274 | } |
275 | }; |
276 | |
277 | //===----------------------------------------------------------------------===// |
278 | // ConvertMemRefLoad |
279 | //===----------------------------------------------------------------------===// |
280 | |
281 | struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { |
282 | using OpConversionPattern::OpConversionPattern; |
283 | |
284 | LogicalResult |
285 | matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, |
286 | ConversionPatternRewriter &rewriter) const override { |
287 | auto convertedType = cast<MemRefType>(adaptor.getMemref().getType()); |
288 | auto convertedElementType = convertedType.getElementType(); |
289 | auto oldElementType = op.getMemRefType().getElementType(); |
290 | int srcBits = oldElementType.getIntOrFloatBitWidth(); |
291 | int dstBits = convertedElementType.getIntOrFloatBitWidth(); |
292 | if (dstBits % srcBits != 0) { |
293 | return rewriter.notifyMatchFailure( |
294 | op, "only dstBits % srcBits == 0 supported"); |
295 | } |
296 | |
297 | Location loc = op.getLoc(); |
298 | // Special case 0-rank memref loads. |
299 | Value bitsLoad; |
300 | if (convertedType.getRank() == 0) { |
301 | bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(), |
302 | ValueRange{}); |
303 | } else { |
304 | // Linearize the indices of the original load instruction. Do not account |
305 | // for the scaling yet. This will be accounted for later. |
306 | OpFoldResult linearizedIndices = getLinearizedSrcIndices( |
307 | rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); |
308 | |
309 | Value newLoad = rewriter.create<memref::LoadOp>( |
310 | loc, adaptor.getMemref(), |
311 | getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits, |
312 | dstBits)); |
313 | |
314 | // Get the offset and shift the bits to the rightmost. |
315 | // Note, currently only the big-endian is supported. |
316 | Value bitwidthOffset = getOffsetForBitwidth(loc, srcIdx: linearizedIndices, |
317 | sourceBits: srcBits, targetBits: dstBits, builder&: rewriter); |
318 | bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset); |
319 | } |
320 | |
321 | // Get the corresponding bits. If the arith computation bitwidth equals |
322 | // to the emulated bitwidth, we apply a mask to extract the low bits. |
323 | // It is not clear if this case actually happens in practice, but we keep |
324 | // the operations just in case. Otherwise, if the arith computation bitwidth |
325 | // is different from the emulated bitwidth we truncate the result. |
326 | Operation *result; |
327 | auto resultTy = getTypeConverter()->convertType(oldElementType); |
328 | if (resultTy == convertedElementType) { |
329 | auto mask = rewriter.create<arith::ConstantOp>( |
330 | loc, convertedElementType, |
331 | rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1)); |
332 | |
333 | result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask); |
334 | } else { |
335 | result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad); |
336 | } |
337 | |
338 | rewriter.replaceOp(op, result->getResult(idx: 0)); |
339 | return success(); |
340 | } |
341 | }; |
342 | |
343 | //===----------------------------------------------------------------------===// |
344 | // ConvertMemRefMemorySpaceCast |
345 | //===----------------------------------------------------------------------===// |
346 | |
347 | struct ConvertMemRefMemorySpaceCast final |
348 | : OpConversionPattern<memref::MemorySpaceCastOp> { |
349 | using OpConversionPattern::OpConversionPattern; |
350 | |
351 | LogicalResult |
352 | matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor, |
353 | ConversionPatternRewriter &rewriter) const override { |
354 | Type newTy = getTypeConverter()->convertType(op.getDest().getType()); |
355 | if (!newTy) { |
356 | return rewriter.notifyMatchFailure( |
357 | op->getLoc(), llvm::formatv("failed to convert memref type: {0}", |
358 | op.getDest().getType())); |
359 | } |
360 | |
361 | rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy, |
362 | adaptor.getSource()); |
363 | return success(); |
364 | } |
365 | }; |
366 | |
367 | //===----------------------------------------------------------------------===// |
368 | // ConvertMemRefReinterpretCast |
369 | //===----------------------------------------------------------------------===// |
370 | |
371 | /// Output types should be at most one dimensional, so only the 0 or 1 |
372 | /// dimensional cases are supported. |
373 | struct ConvertMemRefReinterpretCast final |
374 | : OpConversionPattern<memref::ReinterpretCastOp> { |
375 | using OpConversionPattern::OpConversionPattern; |
376 | |
377 | LogicalResult |
378 | matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, |
379 | ConversionPatternRewriter &rewriter) const override { |
380 | MemRefType newTy = |
381 | getTypeConverter()->convertType<MemRefType>(op.getType()); |
382 | if (!newTy) { |
383 | return rewriter.notifyMatchFailure( |
384 | op->getLoc(), |
385 | llvm::formatv("failed to convert memref type: {0}", op.getType())); |
386 | } |
387 | |
388 | // Only support for 0 or 1 dimensional cases. |
389 | if (op.getType().getRank() > 1) { |
390 | return rewriter.notifyMatchFailure( |
391 | op->getLoc(), "subview with rank > 1 is not supported"); |
392 | } |
393 | |
394 | return convertCastingOp(rewriter, adaptor, op, newTy); |
395 | } |
396 | }; |
397 | |
398 | //===----------------------------------------------------------------------===// |
399 | // ConvertMemrefStore |
400 | //===----------------------------------------------------------------------===// |
401 | |
402 | struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> { |
403 | using OpConversionPattern::OpConversionPattern; |
404 | |
405 | LogicalResult |
406 | matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, |
407 | ConversionPatternRewriter &rewriter) const override { |
408 | auto convertedType = cast<MemRefType>(adaptor.getMemref().getType()); |
409 | int srcBits = op.getMemRefType().getElementTypeBitWidth(); |
410 | int dstBits = convertedType.getElementTypeBitWidth(); |
411 | auto dstIntegerType = rewriter.getIntegerType(dstBits); |
412 | if (dstBits % srcBits != 0) { |
413 | return rewriter.notifyMatchFailure( |
414 | op, "only dstBits % srcBits == 0 supported"); |
415 | } |
416 | |
417 | Location loc = op.getLoc(); |
418 | Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, |
419 | adaptor.getValue()); |
420 | |
421 | // Special case 0-rank memref stores. No need for masking. |
422 | if (convertedType.getRank() == 0) { |
423 | rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign, |
424 | extendedInput, adaptor.getMemref(), |
425 | ValueRange{}); |
426 | rewriter.eraseOp(op: op); |
427 | return success(); |
428 | } |
429 | |
430 | OpFoldResult linearizedIndices = getLinearizedSrcIndices( |
431 | rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); |
432 | Value storeIndices = getIndicesForLoadOrStore( |
433 | builder&: rewriter, loc, linearizedIndex: linearizedIndices, srcBits, dstBits); |
434 | Value bitwidthOffset = getOffsetForBitwidth(loc, srcIdx: linearizedIndices, sourceBits: srcBits, |
435 | targetBits: dstBits, builder&: rewriter); |
436 | Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits, |
437 | dstBits, bitwidthOffset, builder&: rewriter); |
438 | // Align the value to write with the destination bits |
439 | Value alignedVal = |
440 | rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset); |
441 | |
442 | // Clear destination bits |
443 | rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, |
444 | writeMask, adaptor.getMemref(), |
445 | storeIndices); |
446 | // Write srcs bits to destination |
447 | rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, |
448 | alignedVal, adaptor.getMemref(), |
449 | storeIndices); |
450 | rewriter.eraseOp(op: op); |
451 | return success(); |
452 | } |
453 | }; |
454 | |
455 | //===----------------------------------------------------------------------===// |
456 | // ConvertMemRefSubview |
457 | //===----------------------------------------------------------------------===// |
458 | |
459 | /// Emulating narrow ints on subview have limited support, supporting only |
460 | /// static offset and size and stride of 1. Ideally, the subview should be |
461 | /// folded away before running narrow type emulation, and this pattern should |
462 | /// only run for cases that can't be folded. |
463 | struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> { |
464 | using OpConversionPattern::OpConversionPattern; |
465 | |
466 | LogicalResult |
467 | matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, |
468 | ConversionPatternRewriter &rewriter) const override { |
469 | MemRefType newTy = |
470 | getTypeConverter()->convertType<MemRefType>(subViewOp.getType()); |
471 | if (!newTy) { |
472 | return rewriter.notifyMatchFailure( |
473 | subViewOp->getLoc(), |
474 | llvm::formatv("failed to convert memref type: {0}", |
475 | subViewOp.getType())); |
476 | } |
477 | |
478 | Location loc = subViewOp.getLoc(); |
479 | Type convertedElementType = newTy.getElementType(); |
480 | Type oldElementType = subViewOp.getType().getElementType(); |
481 | int srcBits = oldElementType.getIntOrFloatBitWidth(); |
482 | int dstBits = convertedElementType.getIntOrFloatBitWidth(); |
483 | if (dstBits % srcBits != 0) |
484 | return rewriter.notifyMatchFailure( |
485 | subViewOp, "only dstBits % srcBits == 0 supported"); |
486 | |
487 | // Only support stride of 1. |
488 | if (llvm::any_of(subViewOp.getStaticStrides(), |
489 | [](int64_t stride) { return stride != 1; })) { |
490 | return rewriter.notifyMatchFailure(subViewOp->getLoc(), |
491 | "stride != 1 is not supported"); |
492 | } |
493 | |
494 | if (!memref::isStaticShapeAndContiguousRowMajor(type: subViewOp.getType())) { |
495 | return rewriter.notifyMatchFailure( |
496 | subViewOp, "the result memref type is not contiguous"); |
497 | } |
498 | |
499 | auto sizes = subViewOp.getStaticSizes(); |
500 | int64_t lastOffset = subViewOp.getStaticOffsets().back(); |
501 | // Only support static sizes and offsets. |
502 | if (llvm::is_contained(sizes, ShapedType::kDynamic) || |
503 | lastOffset == ShapedType::kDynamic) { |
504 | return rewriter.notifyMatchFailure( |
505 | subViewOp->getLoc(), "dynamic size or offset is not supported"); |
506 | } |
507 | |
508 | // Transform the offsets, sizes and strides according to the emulation. |
509 | auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>( |
510 | loc, subViewOp.getViewSource()); |
511 | |
512 | OpFoldResult linearizedIndices; |
513 | auto strides = stridedMetadata.getConstifiedMixedStrides(); |
514 | memref::LinearizedMemRefInfo linearizedInfo; |
515 | std::tie(args&: linearizedInfo, args&: linearizedIndices) = |
516 | memref::getLinearizedMemRefOffsetAndSize( |
517 | rewriter, loc, srcBits, dstBits, |
518 | stridedMetadata.getConstifiedMixedOffset(), |
519 | subViewOp.getMixedSizes(), strides, |
520 | getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), |
521 | rewriter)); |
522 | |
523 | rewriter.replaceOpWithNewOp<memref::SubViewOp>( |
524 | subViewOp, newTy, adaptor.getSource(), linearizedIndices, |
525 | linearizedInfo.linearizedSize, strides.back()); |
526 | return success(); |
527 | } |
528 | }; |
529 | |
530 | //===----------------------------------------------------------------------===// |
531 | // ConvertMemRefCollapseShape |
532 | //===----------------------------------------------------------------------===// |
533 | |
534 | /// Emulating a `memref.collapse_shape` becomes a no-op after emulation given |
535 | /// that we flatten memrefs to a single dimension as part of the emulation and |
536 | /// there is no dimension to collapse any further. |
537 | struct ConvertMemRefCollapseShape final |
538 | : OpConversionPattern<memref::CollapseShapeOp> { |
539 | using OpConversionPattern::OpConversionPattern; |
540 | |
541 | LogicalResult |
542 | matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor, |
543 | ConversionPatternRewriter &rewriter) const override { |
544 | Value srcVal = adaptor.getSrc(); |
545 | auto newTy = dyn_cast<MemRefType>(srcVal.getType()); |
546 | if (!newTy) |
547 | return failure(); |
548 | |
549 | if (newTy.getRank() != 1) |
550 | return failure(); |
551 | |
552 | rewriter.replaceOp(collapseShapeOp, srcVal); |
553 | return success(); |
554 | } |
555 | }; |
556 | |
557 | /// Emulating a `memref.expand_shape` becomes a no-op after emulation given |
558 | /// that we flatten memrefs to a single dimension as part of the emulation and |
559 | /// the expansion would just have been undone. |
560 | struct ConvertMemRefExpandShape final |
561 | : OpConversionPattern<memref::ExpandShapeOp> { |
562 | using OpConversionPattern::OpConversionPattern; |
563 | |
564 | LogicalResult |
565 | matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor, |
566 | ConversionPatternRewriter &rewriter) const override { |
567 | Value srcVal = adaptor.getSrc(); |
568 | auto newTy = dyn_cast<MemRefType>(srcVal.getType()); |
569 | if (!newTy) |
570 | return failure(); |
571 | |
572 | if (newTy.getRank() != 1) |
573 | return failure(); |
574 | |
575 | rewriter.replaceOp(expandShapeOp, srcVal); |
576 | return success(); |
577 | } |
578 | }; |
579 | } // end anonymous namespace |
580 | |
581 | //===----------------------------------------------------------------------===// |
582 | // Public Interface Definition |
583 | //===----------------------------------------------------------------------===// |
584 | |
585 | void memref::populateMemRefNarrowTypeEmulationPatterns( |
586 | const arith::NarrowTypeEmulationConverter &typeConverter, |
587 | RewritePatternSet &patterns) { |
588 | |
589 | // Populate `memref.*` conversion patterns. |
590 | patterns.add<ConvertMemRefAllocation<memref::AllocOp>, |
591 | ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy, |
592 | ConvertMemRefDealloc, ConvertMemRefCollapseShape, |
593 | ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore, |
594 | ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast, |
595 | ConvertMemRefSubview, ConvertMemRefReinterpretCast>( |
596 | arg: typeConverter, args: patterns.getContext()); |
597 | memref::populateResolveExtractStridedMetadataPatterns(patterns); |
598 | } |
599 | |
600 | static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits, |
601 | int dstBits) { |
602 | if (ty.getRank() == 0) |
603 | return {}; |
604 | |
605 | int64_t linearizedShape = 1; |
606 | for (auto shape : ty.getShape()) { |
607 | if (shape == ShapedType::kDynamic) |
608 | return {ShapedType::kDynamic}; |
609 | linearizedShape *= shape; |
610 | } |
611 | int scale = dstBits / srcBits; |
612 | // Scale the size to the ceilDiv(linearizedShape, scale) |
613 | // to accomodate all the values. |
614 | linearizedShape = (linearizedShape + scale - 1) / scale; |
615 | return {linearizedShape}; |
616 | } |
617 | |
618 | void memref::populateMemRefNarrowTypeEmulationConversions( |
619 | arith::NarrowTypeEmulationConverter &typeConverter) { |
620 | typeConverter.addConversion( |
621 | callback: [&typeConverter](MemRefType ty) -> std::optional<Type> { |
622 | auto intTy = dyn_cast<IntegerType>(ty.getElementType()); |
623 | if (!intTy) |
624 | return ty; |
625 | |
626 | unsigned width = intTy.getWidth(); |
627 | unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth(); |
628 | if (width >= loadStoreWidth) |
629 | return ty; |
630 | |
631 | // Currently only handle innermost stride being 1, checking |
632 | SmallVector<int64_t> strides; |
633 | int64_t offset; |
634 | if (failed(ty.getStridesAndOffset(strides, offset))) |
635 | return nullptr; |
636 | if (!strides.empty() && strides.back() != 1) |
637 | return nullptr; |
638 | |
639 | auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth, |
640 | intTy.getSignedness()); |
641 | if (!newElemTy) |
642 | return nullptr; |
643 | |
644 | StridedLayoutAttr layoutAttr; |
645 | // If the offset is 0, we do not need a strided layout as the stride is |
646 | // 1, so we only use the strided layout if the offset is not 0. |
647 | if (offset != 0) { |
648 | if (offset == ShapedType::kDynamic) { |
649 | layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, |
650 | ArrayRef<int64_t>{1}); |
651 | } else { |
652 | // Check if the number of bytes are a multiple of the loadStoreWidth |
653 | // and if so, divide it by the loadStoreWidth to get the offset. |
654 | if ((offset * width) % loadStoreWidth != 0) |
655 | return std::nullopt; |
656 | offset = (offset * width) / loadStoreWidth; |
657 | |
658 | layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, |
659 | ArrayRef<int64_t>{1}); |
660 | } |
661 | } |
662 | |
663 | return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth), |
664 | newElemTy, layoutAttr, ty.getMemorySpace()); |
665 | }); |
666 | } |
667 |
Definitions
- convertCastingOp
- getOffsetForBitwidth
- getSubByteWriteMask
- getIndicesForLoadOrStore
- getLinearizedSrcIndices
- ConvertMemRefAllocation
- matchAndRewrite
- ConvertMemRefAssumeAlignment
- matchAndRewrite
- ConvertMemRefCopy
- matchAndRewrite
- ConvertMemRefDealloc
- matchAndRewrite
- ConvertMemRefLoad
- matchAndRewrite
- ConvertMemRefMemorySpaceCast
- matchAndRewrite
- ConvertMemRefReinterpretCast
- matchAndRewrite
- ConvertMemrefStore
- matchAndRewrite
- ConvertMemRefSubview
- matchAndRewrite
- ConvertMemRefCollapseShape
- matchAndRewrite
- ConvertMemRefExpandShape
- matchAndRewrite
- populateMemRefNarrowTypeEmulationPatterns
- getLinearizedShape
Learn to use CMake with our Intro Training
Find out more