| 1 | //===- MemRefBuilder.cpp - Helper for LLVM MemRef equivalents -------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" |
| 10 | #include "MemRefDescriptor.h" |
| 11 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| 12 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 13 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| 14 | #include "mlir/IR/Builders.h" |
| 15 | #include "llvm/Support/MathExtras.h" |
| 16 | |
| 17 | using namespace mlir; |
| 18 | |
| 19 | //===----------------------------------------------------------------------===// |
| 20 | // MemRefDescriptor implementation |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | |
| 23 | /// Construct a helper for the given descriptor value. |
| 24 | MemRefDescriptor::MemRefDescriptor(Value descriptor) |
| 25 | : StructBuilder(descriptor) { |
| 26 | assert(value != nullptr && "value cannot be null" ); |
| 27 | indexType = cast<LLVM::LLVMStructType>(value.getType()) |
| 28 | .getBody()[kOffsetPosInMemRefDescriptor]; |
| 29 | } |
| 30 | |
| 31 | /// Builds IR creating an `undef` value of the descriptor type. |
| 32 | MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc, |
| 33 | Type descriptorType) { |
| 34 | |
| 35 | Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType); |
| 36 | return MemRefDescriptor(descriptor); |
| 37 | } |
| 38 | |
| 39 | /// Builds IR creating a MemRef descriptor that represents `type` and |
| 40 | /// populates it with static shape and stride information extracted from the |
| 41 | /// type. |
| 42 | MemRefDescriptor |
| 43 | MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, |
| 44 | const LLVMTypeConverter &typeConverter, |
| 45 | MemRefType type, Value memory) { |
| 46 | return fromStaticShape(builder, loc, typeConverter, type, memory, memory); |
| 47 | } |
| 48 | |
| 49 | MemRefDescriptor MemRefDescriptor::fromStaticShape( |
| 50 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
| 51 | MemRefType type, Value memory, Value alignedMemory) { |
| 52 | assert(type.hasStaticShape() && "unexpected dynamic shape" ); |
| 53 | |
| 54 | // Extract all strides and offsets and verify they are static. |
| 55 | auto [strides, offset] = type.getStridesAndOffset(); |
| 56 | assert(!ShapedType::isDynamic(offset) && "expected static offset" ); |
| 57 | assert(!llvm::any_of(strides, ShapedType::isDynamic) && |
| 58 | "expected static strides" ); |
| 59 | |
| 60 | auto convertedType = typeConverter.convertType(type); |
| 61 | assert(convertedType && "unexpected failure in memref type conversion" ); |
| 62 | |
| 63 | auto descr = MemRefDescriptor::poison(builder, loc, descriptorType: convertedType); |
| 64 | descr.setAllocatedPtr(builder, loc, memory); |
| 65 | descr.setAlignedPtr(builder, loc, alignedMemory); |
| 66 | descr.setConstantOffset(builder, loc, offset); |
| 67 | |
| 68 | // Fill in sizes and strides |
| 69 | for (unsigned i = 0, e = type.getRank(); i != e; ++i) { |
| 70 | descr.setConstantSize(builder, loc, i, type.getDimSize(i)); |
| 71 | descr.setConstantStride(builder, loc, i, strides[i]); |
| 72 | } |
| 73 | return descr; |
| 74 | } |
| 75 | |
| 76 | /// Builds IR extracting the allocated pointer from the descriptor. |
| 77 | Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { |
| 78 | return extractPtr(builder, loc, pos: kAllocatedPtrPosInMemRefDescriptor); |
| 79 | } |
| 80 | |
| 81 | /// Builds IR inserting the allocated pointer into the descriptor. |
| 82 | void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, |
| 83 | Value ptr) { |
| 84 | setPtr(builder, loc, pos: kAllocatedPtrPosInMemRefDescriptor, ptr); |
| 85 | } |
| 86 | |
| 87 | /// Builds IR extracting the aligned pointer from the descriptor. |
| 88 | Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { |
| 89 | return extractPtr(builder, loc, pos: kAlignedPtrPosInMemRefDescriptor); |
| 90 | } |
| 91 | |
| 92 | /// Builds IR inserting the aligned pointer into the descriptor. |
| 93 | void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, |
| 94 | Value ptr) { |
| 95 | setPtr(builder, loc, pos: kAlignedPtrPosInMemRefDescriptor, ptr); |
| 96 | } |
| 97 | |
| 98 | // Creates a constant Op producing a value of `resultType` from an index-typed |
| 99 | // integer attribute. |
| 100 | static Value createIndexAttrConstant(OpBuilder &builder, Location loc, |
| 101 | Type resultType, int64_t value) { |
| 102 | return builder.create<LLVM::ConstantOp>(loc, resultType, |
| 103 | builder.getIndexAttr(value)); |
| 104 | } |
| 105 | |
| 106 | /// Builds IR extracting the offset from the descriptor. |
| 107 | Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { |
| 108 | return builder.create<LLVM::ExtractValueOp>(loc, value, |
| 109 | kOffsetPosInMemRefDescriptor); |
| 110 | } |
| 111 | |
| 112 | /// Builds IR inserting the offset into the descriptor. |
| 113 | void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, |
| 114 | Value offset) { |
| 115 | value = builder.create<LLVM::InsertValueOp>(loc, value, offset, |
| 116 | kOffsetPosInMemRefDescriptor); |
| 117 | } |
| 118 | |
| 119 | /// Builds IR inserting the offset into the descriptor. |
| 120 | void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, |
| 121 | uint64_t offset) { |
| 122 | setOffset(builder, loc, |
| 123 | offset: createIndexAttrConstant(builder, loc, resultType: indexType, value: offset)); |
| 124 | } |
| 125 | |
| 126 | /// Builds IR extracting the pos-th size from the descriptor. |
| 127 | Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { |
| 128 | return builder.create<LLVM::ExtractValueOp>( |
| 129 | loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); |
| 130 | } |
| 131 | |
| 132 | Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, |
| 133 | int64_t rank) { |
| 134 | auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank); |
| 135 | |
| 136 | auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); |
| 137 | |
| 138 | // Copy size values to stack-allocated memory. |
| 139 | auto one = createIndexAttrConstant(builder, loc, resultType: indexType, value: 1); |
| 140 | auto sizes = builder.create<LLVM::ExtractValueOp>( |
| 141 | loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor})); |
| 142 | auto sizesPtr = builder.create<LLVM::AllocaOp>(loc, ptrTy, arrayTy, one, |
| 143 | /*alignment=*/0); |
| 144 | builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr); |
| 145 | |
| 146 | // Load an return size value of interest. |
| 147 | auto resultPtr = builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, sizesPtr, |
| 148 | ArrayRef<LLVM::GEPArg>{0, pos}); |
| 149 | return builder.create<LLVM::LoadOp>(loc, indexType, resultPtr); |
| 150 | } |
| 151 | |
| 152 | /// Builds IR inserting the pos-th size into the descriptor |
| 153 | void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, |
| 154 | Value size) { |
| 155 | value = builder.create<LLVM::InsertValueOp>( |
| 156 | loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); |
| 157 | } |
| 158 | |
| 159 | void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, |
| 160 | unsigned pos, uint64_t size) { |
| 161 | setSize(builder, loc, pos, |
| 162 | size: createIndexAttrConstant(builder, loc, resultType: indexType, value: size)); |
| 163 | } |
| 164 | |
| 165 | /// Builds IR extracting the pos-th stride from the descriptor. |
| 166 | Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { |
| 167 | return builder.create<LLVM::ExtractValueOp>( |
| 168 | loc, value, ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos})); |
| 169 | } |
| 170 | |
| 171 | /// Builds IR inserting the pos-th stride into the descriptor |
| 172 | void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, |
| 173 | Value stride) { |
| 174 | value = builder.create<LLVM::InsertValueOp>( |
| 175 | loc, value, stride, |
| 176 | ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos})); |
| 177 | } |
| 178 | |
| 179 | void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, |
| 180 | unsigned pos, uint64_t stride) { |
| 181 | setStride(builder, loc, pos, |
| 182 | stride: createIndexAttrConstant(builder, loc, resultType: indexType, value: stride)); |
| 183 | } |
| 184 | |
| 185 | LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { |
| 186 | return cast<LLVM::LLVMPointerType>( |
| 187 | cast<LLVM::LLVMStructType>(value.getType()) |
| 188 | .getBody()[kAlignedPtrPosInMemRefDescriptor]); |
| 189 | } |
| 190 | |
| 191 | Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, |
| 192 | const LLVMTypeConverter &converter, |
| 193 | MemRefType type) { |
| 194 | // When we convert to LLVM, the input memref must have been normalized |
| 195 | // beforehand. Hence, this call is guaranteed to work. |
| 196 | auto [strides, offsetCst] = type.getStridesAndOffset(); |
| 197 | |
| 198 | Value ptr = alignedPtr(builder, loc); |
| 199 | // For zero offsets, we already have the base pointer. |
| 200 | if (offsetCst == 0) |
| 201 | return ptr; |
| 202 | |
| 203 | // Otherwise add the offset to the aligned base. |
| 204 | Type indexType = converter.getIndexType(); |
| 205 | Value offsetVal = |
| 206 | ShapedType::isDynamic(offsetCst) |
| 207 | ? offset(builder, loc) |
| 208 | : createIndexAttrConstant(builder, loc, indexType, offsetCst); |
| 209 | Type elementType = converter.convertType(type.getElementType()); |
| 210 | ptr = builder.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr, |
| 211 | offsetVal); |
| 212 | return ptr; |
| 213 | } |
| 214 | |
| 215 | /// Creates a MemRef descriptor structure from a list of individual values |
| 216 | /// composing that descriptor, in the following order: |
| 217 | /// - allocated pointer; |
| 218 | /// - aligned pointer; |
| 219 | /// - offset; |
| 220 | /// - <rank> sizes; |
| 221 | /// - <rank> strides; |
| 222 | /// where <rank> is the MemRef rank as provided in `type`. |
| 223 | Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, |
| 224 | const LLVMTypeConverter &converter, |
| 225 | MemRefType type, ValueRange values) { |
| 226 | Type llvmType = converter.convertType(type); |
| 227 | auto d = MemRefDescriptor::poison(builder, loc, descriptorType: llvmType); |
| 228 | |
| 229 | d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]); |
| 230 | d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]); |
| 231 | d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]); |
| 232 | |
| 233 | int64_t rank = type.getRank(); |
| 234 | for (unsigned i = 0; i < rank; ++i) { |
| 235 | d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]); |
| 236 | d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]); |
| 237 | } |
| 238 | |
| 239 | return d; |
| 240 | } |
| 241 | |
| 242 | /// Builds IR extracting individual elements of a MemRef descriptor structure |
| 243 | /// and returning them as `results` list. |
| 244 | void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, |
| 245 | MemRefType type, |
| 246 | SmallVectorImpl<Value> &results) { |
| 247 | int64_t rank = type.getRank(); |
| 248 | results.reserve(N: results.size() + getNumUnpackedValues(type: type)); |
| 249 | |
| 250 | MemRefDescriptor d(packed); |
| 251 | results.push_back(Elt: d.allocatedPtr(builder, loc)); |
| 252 | results.push_back(Elt: d.alignedPtr(builder, loc)); |
| 253 | results.push_back(Elt: d.offset(builder, loc)); |
| 254 | for (int64_t i = 0; i < rank; ++i) |
| 255 | results.push_back(Elt: d.size(builder, loc, pos: i)); |
| 256 | for (int64_t i = 0; i < rank; ++i) |
| 257 | results.push_back(Elt: d.stride(builder, loc, pos: i)); |
| 258 | } |
| 259 | |
| 260 | /// Returns the number of non-aggregate values that would be produced by |
| 261 | /// `unpack`. |
| 262 | unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { |
| 263 | // Two pointers, offset, <rank> sizes, <rank> strides. |
| 264 | return 3 + 2 * type.getRank(); |
| 265 | } |
| 266 | |
| 267 | //===----------------------------------------------------------------------===// |
| 268 | // MemRefDescriptorView implementation. |
| 269 | //===----------------------------------------------------------------------===// |
| 270 | |
| 271 | MemRefDescriptorView::MemRefDescriptorView(ValueRange range) |
| 272 | : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} |
| 273 | |
| 274 | Value MemRefDescriptorView::allocatedPtr() { |
| 275 | return elements[kAllocatedPtrPosInMemRefDescriptor]; |
| 276 | } |
| 277 | |
| 278 | Value MemRefDescriptorView::alignedPtr() { |
| 279 | return elements[kAlignedPtrPosInMemRefDescriptor]; |
| 280 | } |
| 281 | |
| 282 | Value MemRefDescriptorView::offset() { |
| 283 | return elements[kOffsetPosInMemRefDescriptor]; |
| 284 | } |
| 285 | |
| 286 | Value MemRefDescriptorView::size(unsigned pos) { |
| 287 | return elements[kSizePosInMemRefDescriptor + pos]; |
| 288 | } |
| 289 | |
| 290 | Value MemRefDescriptorView::stride(unsigned pos) { |
| 291 | return elements[kSizePosInMemRefDescriptor + rank + pos]; |
| 292 | } |
| 293 | |
| 294 | //===----------------------------------------------------------------------===// |
| 295 | // UnrankedMemRefDescriptor implementation |
| 296 | //===----------------------------------------------------------------------===// |
| 297 | |
| 298 | /// Construct a helper for the given descriptor value. |
| 299 | UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) |
| 300 | : StructBuilder(descriptor) {} |
| 301 | |
| 302 | /// Builds IR creating an `undef` value of the descriptor type. |
| 303 | UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder, |
| 304 | Location loc, |
| 305 | Type descriptorType) { |
| 306 | Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType); |
| 307 | return UnrankedMemRefDescriptor(descriptor); |
| 308 | } |
| 309 | Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const { |
| 310 | return extractPtr(builder, loc, pos: kRankInUnrankedMemRefDescriptor); |
| 311 | } |
| 312 | void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, |
| 313 | Value v) { |
| 314 | setPtr(builder, loc, pos: kRankInUnrankedMemRefDescriptor, ptr: v); |
| 315 | } |
| 316 | Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, |
| 317 | Location loc) const { |
| 318 | return extractPtr(builder, loc, pos: kPtrInUnrankedMemRefDescriptor); |
| 319 | } |
| 320 | void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, |
| 321 | Location loc, Value v) { |
| 322 | setPtr(builder, loc, pos: kPtrInUnrankedMemRefDescriptor, ptr: v); |
| 323 | } |
| 324 | |
| 325 | /// Builds IR populating an unranked MemRef descriptor structure from a list |
| 326 | /// of individual constituent values in the following order: |
| 327 | /// - rank of the memref; |
| 328 | /// - pointer to the memref descriptor. |
| 329 | Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, |
| 330 | const LLVMTypeConverter &converter, |
| 331 | UnrankedMemRefType type, |
| 332 | ValueRange values) { |
| 333 | Type llvmType = converter.convertType(type); |
| 334 | auto d = UnrankedMemRefDescriptor::poison(builder, loc, descriptorType: llvmType); |
| 335 | |
| 336 | d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]); |
| 337 | d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]); |
| 338 | return d; |
| 339 | } |
| 340 | |
| 341 | /// Builds IR extracting individual elements that compose an unranked memref |
| 342 | /// descriptor and returns them as `results` list. |
| 343 | void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, |
| 344 | Value packed, |
| 345 | SmallVectorImpl<Value> &results) { |
| 346 | UnrankedMemRefDescriptor d(packed); |
| 347 | results.reserve(N: results.size() + 2); |
| 348 | results.push_back(Elt: d.rank(builder, loc)); |
| 349 | results.push_back(Elt: d.memRefDescPtr(builder, loc)); |
| 350 | } |
| 351 | |
| 352 | void UnrankedMemRefDescriptor::computeSizes( |
| 353 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
| 354 | ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces, |
| 355 | SmallVectorImpl<Value> &sizes) { |
| 356 | if (values.empty()) |
| 357 | return; |
| 358 | assert(values.size() == addressSpaces.size() && |
| 359 | "must provide address space for each descriptor" ); |
| 360 | // Cache the index type. |
| 361 | Type indexType = typeConverter.getIndexType(); |
| 362 | |
| 363 | // Initialize shared constants. |
| 364 | Value one = createIndexAttrConstant(builder, loc, resultType: indexType, value: 1); |
| 365 | Value two = createIndexAttrConstant(builder, loc, resultType: indexType, value: 2); |
| 366 | Value indexSize = createIndexAttrConstant( |
| 367 | builder, loc, resultType: indexType, |
| 368 | value: llvm::divideCeil(Numerator: typeConverter.getIndexTypeBitwidth(), Denominator: 8)); |
| 369 | |
| 370 | sizes.reserve(N: sizes.size() + values.size()); |
| 371 | for (auto [desc, addressSpace] : llvm::zip(t&: values, u&: addressSpaces)) { |
| 372 | // Emit IR computing the memory necessary to store the descriptor. This |
| 373 | // assumes the descriptor to be |
| 374 | // { type*, type*, index, index[rank], index[rank] } |
| 375 | // and densely packed, so the total size is |
| 376 | // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). |
| 377 | // TODO: consider including the actual size (including eventual padding due |
| 378 | // to data layout) into the unranked descriptor. |
| 379 | Value pointerSize = createIndexAttrConstant( |
| 380 | builder, loc, resultType: indexType, |
| 381 | value: llvm::divideCeil(Numerator: typeConverter.getPointerBitwidth(addressSpace), Denominator: 8)); |
| 382 | Value doublePointerSize = |
| 383 | builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize); |
| 384 | |
| 385 | // (1 + 2 * rank) * sizeof(index) |
| 386 | Value rank = desc.rank(builder, loc); |
| 387 | Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank); |
| 388 | Value doubleRankIncremented = |
| 389 | builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one); |
| 390 | Value rankIndexSize = builder.create<LLVM::MulOp>( |
| 391 | loc, indexType, doubleRankIncremented, indexSize); |
| 392 | |
| 393 | // Total allocation size. |
| 394 | Value allocationSize = builder.create<LLVM::AddOp>( |
| 395 | loc, indexType, doublePointerSize, rankIndexSize); |
| 396 | sizes.push_back(Elt: allocationSize); |
| 397 | } |
| 398 | } |
| 399 | |
| 400 | Value UnrankedMemRefDescriptor::allocatedPtr( |
| 401 | OpBuilder &builder, Location loc, Value memRefDescPtr, |
| 402 | LLVM::LLVMPointerType elemPtrType) { |
| 403 | return builder.create<LLVM::LoadOp>(loc, elemPtrType, memRefDescPtr); |
| 404 | } |
| 405 | |
| 406 | void UnrankedMemRefDescriptor::setAllocatedPtr( |
| 407 | OpBuilder &builder, Location loc, Value memRefDescPtr, |
| 408 | LLVM::LLVMPointerType elemPtrType, Value allocatedPtr) { |
| 409 | builder.create<LLVM::StoreOp>(loc, allocatedPtr, memRefDescPtr); |
| 410 | } |
| 411 | |
| 412 | static std::pair<Value, Type> |
| 413 | castToElemPtrPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, |
| 414 | LLVM::LLVMPointerType elemPtrType) { |
| 415 | auto elemPtrPtrType = LLVM::LLVMPointerType::get(builder.getContext()); |
| 416 | return {memRefDescPtr, elemPtrPtrType}; |
| 417 | } |
| 418 | |
| 419 | Value UnrankedMemRefDescriptor::alignedPtr( |
| 420 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
| 421 | Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { |
| 422 | auto [elementPtrPtr, elemPtrPtrType] = |
| 423 | castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); |
| 424 | |
| 425 | Value alignedGep = |
| 426 | builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, |
| 427 | elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); |
| 428 | return builder.create<LLVM::LoadOp>(loc, elemPtrType, alignedGep); |
| 429 | } |
| 430 | |
| 431 | void UnrankedMemRefDescriptor::setAlignedPtr( |
| 432 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
| 433 | Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value alignedPtr) { |
| 434 | auto [elementPtrPtr, elemPtrPtrType] = |
| 435 | castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); |
| 436 | |
| 437 | Value alignedGep = |
| 438 | builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, |
| 439 | elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); |
| 440 | builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep); |
| 441 | } |
| 442 | |
| 443 | Value UnrankedMemRefDescriptor::offsetBasePtr( |
| 444 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
| 445 | Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { |
| 446 | auto [elementPtrPtr, elemPtrPtrType] = |
| 447 | castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); |
| 448 | |
| 449 | return builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, |
| 450 | elementPtrPtr, ArrayRef<LLVM::GEPArg>{2}); |
| 451 | } |
| 452 | |
| 453 | Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, |
| 454 | const LLVMTypeConverter &typeConverter, |
| 455 | Value memRefDescPtr, |
| 456 | LLVM::LLVMPointerType elemPtrType) { |
| 457 | Value offsetPtr = |
| 458 | offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType: elemPtrType); |
| 459 | return builder.create<LLVM::LoadOp>(loc, typeConverter.getIndexType(), |
| 460 | offsetPtr); |
| 461 | } |
| 462 | |
| 463 | void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, |
| 464 | const LLVMTypeConverter &typeConverter, |
| 465 | Value memRefDescPtr, |
| 466 | LLVM::LLVMPointerType elemPtrType, |
| 467 | Value offset) { |
| 468 | Value offsetPtr = |
| 469 | offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType: elemPtrType); |
| 470 | builder.create<LLVM::StoreOp>(loc, offset, offsetPtr); |
| 471 | } |
| 472 | |
| 473 | Value UnrankedMemRefDescriptor::sizeBasePtr( |
| 474 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
| 475 | Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { |
| 476 | Type indexTy = typeConverter.getIndexType(); |
| 477 | Type structTy = LLVM::LLVMStructType::getLiteral( |
| 478 | indexTy.getContext(), {elemPtrType, elemPtrType, indexTy, indexTy}); |
| 479 | auto resultType = LLVM::LLVMPointerType::get(builder.getContext()); |
| 480 | return builder.create<LLVM::GEPOp>(loc, resultType, structTy, memRefDescPtr, |
| 481 | ArrayRef<LLVM::GEPArg>{0, 3}); |
| 482 | } |
| 483 | |
| 484 | Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, |
| 485 | const LLVMTypeConverter &typeConverter, |
| 486 | Value sizeBasePtr, Value index) { |
| 487 | |
| 488 | Type indexTy = typeConverter.getIndexType(); |
| 489 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
| 490 | |
| 491 | Value sizeStoreGep = |
| 492 | builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index); |
| 493 | return builder.create<LLVM::LoadOp>(loc, indexTy, sizeStoreGep); |
| 494 | } |
| 495 | |
| 496 | void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, |
| 497 | const LLVMTypeConverter &typeConverter, |
| 498 | Value sizeBasePtr, Value index, |
| 499 | Value size) { |
| 500 | Type indexTy = typeConverter.getIndexType(); |
| 501 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
| 502 | |
| 503 | Value sizeStoreGep = |
| 504 | builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index); |
| 505 | builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep); |
| 506 | } |
| 507 | |
| 508 | Value UnrankedMemRefDescriptor::strideBasePtr( |
| 509 | OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, |
| 510 | Value sizeBasePtr, Value rank) { |
| 511 | Type indexTy = typeConverter.getIndexType(); |
| 512 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
| 513 | |
| 514 | return builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, rank); |
| 515 | } |
| 516 | |
| 517 | Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, |
| 518 | const LLVMTypeConverter &typeConverter, |
| 519 | Value strideBasePtr, Value index, |
| 520 | Value stride) { |
| 521 | Type indexTy = typeConverter.getIndexType(); |
| 522 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
| 523 | |
| 524 | Value strideStoreGep = |
| 525 | builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index); |
| 526 | return builder.create<LLVM::LoadOp>(loc, indexTy, strideStoreGep); |
| 527 | } |
| 528 | |
| 529 | void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, |
| 530 | const LLVMTypeConverter &typeConverter, |
| 531 | Value strideBasePtr, Value index, |
| 532 | Value stride) { |
| 533 | Type indexTy = typeConverter.getIndexType(); |
| 534 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
| 535 | |
| 536 | Value strideStoreGep = |
| 537 | builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index); |
| 538 | builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep); |
| 539 | } |
| 540 | |