1 | //===- MemRefBuilder.cpp - Helper for LLVM MemRef equivalents -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" |
10 | #include "MemRefDescriptor.h" |
11 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
12 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
13 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
14 | #include "mlir/IR/Builders.h" |
15 | #include "mlir/Support/MathExtras.h" |
16 | |
17 | using namespace mlir; |
18 | |
19 | //===----------------------------------------------------------------------===// |
20 | // MemRefDescriptor implementation |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | /// Construct a helper for the given descriptor value. |
24 | MemRefDescriptor::MemRefDescriptor(Value descriptor) |
25 | : StructBuilder(descriptor) { |
26 | assert(value != nullptr && "value cannot be null" ); |
27 | indexType = cast<LLVM::LLVMStructType>(Val: value.getType()) |
28 | .getBody()[kOffsetPosInMemRefDescriptor]; |
29 | } |
30 | |
31 | /// Builds IR creating an `undef` value of the descriptor type. |
32 | MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, |
33 | Type descriptorType) { |
34 | |
35 | Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType); |
36 | return MemRefDescriptor(descriptor); |
37 | } |
38 | |
39 | /// Builds IR creating a MemRef descriptor that represents `type` and |
40 | /// populates it with static shape and stride information extracted from the |
41 | /// type. |
42 | MemRefDescriptor |
43 | MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, |
44 | const LLVMTypeConverter &typeConverter, |
45 | MemRefType type, Value memory) { |
46 | return fromStaticShape(builder, loc, typeConverter, type, memory, memory); |
47 | } |
48 | |
49 | MemRefDescriptor MemRefDescriptor::fromStaticShape( |
50 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
51 | MemRefType type, Value memory, Value alignedMemory) { |
52 | assert(type.hasStaticShape() && "unexpected dynamic shape" ); |
53 | |
54 | // Extract all strides and offsets and verify they are static. |
55 | auto [strides, offset] = getStridesAndOffset(type); |
56 | assert(!ShapedType::isDynamic(offset) && "expected static offset" ); |
57 | assert(!llvm::any_of(strides, ShapedType::isDynamic) && |
58 | "expected static strides" ); |
59 | |
60 | auto convertedType = typeConverter.convertType(type); |
61 | assert(convertedType && "unexpected failure in memref type conversion" ); |
62 | |
63 | auto descr = MemRefDescriptor::undef(builder, loc, descriptorType: convertedType); |
64 | descr.setAllocatedPtr(builder, loc, memory); |
65 | descr.setAlignedPtr(builder, loc, alignedMemory); |
66 | descr.setConstantOffset(builder, loc, offset); |
67 | |
68 | // Fill in sizes and strides |
69 | for (unsigned i = 0, e = type.getRank(); i != e; ++i) { |
70 | descr.setConstantSize(builder, loc, i, type.getDimSize(i)); |
71 | descr.setConstantStride(builder, loc, i, strides[i]); |
72 | } |
73 | return descr; |
74 | } |
75 | |
76 | /// Builds IR extracting the allocated pointer from the descriptor. |
77 | Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { |
78 | return extractPtr(builder, loc, pos: kAllocatedPtrPosInMemRefDescriptor); |
79 | } |
80 | |
81 | /// Builds IR inserting the allocated pointer into the descriptor. |
82 | void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, |
83 | Value ptr) { |
84 | setPtr(builder, loc, pos: kAllocatedPtrPosInMemRefDescriptor, ptr); |
85 | } |
86 | |
87 | /// Builds IR extracting the aligned pointer from the descriptor. |
88 | Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { |
89 | return extractPtr(builder, loc, pos: kAlignedPtrPosInMemRefDescriptor); |
90 | } |
91 | |
92 | /// Builds IR inserting the aligned pointer into the descriptor. |
93 | void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, |
94 | Value ptr) { |
95 | setPtr(builder, loc, pos: kAlignedPtrPosInMemRefDescriptor, ptr); |
96 | } |
97 | |
98 | // Creates a constant Op producing a value of `resultType` from an index-typed |
99 | // integer attribute. |
100 | static Value createIndexAttrConstant(OpBuilder &builder, Location loc, |
101 | Type resultType, int64_t value) { |
102 | return builder.create<LLVM::ConstantOp>(loc, resultType, |
103 | builder.getIndexAttr(value)); |
104 | } |
105 | |
106 | /// Builds IR extracting the offset from the descriptor. |
107 | Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { |
108 | return builder.create<LLVM::ExtractValueOp>(loc, value, |
109 | kOffsetPosInMemRefDescriptor); |
110 | } |
111 | |
112 | /// Builds IR inserting the offset into the descriptor. |
113 | void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, |
114 | Value offset) { |
115 | value = builder.create<LLVM::InsertValueOp>(loc, value, offset, |
116 | kOffsetPosInMemRefDescriptor); |
117 | } |
118 | |
119 | /// Builds IR inserting the offset into the descriptor. |
120 | void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, |
121 | uint64_t offset) { |
122 | setOffset(builder, loc, |
123 | offset: createIndexAttrConstant(builder, loc, resultType: indexType, value: offset)); |
124 | } |
125 | |
126 | /// Builds IR extracting the pos-th size from the descriptor. |
127 | Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { |
128 | return builder.create<LLVM::ExtractValueOp>( |
129 | loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); |
130 | } |
131 | |
132 | Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, |
133 | int64_t rank) { |
134 | auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank); |
135 | |
136 | auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); |
137 | |
138 | // Copy size values to stack-allocated memory. |
139 | auto one = createIndexAttrConstant(builder, loc, resultType: indexType, value: 1); |
140 | auto sizes = builder.create<LLVM::ExtractValueOp>( |
141 | loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor})); |
142 | auto sizesPtr = builder.create<LLVM::AllocaOp>(loc, ptrTy, arrayTy, one, |
143 | /*alignment=*/0); |
144 | builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr); |
145 | |
146 | // Load an return size value of interest. |
147 | auto resultPtr = builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, sizesPtr, |
148 | ArrayRef<LLVM::GEPArg>{0, pos}); |
149 | return builder.create<LLVM::LoadOp>(loc, indexType, resultPtr); |
150 | } |
151 | |
152 | /// Builds IR inserting the pos-th size into the descriptor |
153 | void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, |
154 | Value size) { |
155 | value = builder.create<LLVM::InsertValueOp>( |
156 | loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); |
157 | } |
158 | |
159 | void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, |
160 | unsigned pos, uint64_t size) { |
161 | setSize(builder, loc, pos, |
162 | size: createIndexAttrConstant(builder, loc, resultType: indexType, value: size)); |
163 | } |
164 | |
165 | /// Builds IR extracting the pos-th stride from the descriptor. |
166 | Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { |
167 | return builder.create<LLVM::ExtractValueOp>( |
168 | loc, value, ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos})); |
169 | } |
170 | |
171 | /// Builds IR inserting the pos-th stride into the descriptor |
172 | void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, |
173 | Value stride) { |
174 | value = builder.create<LLVM::InsertValueOp>( |
175 | loc, value, stride, |
176 | ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos})); |
177 | } |
178 | |
179 | void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, |
180 | unsigned pos, uint64_t stride) { |
181 | setStride(builder, loc, pos, |
182 | stride: createIndexAttrConstant(builder, loc, resultType: indexType, value: stride)); |
183 | } |
184 | |
185 | LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { |
186 | return cast<LLVM::LLVMPointerType>( |
187 | cast<LLVM::LLVMStructType>(Val: value.getType()) |
188 | .getBody()[kAlignedPtrPosInMemRefDescriptor]); |
189 | } |
190 | |
191 | Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, |
192 | const LLVMTypeConverter &converter, |
193 | MemRefType type) { |
194 | // When we convert to LLVM, the input memref must have been normalized |
195 | // beforehand. Hence, this call is guaranteed to work. |
196 | auto [strides, offsetCst] = getStridesAndOffset(type); |
197 | |
198 | Value ptr = alignedPtr(builder, loc); |
199 | // For zero offsets, we already have the base pointer. |
200 | if (offsetCst == 0) |
201 | return ptr; |
202 | |
203 | // Otherwise add the offset to the aligned base. |
204 | Type indexType = converter.getIndexType(); |
205 | Value offsetVal = |
206 | ShapedType::isDynamic(offsetCst) |
207 | ? offset(builder, loc) |
208 | : createIndexAttrConstant(builder, loc, indexType, offsetCst); |
209 | Type elementType = converter.convertType(type.getElementType()); |
210 | ptr = builder.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr, |
211 | offsetVal); |
212 | return ptr; |
213 | } |
214 | |
215 | /// Creates a MemRef descriptor structure from a list of individual values |
216 | /// composing that descriptor, in the following order: |
217 | /// - allocated pointer; |
218 | /// - aligned pointer; |
219 | /// - offset; |
220 | /// - <rank> sizes; |
221 | /// - <rank> shapes; |
222 | /// where <rank> is the MemRef rank as provided in `type`. |
223 | Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, |
224 | const LLVMTypeConverter &converter, |
225 | MemRefType type, ValueRange values) { |
226 | Type llvmType = converter.convertType(type); |
227 | auto d = MemRefDescriptor::undef(builder, loc, descriptorType: llvmType); |
228 | |
229 | d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]); |
230 | d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]); |
231 | d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]); |
232 | |
233 | int64_t rank = type.getRank(); |
234 | for (unsigned i = 0; i < rank; ++i) { |
235 | d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]); |
236 | d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]); |
237 | } |
238 | |
239 | return d; |
240 | } |
241 | |
242 | /// Builds IR extracting individual elements of a MemRef descriptor structure |
243 | /// and returning them as `results` list. |
244 | void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, |
245 | MemRefType type, |
246 | SmallVectorImpl<Value> &results) { |
247 | int64_t rank = type.getRank(); |
248 | results.reserve(N: results.size() + getNumUnpackedValues(type: type)); |
249 | |
250 | MemRefDescriptor d(packed); |
251 | results.push_back(Elt: d.allocatedPtr(builder, loc)); |
252 | results.push_back(Elt: d.alignedPtr(builder, loc)); |
253 | results.push_back(Elt: d.offset(builder, loc)); |
254 | for (int64_t i = 0; i < rank; ++i) |
255 | results.push_back(Elt: d.size(builder, loc, pos: i)); |
256 | for (int64_t i = 0; i < rank; ++i) |
257 | results.push_back(Elt: d.stride(builder, loc, pos: i)); |
258 | } |
259 | |
260 | /// Returns the number of non-aggregate values that would be produced by |
261 | /// `unpack`. |
262 | unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { |
263 | // Two pointers, offset, <rank> sizes, <rank> shapes. |
264 | return 3 + 2 * type.getRank(); |
265 | } |
266 | |
267 | //===----------------------------------------------------------------------===// |
268 | // MemRefDescriptorView implementation. |
269 | //===----------------------------------------------------------------------===// |
270 | |
271 | MemRefDescriptorView::MemRefDescriptorView(ValueRange range) |
272 | : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} |
273 | |
274 | Value MemRefDescriptorView::allocatedPtr() { |
275 | return elements[kAllocatedPtrPosInMemRefDescriptor]; |
276 | } |
277 | |
278 | Value MemRefDescriptorView::alignedPtr() { |
279 | return elements[kAlignedPtrPosInMemRefDescriptor]; |
280 | } |
281 | |
282 | Value MemRefDescriptorView::offset() { |
283 | return elements[kOffsetPosInMemRefDescriptor]; |
284 | } |
285 | |
286 | Value MemRefDescriptorView::size(unsigned pos) { |
287 | return elements[kSizePosInMemRefDescriptor + pos]; |
288 | } |
289 | |
290 | Value MemRefDescriptorView::stride(unsigned pos) { |
291 | return elements[kSizePosInMemRefDescriptor + rank + pos]; |
292 | } |
293 | |
294 | //===----------------------------------------------------------------------===// |
295 | // UnrankedMemRefDescriptor implementation |
296 | //===----------------------------------------------------------------------===// |
297 | |
298 | /// Construct a helper for the given descriptor value. |
299 | UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) |
300 | : StructBuilder(descriptor) {} |
301 | |
302 | /// Builds IR creating an `undef` value of the descriptor type. |
303 | UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, |
304 | Location loc, |
305 | Type descriptorType) { |
306 | Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType); |
307 | return UnrankedMemRefDescriptor(descriptor); |
308 | } |
309 | Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const { |
310 | return extractPtr(builder, loc, pos: kRankInUnrankedMemRefDescriptor); |
311 | } |
312 | void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, |
313 | Value v) { |
314 | setPtr(builder, loc, pos: kRankInUnrankedMemRefDescriptor, ptr: v); |
315 | } |
316 | Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, |
317 | Location loc) const { |
318 | return extractPtr(builder, loc, pos: kPtrInUnrankedMemRefDescriptor); |
319 | } |
320 | void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, |
321 | Location loc, Value v) { |
322 | setPtr(builder, loc, pos: kPtrInUnrankedMemRefDescriptor, ptr: v); |
323 | } |
324 | |
325 | /// Builds IR populating an unranked MemRef descriptor structure from a list |
326 | /// of individual constituent values in the following order: |
327 | /// - rank of the memref; |
328 | /// - pointer to the memref descriptor. |
329 | Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, |
330 | const LLVMTypeConverter &converter, |
331 | UnrankedMemRefType type, |
332 | ValueRange values) { |
333 | Type llvmType = converter.convertType(type); |
334 | auto d = UnrankedMemRefDescriptor::undef(builder, loc, descriptorType: llvmType); |
335 | |
336 | d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]); |
337 | d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]); |
338 | return d; |
339 | } |
340 | |
341 | /// Builds IR extracting individual elements that compose an unranked memref |
342 | /// descriptor and returns them as `results` list. |
343 | void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, |
344 | Value packed, |
345 | SmallVectorImpl<Value> &results) { |
346 | UnrankedMemRefDescriptor d(packed); |
347 | results.reserve(N: results.size() + 2); |
348 | results.push_back(Elt: d.rank(builder, loc)); |
349 | results.push_back(Elt: d.memRefDescPtr(builder, loc)); |
350 | } |
351 | |
352 | void UnrankedMemRefDescriptor::computeSizes( |
353 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
354 | ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces, |
355 | SmallVectorImpl<Value> &sizes) { |
356 | if (values.empty()) |
357 | return; |
358 | assert(values.size() == addressSpaces.size() && |
359 | "must provide address space for each descriptor" ); |
360 | // Cache the index type. |
361 | Type indexType = typeConverter.getIndexType(); |
362 | |
363 | // Initialize shared constants. |
364 | Value one = createIndexAttrConstant(builder, loc, resultType: indexType, value: 1); |
365 | Value two = createIndexAttrConstant(builder, loc, resultType: indexType, value: 2); |
366 | Value indexSize = |
367 | createIndexAttrConstant(builder, loc, resultType: indexType, |
368 | value: ceilDiv(lhs: typeConverter.getIndexTypeBitwidth(), rhs: 8)); |
369 | |
370 | sizes.reserve(N: sizes.size() + values.size()); |
371 | for (auto [desc, addressSpace] : llvm::zip(t&: values, u&: addressSpaces)) { |
372 | // Emit IR computing the memory necessary to store the descriptor. This |
373 | // assumes the descriptor to be |
374 | // { type*, type*, index, index[rank], index[rank] } |
375 | // and densely packed, so the total size is |
376 | // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). |
377 | // TODO: consider including the actual size (including eventual padding due |
378 | // to data layout) into the unranked descriptor. |
379 | Value pointerSize = createIndexAttrConstant( |
380 | builder, loc, resultType: indexType, |
381 | value: ceilDiv(lhs: typeConverter.getPointerBitwidth(addressSpace), rhs: 8)); |
382 | Value doublePointerSize = |
383 | builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize); |
384 | |
385 | // (1 + 2 * rank) * sizeof(index) |
386 | Value rank = desc.rank(builder, loc); |
387 | Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank); |
388 | Value doubleRankIncremented = |
389 | builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one); |
390 | Value rankIndexSize = builder.create<LLVM::MulOp>( |
391 | loc, indexType, doubleRankIncremented, indexSize); |
392 | |
393 | // Total allocation size. |
394 | Value allocationSize = builder.create<LLVM::AddOp>( |
395 | loc, indexType, doublePointerSize, rankIndexSize); |
396 | sizes.push_back(Elt: allocationSize); |
397 | } |
398 | } |
399 | |
400 | Value UnrankedMemRefDescriptor::allocatedPtr( |
401 | OpBuilder &builder, Location loc, Value memRefDescPtr, |
402 | LLVM::LLVMPointerType elemPtrType) { |
403 | return builder.create<LLVM::LoadOp>(loc, elemPtrType, memRefDescPtr); |
404 | } |
405 | |
406 | void UnrankedMemRefDescriptor::setAllocatedPtr( |
407 | OpBuilder &builder, Location loc, Value memRefDescPtr, |
408 | LLVM::LLVMPointerType elemPtrType, Value allocatedPtr) { |
409 | builder.create<LLVM::StoreOp>(loc, allocatedPtr, memRefDescPtr); |
410 | } |
411 | |
412 | static std::pair<Value, Type> |
413 | castToElemPtrPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, |
414 | LLVM::LLVMPointerType elemPtrType) { |
415 | auto elemPtrPtrType = LLVM::LLVMPointerType::get(builder.getContext()); |
416 | return {memRefDescPtr, elemPtrPtrType}; |
417 | } |
418 | |
419 | Value UnrankedMemRefDescriptor::alignedPtr( |
420 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
421 | Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { |
422 | auto [elementPtrPtr, elemPtrPtrType] = |
423 | castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); |
424 | |
425 | Value alignedGep = |
426 | builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, |
427 | elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); |
428 | return builder.create<LLVM::LoadOp>(loc, elemPtrType, alignedGep); |
429 | } |
430 | |
431 | void UnrankedMemRefDescriptor::setAlignedPtr( |
432 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
433 | Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr) { |
434 | auto [elementPtrPtr, elemPtrPtrType] = |
435 | castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); |
436 | |
437 | Value alignedGep = |
438 | builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, |
439 | elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); |
440 | builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep); |
441 | } |
442 | |
443 | Value UnrankedMemRefDescriptor::offsetBasePtr( |
444 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
445 | Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { |
446 | auto [elementPtrPtr, elemPtrPtrType] = |
447 | castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); |
448 | |
449 | return builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, |
450 | elementPtrPtr, ArrayRef<LLVM::GEPArg>{2}); |
451 | } |
452 | |
453 | Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, |
454 | const LLVMTypeConverter &typeConverter, |
455 | Value memRefDescPtr, |
456 | LLVM::LLVMPointerType elemPtrType) { |
457 | Value offsetPtr = |
458 | offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType: elemPtrType); |
459 | return builder.create<LLVM::LoadOp>(loc, typeConverter.getIndexType(), |
460 | offsetPtr); |
461 | } |
462 | |
463 | void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, |
464 | const LLVMTypeConverter &typeConverter, |
465 | Value memRefDescPtr, |
466 | LLVM::LLVMPointerType elemPtrType, |
467 | Value offset) { |
468 | Value offsetPtr = |
469 | offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType: elemPtrType); |
470 | builder.create<LLVM::StoreOp>(loc, offset, offsetPtr); |
471 | } |
472 | |
473 | Value UnrankedMemRefDescriptor::sizeBasePtr( |
474 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
475 | Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { |
476 | Type indexTy = typeConverter.getIndexType(); |
477 | Type structTy = LLVM::LLVMStructType::getLiteral( |
478 | context: indexTy.getContext(), types: {elemPtrType, elemPtrType, indexTy, indexTy}); |
479 | auto resultType = LLVM::LLVMPointerType::get(builder.getContext()); |
480 | return builder.create<LLVM::GEPOp>(loc, resultType, structTy, memRefDescPtr, |
481 | ArrayRef<LLVM::GEPArg>{0, 3}); |
482 | } |
483 | |
484 | Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, |
485 | const LLVMTypeConverter &typeConverter, |
486 | Value sizeBasePtr, Value index) { |
487 | |
488 | Type indexTy = typeConverter.getIndexType(); |
489 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
490 | |
491 | Value sizeStoreGep = |
492 | builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index); |
493 | return builder.create<LLVM::LoadOp>(loc, indexTy, sizeStoreGep); |
494 | } |
495 | |
496 | void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, |
497 | const LLVMTypeConverter &typeConverter, |
498 | Value sizeBasePtr, Value index, |
499 | Value size) { |
500 | Type indexTy = typeConverter.getIndexType(); |
501 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
502 | |
503 | Value sizeStoreGep = |
504 | builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index); |
505 | builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep); |
506 | } |
507 | |
508 | Value UnrankedMemRefDescriptor::strideBasePtr( |
509 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
510 | Value sizeBasePtr, Value rank) { |
511 | Type indexTy = typeConverter.getIndexType(); |
512 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
513 | |
514 | return builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, rank); |
515 | } |
516 | |
517 | Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, |
518 | const LLVMTypeConverter &typeConverter, |
519 | Value strideBasePtr, Value index, |
520 | Value stride) { |
521 | Type indexTy = typeConverter.getIndexType(); |
522 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
523 | |
524 | Value strideStoreGep = |
525 | builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index); |
526 | return builder.create<LLVM::LoadOp>(loc, indexTy, strideStoreGep); |
527 | } |
528 | |
529 | void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, |
530 | const LLVMTypeConverter &typeConverter, |
531 | Value strideBasePtr, Value index, |
532 | Value stride) { |
533 | Type indexTy = typeConverter.getIndexType(); |
534 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
535 | |
536 | Value strideStoreGep = |
537 | builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index); |
538 | builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep); |
539 | } |
540 | |