1 | //===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
10 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
11 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
12 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
13 | #include "mlir/IR/AffineMap.h" |
14 | #include "mlir/IR/BuiltinAttributes.h" |
15 | |
16 | using namespace mlir; |
17 | |
18 | //===----------------------------------------------------------------------===// |
19 | // ConvertToLLVMPattern |
20 | //===----------------------------------------------------------------------===// |
21 | |
22 | ConvertToLLVMPattern::ConvertToLLVMPattern( |
23 | StringRef rootOpName, MLIRContext *context, |
24 | const LLVMTypeConverter &typeConverter, PatternBenefit benefit) |
25 | : ConversionPattern(typeConverter, rootOpName, benefit, context) {} |
26 | |
27 | const LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { |
28 | return static_cast<const LLVMTypeConverter *>( |
29 | ConversionPattern::getTypeConverter()); |
30 | } |
31 | |
32 | LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { |
33 | return *getTypeConverter()->getDialect(); |
34 | } |
35 | |
36 | Type ConvertToLLVMPattern::getIndexType() const { |
37 | return getTypeConverter()->getIndexType(); |
38 | } |
39 | |
40 | Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { |
41 | return IntegerType::get(&getTypeConverter()->getContext(), |
42 | getTypeConverter()->getPointerBitwidth(addressSpace)); |
43 | } |
44 | |
45 | Type ConvertToLLVMPattern::getVoidType() const { |
46 | return LLVM::LLVMVoidType::get(ctx: &getTypeConverter()->getContext()); |
47 | } |
48 | |
49 | Type ConvertToLLVMPattern::getVoidPtrType() const { |
50 | return LLVM::LLVMPointerType::get(&getTypeConverter()->getContext()); |
51 | } |
52 | |
53 | Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, |
54 | Location loc, |
55 | Type resultType, |
56 | int64_t value) { |
57 | return builder.create<LLVM::ConstantOp>(loc, resultType, |
58 | builder.getIndexAttr(value)); |
59 | } |
60 | |
61 | Value ConvertToLLVMPattern::getStridedElementPtr( |
62 | ConversionPatternRewriter &rewriter, Location loc, MemRefType type, |
63 | Value memRefDesc, ValueRange indices, |
64 | LLVM::GEPNoWrapFlags noWrapFlags) const { |
65 | return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type, |
66 | memRefDesc, indices, noWrapFlags); |
67 | } |
68 | |
69 | // Check if the MemRefType `type` is supported by the lowering. We currently |
70 | // only support memrefs with identity maps. |
71 | bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( |
72 | MemRefType type) const { |
73 | if (!type.getLayout().isIdentity()) |
74 | return false; |
75 | return static_cast<bool>(typeConverter->convertType(type)); |
76 | } |
77 | |
78 | Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { |
79 | auto addressSpace = getTypeConverter()->getMemRefAddressSpace(type: type); |
80 | if (failed(addressSpace)) |
81 | return {}; |
82 | return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); |
83 | } |
84 | |
85 | void ConvertToLLVMPattern::getMemRefDescriptorSizes( |
86 | Location loc, MemRefType memRefType, ValueRange dynamicSizes, |
87 | ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes, |
88 | SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const { |
89 | assert(isConvertibleAndHasIdentityMaps(memRefType) && |
90 | "layout maps must have been normalized away"); |
91 | assert(count(memRefType.getShape(), ShapedType::kDynamic) == |
92 | static_cast<ssize_t>(dynamicSizes.size()) && |
93 | "dynamicSizes size doesn't match dynamic sizes count in memref shape"); |
94 | |
95 | sizes.reserve(N: memRefType.getRank()); |
96 | unsigned dynamicIndex = 0; |
97 | Type indexType = getIndexType(); |
98 | for (int64_t size : memRefType.getShape()) { |
99 | sizes.push_back( |
100 | size == ShapedType::kDynamic |
101 | ? dynamicSizes[dynamicIndex++] |
102 | : createIndexAttrConstant(rewriter, loc, indexType, size)); |
103 | } |
104 | |
105 | // Strides: iterate sizes in reverse order and multiply. |
106 | int64_t stride = 1; |
107 | Value runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1); |
108 | strides.resize(memRefType.getRank()); |
109 | for (auto i = memRefType.getRank(); i-- > 0;) { |
110 | strides[i] = runningStride; |
111 | |
112 | int64_t staticSize = memRefType.getShape()[i]; |
113 | bool useSizeAsStride = stride == 1; |
114 | if (staticSize == ShapedType::kDynamic) |
115 | stride = ShapedType::kDynamic; |
116 | if (stride != ShapedType::kDynamic) |
117 | stride *= staticSize; |
118 | |
119 | if (useSizeAsStride) |
120 | runningStride = sizes[i]; |
121 | else if (stride == ShapedType::kDynamic) |
122 | runningStride = |
123 | rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]); |
124 | else |
125 | runningStride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: stride); |
126 | } |
127 | if (sizeInBytes) { |
128 | // Buffer size in bytes. |
129 | Type elementType = typeConverter->convertType(memRefType.getElementType()); |
130 | auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
131 | Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType); |
132 | Value gepPtr = rewriter.create<LLVM::GEPOp>( |
133 | loc, elementPtrType, elementType, nullPtr, runningStride); |
134 | size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); |
135 | } else { |
136 | size = runningStride; |
137 | } |
138 | } |
139 | |
140 | Value ConvertToLLVMPattern::getSizeInBytes( |
141 | Location loc, Type type, ConversionPatternRewriter &rewriter) const { |
142 | // Compute the size of an individual element. This emits the MLIR equivalent |
143 | // of the following sizeof(...) implementation in LLVM IR: |
144 | // %0 = getelementptr %elementType* null, %indexType 1 |
145 | // %1 = ptrtoint %elementType* %0 to %indexType |
146 | // which is a common pattern of getting the size of a type in bytes. |
147 | Type llvmType = typeConverter->convertType(t: type); |
148 | auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
149 | auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType); |
150 | auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType, |
151 | nullPtr, ArrayRef<LLVM::GEPArg>{1}); |
152 | return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); |
153 | } |
154 | |
155 | Value ConvertToLLVMPattern::getNumElements( |
156 | Location loc, MemRefType memRefType, ValueRange dynamicSizes, |
157 | ConversionPatternRewriter &rewriter) const { |
158 | assert(count(memRefType.getShape(), ShapedType::kDynamic) == |
159 | static_cast<ssize_t>(dynamicSizes.size()) && |
160 | "dynamicSizes size doesn't match dynamic sizes count in memref shape"); |
161 | |
162 | Type indexType = getIndexType(); |
163 | Value numElements = memRefType.getRank() == 0 |
164 | ? createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1) |
165 | : nullptr; |
166 | unsigned dynamicIndex = 0; |
167 | |
168 | // Compute the total number of memref elements. |
169 | for (int64_t staticSize : memRefType.getShape()) { |
170 | if (numElements) { |
171 | Value size = |
172 | staticSize == ShapedType::kDynamic |
173 | ? dynamicSizes[dynamicIndex++] |
174 | : createIndexAttrConstant(rewriter, loc, indexType, staticSize); |
175 | numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); |
176 | } else { |
177 | numElements = |
178 | staticSize == ShapedType::kDynamic |
179 | ? dynamicSizes[dynamicIndex++] |
180 | : createIndexAttrConstant(rewriter, loc, indexType, staticSize); |
181 | } |
182 | } |
183 | return numElements; |
184 | } |
185 | |
186 | /// Creates and populates the memref descriptor struct given all its fields. |
187 | MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( |
188 | Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, |
189 | ArrayRef<Value> sizes, ArrayRef<Value> strides, |
190 | ConversionPatternRewriter &rewriter) const { |
191 | auto structType = typeConverter->convertType(memRefType); |
192 | auto memRefDescriptor = MemRefDescriptor::poison(builder&: rewriter, loc, descriptorType: structType); |
193 | |
194 | // Field 1: Allocated pointer, used for malloc/free. |
195 | memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); |
196 | |
197 | // Field 2: Actual aligned pointer to payload. |
198 | memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); |
199 | |
200 | // Field 3: Offset in aligned pointer. |
201 | Type indexType = getIndexType(); |
202 | memRefDescriptor.setOffset( |
203 | rewriter, loc, createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 0)); |
204 | |
205 | // Fields 4: Sizes. |
206 | for (const auto &en : llvm::enumerate(First&: sizes)) |
207 | memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); |
208 | |
209 | // Field 5: Strides. |
210 | for (const auto &en : llvm::enumerate(First&: strides)) |
211 | memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); |
212 | |
213 | return memRefDescriptor; |
214 | } |
215 | |
216 | LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( |
217 | OpBuilder &builder, Location loc, TypeRange origTypes, |
218 | SmallVectorImpl<Value> &operands, bool toDynamic) const { |
219 | assert(origTypes.size() == operands.size() && |
220 | "expected as may original types as operands"); |
221 | |
222 | // Find operands of unranked memref type and store them. |
223 | SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs; |
224 | SmallVector<unsigned> unrankedAddressSpaces; |
225 | for (unsigned i = 0, e = operands.size(); i < e; ++i) { |
226 | if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) { |
227 | unrankedMemrefs.emplace_back(Args&: operands[i]); |
228 | FailureOr<unsigned> addressSpace = |
229 | getTypeConverter()->getMemRefAddressSpace(type: memRefType); |
230 | if (failed(Result: addressSpace)) |
231 | return failure(); |
232 | unrankedAddressSpaces.emplace_back(Args&: *addressSpace); |
233 | } |
234 | } |
235 | |
236 | if (unrankedMemrefs.empty()) |
237 | return success(); |
238 | |
239 | // Compute allocation sizes. |
240 | SmallVector<Value> sizes; |
241 | UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter: *getTypeConverter(), |
242 | values: unrankedMemrefs, addressSpaces: unrankedAddressSpaces, |
243 | sizes); |
244 | |
245 | // Get frequently used types. |
246 | Type indexType = getTypeConverter()->getIndexType(); |
247 | |
248 | // Find the malloc and free, or declare them if necessary. |
249 | auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>(); |
250 | FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc; |
251 | if (toDynamic) { |
252 | mallocFunc = LLVM::lookupOrCreateMallocFn(b&: builder, moduleOp: module, indexType); |
253 | if (failed(Result: mallocFunc)) |
254 | return failure(); |
255 | } |
256 | if (!toDynamic) { |
257 | freeFunc = LLVM::lookupOrCreateFreeFn(b&: builder, moduleOp: module); |
258 | if (failed(freeFunc)) |
259 | return failure(); |
260 | } |
261 | |
262 | unsigned unrankedMemrefPos = 0; |
263 | for (unsigned i = 0, e = operands.size(); i < e; ++i) { |
264 | Type type = origTypes[i]; |
265 | if (!isa<UnrankedMemRefType>(Val: type)) |
266 | continue; |
267 | Value allocationSize = sizes[unrankedMemrefPos++]; |
268 | UnrankedMemRefDescriptor desc(operands[i]); |
269 | |
270 | // Allocate memory, copy, and free the source if necessary. |
271 | Value memory = |
272 | toDynamic |
273 | ? builder |
274 | .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize) |
275 | .getResult() |
276 | : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(), |
277 | IntegerType::get(getContext(), 8), |
278 | allocationSize, |
279 | /*alignment=*/0); |
280 | Value source = desc.memRefDescPtr(builder, loc); |
281 | builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false); |
282 | if (!toDynamic) |
283 | builder.create<LLVM::CallOp>(loc, freeFunc.value(), source); |
284 | |
285 | // Create a new descriptor. The same descriptor can be returned multiple |
286 | // times, attempting to modify its pointer can lead to memory leaks |
287 | // (allocated twice and overwritten) or double frees (the caller does not |
288 | // know if the descriptor points to the same memory). |
289 | Type descriptorType = getTypeConverter()->convertType(t: type); |
290 | if (!descriptorType) |
291 | return failure(); |
292 | auto updatedDesc = |
293 | UnrankedMemRefDescriptor::poison(builder, loc, descriptorType); |
294 | Value rank = desc.rank(builder, loc); |
295 | updatedDesc.setRank(builder, loc, value: rank); |
296 | updatedDesc.setMemRefDescPtr(builder, loc, value: memory); |
297 | |
298 | operands[i] = updatedDesc; |
299 | } |
300 | |
301 | return success(); |
302 | } |
303 | |
304 | //===----------------------------------------------------------------------===// |
305 | // Detail methods |
306 | //===----------------------------------------------------------------------===// |
307 | |
308 | void LLVM::detail::setNativeProperties(Operation *op, |
309 | IntegerOverflowFlags overflowFlags) { |
310 | if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) |
311 | iface.setOverflowFlags(overflowFlags); |
312 | } |
313 | |
314 | /// Replaces the given operation "op" with a new operation of type "targetOp" |
315 | /// and given operands. |
316 | LogicalResult LLVM::detail::oneToOneRewrite( |
317 | Operation *op, StringRef targetOp, ValueRange operands, |
318 | ArrayRef<NamedAttribute> targetAttrs, |
319 | const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, |
320 | IntegerOverflowFlags overflowFlags) { |
321 | unsigned numResults = op->getNumResults(); |
322 | |
323 | SmallVector<Type> resultTypes; |
324 | if (numResults != 0) { |
325 | resultTypes.push_back( |
326 | Elt: typeConverter.packOperationResults(types: op->getResultTypes())); |
327 | if (!resultTypes.back()) |
328 | return failure(); |
329 | } |
330 | |
331 | // Create the operation through state since we don't know its C++ type. |
332 | Operation *newOp = |
333 | rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, |
334 | resultTypes, targetAttrs); |
335 | |
336 | setNativeProperties(newOp, overflowFlags); |
337 | |
338 | // If the operation produced 0 or 1 result, return them immediately. |
339 | if (numResults == 0) |
340 | return rewriter.eraseOp(op), success(); |
341 | if (numResults == 1) |
342 | return rewriter.replaceOp(op, newValues: newOp->getResult(idx: 0)), success(); |
343 | |
344 | // Otherwise, it had been converted to an operation producing a structure. |
345 | // Extract individual results from the structure and return them as list. |
346 | SmallVector<Value, 4> results; |
347 | results.reserve(N: numResults); |
348 | for (unsigned i = 0; i < numResults; ++i) { |
349 | results.push_back(rewriter.create<LLVM::ExtractValueOp>( |
350 | op->getLoc(), newOp->getResult(0), i)); |
351 | } |
352 | rewriter.replaceOp(op, newValues: results); |
353 | return success(); |
354 | } |
355 | |
356 | LogicalResult LLVM::detail::intrinsicRewrite( |
357 | Operation *op, StringRef intrinsic, ValueRange operands, |
358 | const LLVMTypeConverter &typeConverter, RewriterBase &rewriter) { |
359 | auto loc = op->getLoc(); |
360 | |
361 | if (!llvm::all_of(Range&: operands, P: [](Value value) { |
362 | return LLVM::isCompatibleType(type: value.getType()); |
363 | })) |
364 | return failure(); |
365 | |
366 | unsigned numResults = op->getNumResults(); |
367 | Type resType; |
368 | if (numResults != 0) |
369 | resType = typeConverter.packOperationResults(types: op->getResultTypes()); |
370 | |
371 | auto callIntrOp = rewriter.create<LLVM::CallIntrinsicOp>( |
372 | loc, resType, rewriter.getStringAttr(intrinsic), operands); |
373 | // Propagate attributes. |
374 | callIntrOp->setAttrs(op->getAttrDictionary()); |
375 | |
376 | if (numResults <= 1) { |
377 | // Directly replace the original op. |
378 | rewriter.replaceOp(op, callIntrOp); |
379 | return success(); |
380 | } |
381 | |
382 | // Extract individual results from packed structure and use them as |
383 | // replacements. |
384 | SmallVector<Value, 4> results; |
385 | results.reserve(N: numResults); |
386 | Value intrRes = callIntrOp.getResults(); |
387 | for (unsigned i = 0; i < numResults; ++i) |
388 | results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i)); |
389 | rewriter.replaceOp(op, newValues: results); |
390 | |
391 | return success(); |
392 | } |
393 | |
394 | static unsigned getBitWidth(Type type) { |
395 | if (type.isIntOrFloat()) |
396 | return type.getIntOrFloatBitWidth(); |
397 | |
398 | auto vec = cast<VectorType>(type); |
399 | assert(!vec.isScalable() && "scalable vectors are not supported"); |
400 | return vec.getNumElements() * getBitWidth(vec.getElementType()); |
401 | } |
402 | |
403 | static Value createI32Constant(OpBuilder &builder, Location loc, |
404 | int32_t value) { |
405 | Type i32 = builder.getI32Type(); |
406 | return builder.create<LLVM::ConstantOp>(loc, i32, value); |
407 | } |
408 | |
409 | SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, |
410 | Value src, Type dstType) { |
411 | Type srcType = src.getType(); |
412 | if (srcType == dstType) |
413 | return {src}; |
414 | |
415 | unsigned srcBitWidth = getBitWidth(type: srcType); |
416 | unsigned dstBitWidth = getBitWidth(type: dstType); |
417 | if (srcBitWidth == dstBitWidth) { |
418 | Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src); |
419 | return {cast}; |
420 | } |
421 | |
422 | if (dstBitWidth > srcBitWidth) { |
423 | auto smallerInt = builder.getIntegerType(srcBitWidth); |
424 | if (srcType != smallerInt) |
425 | src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src); |
426 | |
427 | auto largerInt = builder.getIntegerType(dstBitWidth); |
428 | Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src); |
429 | return {res}; |
430 | } |
431 | assert(srcBitWidth % dstBitWidth == 0 && |
432 | "src bit width must be a multiple of dst bit width"); |
433 | int64_t numElements = srcBitWidth / dstBitWidth; |
434 | auto vecType = VectorType::get(numElements, dstType); |
435 | |
436 | src = builder.create<LLVM::BitcastOp>(loc, vecType, src); |
437 | |
438 | SmallVector<Value> res; |
439 | for (auto i : llvm::seq(Size: numElements)) { |
440 | Value idx = createI32Constant(builder, loc, value: i); |
441 | Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx); |
442 | res.emplace_back(Args&: elem); |
443 | } |
444 | |
445 | return res; |
446 | } |
447 | |
448 | Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src, |
449 | Type dstType) { |
450 | assert(!src.empty() && "src range must not be empty"); |
451 | if (src.size() == 1) { |
452 | Value res = src.front(); |
453 | if (res.getType() == dstType) |
454 | return res; |
455 | |
456 | unsigned srcBitWidth = getBitWidth(type: res.getType()); |
457 | unsigned dstBitWidth = getBitWidth(type: dstType); |
458 | if (dstBitWidth < srcBitWidth) { |
459 | auto largerInt = builder.getIntegerType(srcBitWidth); |
460 | if (res.getType() != largerInt) |
461 | res = builder.create<LLVM::BitcastOp>(loc, largerInt, res); |
462 | |
463 | auto smallerInt = builder.getIntegerType(dstBitWidth); |
464 | res = builder.create<LLVM::TruncOp>(loc, smallerInt, res); |
465 | } |
466 | |
467 | if (res.getType() != dstType) |
468 | res = builder.create<LLVM::BitcastOp>(loc, dstType, res); |
469 | |
470 | return res; |
471 | } |
472 | |
473 | int64_t numElements = src.size(); |
474 | auto srcType = VectorType::get(numElements, src.front().getType()); |
475 | Value res = builder.create<LLVM::PoisonOp>(loc, srcType); |
476 | for (auto &&[i, elem] : llvm::enumerate(First&: src)) { |
477 | Value idx = createI32Constant(builder, loc, value: i); |
478 | res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx); |
479 | } |
480 | |
481 | if (res.getType() != dstType) |
482 | res = builder.create<LLVM::BitcastOp>(loc, dstType, res); |
483 | |
484 | return res; |
485 | } |
486 | |
487 | Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc, |
488 | const LLVMTypeConverter &converter, |
489 | MemRefType type, Value memRefDesc, |
490 | ValueRange indices, |
491 | LLVM::GEPNoWrapFlags noWrapFlags) { |
492 | auto [strides, offset] = type.getStridesAndOffset(); |
493 | |
494 | MemRefDescriptor memRefDescriptor(memRefDesc); |
495 | // Use a canonical representation of the start address so that later |
496 | // optimizations have a longer sequence of instructions to CSE. |
497 | // If we don't do that we would sprinkle the memref.offset in various |
498 | // position of the different address computations. |
499 | Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type: type); |
500 | |
501 | LLVM::IntegerOverflowFlags intOverflowFlags = |
502 | LLVM::IntegerOverflowFlags::none; |
503 | if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) { |
504 | intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw; |
505 | } |
506 | if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) { |
507 | intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw; |
508 | } |
509 | |
510 | Type indexType = converter.getIndexType(); |
511 | Value index; |
512 | for (int i = 0, e = indices.size(); i < e; ++i) { |
513 | Value increment = indices[i]; |
514 | if (strides[i] != 1) { // Skip if stride is 1. |
515 | Value stride = |
516 | ShapedType::isDynamic(strides[i]) |
517 | ? memRefDescriptor.stride(builder, loc, i) |
518 | : builder.create<LLVM::ConstantOp>( |
519 | loc, indexType, builder.getIndexAttr(strides[i])); |
520 | increment = |
521 | builder.create<LLVM::MulOp>(loc, increment, stride, intOverflowFlags); |
522 | } |
523 | index = index ? builder.create<LLVM::AddOp>(loc, index, increment, |
524 | intOverflowFlags) |
525 | : increment; |
526 | } |
527 | |
528 | Type elementPtrType = memRefDescriptor.getElementPtrType(); |
529 | return index ? builder.create<LLVM::GEPOp>( |
530 | loc, elementPtrType, |
531 | converter.convertType(type.getElementType()), base, index, |
532 | noWrapFlags) |
533 | : base; |
534 | } |
535 |
Definitions
- ConvertToLLVMPattern
- getTypeConverter
- getDialect
- getIndexType
- getIntPtrType
- getVoidType
- getVoidPtrType
- createIndexAttrConstant
- getStridedElementPtr
- isConvertibleAndHasIdentityMaps
- getElementPtrType
- getMemRefDescriptorSizes
- getSizeInBytes
- getNumElements
- createMemRefDescriptor
- copyUnrankedDescriptors
- setNativeProperties
- oneToOneRewrite
- intrinsicRewrite
- getBitWidth
- createI32Constant
- decomposeValue
- composeValue
Improve your Profiling and Debugging skills
Find out more