| 1 | //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // A pass that converts sparse tensor types and primitives to actual compiler |
| 10 | // visible buffers and actual compiler IR that implements these primitives on |
| 11 | // the selected sparse tensor storage schemes. This pass provides an alternative |
| 12 | // to the SparseTensorConversion pass, eliminating the dependence on a runtime |
| 13 | // support library (other than for file I/O), and providing many more |
| 14 | // opportunities for subsequent compiler optimization of the generated code. |
| 15 | // |
| 16 | //===----------------------------------------------------------------------===// |
| 17 | |
| 18 | #include "Utils/CodegenUtils.h" |
| 19 | #include "Utils/SparseTensorDescriptor.h" |
| 20 | |
| 21 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 22 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 23 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 24 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 25 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 26 | #include "mlir/Dialect/SparseTensor/IR/Enums.h" |
| 27 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 28 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| 29 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 30 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 31 | #include "mlir/Transforms/DialectConversion.h" |
| 32 | |
| 33 | #include <optional> |
| 34 | |
| 35 | using namespace mlir; |
| 36 | using namespace mlir::sparse_tensor; |
| 37 | |
| 38 | //===----------------------------------------------------------------------===// |
| 39 | // Helper methods. |
| 40 | //===----------------------------------------------------------------------===// |
| 41 | |
| 42 | /// Flatten the given value ranges into a single vector of values. |
| 43 | static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { |
| 44 | SmallVector<Value> result; |
| 45 | for (const auto &vals : values) |
| 46 | llvm::append_range(C&: result, R: vals); |
| 47 | return result; |
| 48 | } |
| 49 | |
| 50 | /// Generates a load with proper `index` typing. |
| 51 | static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) { |
| 52 | idx = genCast(builder, loc, idx, builder.getIndexType()); |
| 53 | return builder.create<memref::LoadOp>(loc, mem, idx); |
| 54 | } |
| 55 | |
| 56 | /// Generates a store with proper `index` typing and proper value. |
| 57 | static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, |
| 58 | Value idx) { |
| 59 | idx = genCast(builder, loc, idx, builder.getIndexType()); |
| 60 | val = genCast(builder, loc, val, |
| 61 | cast<ShapedType>(mem.getType()).getElementType()); |
| 62 | builder.create<memref::StoreOp>(loc, val, mem, idx); |
| 63 | } |
| 64 | |
| 65 | /// Creates a straightforward counting for-loop. |
| 66 | static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, |
| 67 | MutableArrayRef<Value> fields, |
| 68 | Value lower = Value()) { |
| 69 | Type indexType = builder.getIndexType(); |
| 70 | if (!lower) |
| 71 | lower = constantZero(builder, loc, tp: indexType); |
| 72 | Value one = constantOne(builder, loc, tp: indexType); |
| 73 | scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields); |
| 74 | for (unsigned i = 0, e = fields.size(); i < e; i++) |
| 75 | fields[i] = forOp.getRegionIterArg(i); |
| 76 | builder.setInsertionPointToStart(forOp.getBody()); |
| 77 | return forOp; |
| 78 | } |
| 79 | |
| 80 | /// Creates a push back operation. |
| 81 | static void createPushback(OpBuilder &builder, Location loc, |
| 82 | MutSparseTensorDescriptor desc, |
| 83 | SparseTensorFieldKind kind, std::optional<Level> lvl, |
| 84 | Value value, Value repeat = Value()) { |
| 85 | Type etp = desc.getMemRefElementType(kind, lvl); |
| 86 | Value field = desc.getMemRefField(kind, lvl); |
| 87 | StorageSpecifierKind specFieldKind = toSpecifierKind(kind); |
| 88 | |
| 89 | auto pushBackOp = builder.create<PushBackOp>( |
| 90 | loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field, |
| 91 | genCast(builder, loc, value, etp), repeat); |
| 92 | |
| 93 | desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer()); |
| 94 | desc.setSpecifierField(builder, loc, specFieldKind, lvl, |
| 95 | pushBackOp.getNewSize()); |
| 96 | } |
| 97 | |
| 98 | /// Generates code that allocates a sparse storage scheme for given rank. |
| 99 | static void allocSchemeForRank(OpBuilder &builder, Location loc, |
| 100 | MutSparseTensorDescriptor desc, Level startLvl) { |
| 101 | const SparseTensorType stt(desc.getRankedTensorType()); |
| 102 | Value linear = constantIndex(builder, loc, i: 1); |
| 103 | const Level lvlRank = stt.getLvlRank(); |
| 104 | for (Level lvl = startLvl; lvl < lvlRank; lvl++) { |
| 105 | const auto lt = stt.getLvlType(l: lvl); |
| 106 | if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { |
| 107 | // Append linear x positions, initialized to zero. Since each compressed |
| 108 | // dimension initially already has a single zero entry, this maintains |
| 109 | // the desired "linear + 1" length property at all times. For loose |
| 110 | // compression, we multiply linear by two in order to append both the |
| 111 | // lo/hi positions. |
| 112 | Value posZero = constantZero(builder, loc, tp: stt.getPosType()); |
| 113 | if (isLooseCompressedLT(lt)) { |
| 114 | Value two = constantIndex(builder, loc, i: 2); |
| 115 | linear = builder.create<arith::MulIOp>(loc, linear, two); |
| 116 | } |
| 117 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::PosMemRef, lvl, |
| 118 | /*value=*/posZero, /*repeat=*/linear); |
| 119 | return; |
| 120 | } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { |
| 121 | return; // nothing to do |
| 122 | } |
| 123 | // Keep compounding the size, but nothing needs to be initialized |
| 124 | // at this level. We will eventually reach a compressed level or |
| 125 | // otherwise the values array for the from-here "all-dense" case. |
| 126 | assert(isDenseLT(lt)); |
| 127 | Value size = desc.getLvlSize(builder, loc, lvl); |
| 128 | linear = builder.create<arith::MulIOp>(loc, linear, size); |
| 129 | } |
| 130 | // Reached values array so prepare for an insertion. |
| 131 | Value valZero = constantZero(builder, loc, tp: stt.getElementType()); |
| 132 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::ValMemRef, |
| 133 | lvl: std::nullopt, /*value=*/valZero, /*repeat=*/linear); |
| 134 | } |
| 135 | |
| 136 | /// Creates allocation operation. |
| 137 | static Value createAllocation(OpBuilder &builder, Location loc, |
| 138 | MemRefType memRefType, Value sz, |
| 139 | bool enableInit) { |
| 140 | Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz); |
| 141 | Type elemType = memRefType.getElementType(); |
| 142 | if (enableInit) { |
| 143 | Value fillValue = constantZero(builder, loc, tp: elemType); |
| 144 | builder.create<linalg::FillOp>(loc, fillValue, buffer); |
| 145 | } |
| 146 | return buffer; |
| 147 | } |
| 148 | |
| 149 | /// Creates the dim sizes array, filling in from dynamic sizes. |
| 150 | static void createDimSizes(OpBuilder &builder, Location loc, |
| 151 | SparseTensorType stt, ValueRange dynSizes, |
| 152 | /*out*/ SmallVectorImpl<Value> &dimSizesValues) { |
| 153 | const Dimension dimRank = stt.getDimRank(); |
| 154 | dimSizesValues.clear(); |
| 155 | dimSizesValues.reserve(N: dimRank); |
| 156 | unsigned i = 0; |
| 157 | for (const Size sz : stt.getDimShape()) |
| 158 | dimSizesValues.push_back(ShapedType::isDynamic(sz) |
| 159 | ? dynSizes[i++] |
| 160 | : constantIndex(builder, loc, sz)); |
| 161 | } |
| 162 | |
| 163 | /// Creates allocation for each field in sparse tensor type. Note that |
| 164 | /// for all dynamic memrefs in the sparse tensor stroage layout, the |
| 165 | /// memory size is really the capacity of the "vector", while the actual |
| 166 | /// size resides in the sizes array. |
| 167 | static void createAllocFields(OpBuilder &builder, Location loc, |
| 168 | SparseTensorType stt, bool enableInit, |
| 169 | Value sizeHint, |
| 170 | SmallVectorImpl<Value> &lvlSizesValues, |
| 171 | /*out*/ SmallVectorImpl<Value> &fields) { |
| 172 | Level lvlRank = stt.getLvlRank(); |
| 173 | // Set up some heuristic sizes. We try to set the initial |
| 174 | // size based on available information. Otherwise we just |
| 175 | // initialize a few elements to start the reallocation chain. |
| 176 | // TODO: refine this |
| 177 | Value posHeuristic, crdHeuristic, valHeuristic; |
| 178 | if (stt.isAllDense()) { |
| 179 | valHeuristic = lvlSizesValues[0]; |
| 180 | for (Level lvl = 1; lvl < lvlRank; lvl++) |
| 181 | valHeuristic = |
| 182 | builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]); |
| 183 | } else if (sizeHint) { |
| 184 | if (stt.getAoSCOOStart() == 0) { |
| 185 | posHeuristic = constantIndex(builder, loc, i: 2); |
| 186 | crdHeuristic = builder.create<arith::MulIOp>( |
| 187 | loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS |
| 188 | } else if (lvlRank == 2 && stt.isDenseLvl(l: 0) && stt.isCompressedLvl(l: 1)) { |
| 189 | posHeuristic = builder.create<arith::AddIOp>( |
| 190 | loc, sizeHint, constantIndex(builder, loc, 1)); |
| 191 | crdHeuristic = sizeHint; |
| 192 | } else { |
| 193 | posHeuristic = crdHeuristic = constantIndex(builder, loc, i: 16); |
| 194 | } |
| 195 | valHeuristic = sizeHint; |
| 196 | } else { |
| 197 | posHeuristic = crdHeuristic = valHeuristic = |
| 198 | constantIndex(builder, loc, i: 16); |
| 199 | } |
| 200 | // Initializes all fields. An initial storage specifier and allocated |
| 201 | // positions/coordinates/values memrefs (with heuristic capacity). |
| 202 | foreachFieldAndTypeInSparseTensor( |
| 203 | stt, |
| 204 | [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic, |
| 205 | enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, |
| 206 | Level /*lvl*/, LevelType /*lt*/) -> bool { |
| 207 | assert(fields.size() == fIdx); |
| 208 | Value field; |
| 209 | switch (fKind) { |
| 210 | case SparseTensorFieldKind::StorageSpec: |
| 211 | field = SparseTensorSpecifier::getInitValue(builder, loc, stt); |
| 212 | break; |
| 213 | case SparseTensorFieldKind::PosMemRef: |
| 214 | field = createAllocation(builder, loc, cast<MemRefType>(fType), |
| 215 | posHeuristic, enableInit); |
| 216 | break; |
| 217 | case SparseTensorFieldKind::CrdMemRef: |
| 218 | field = createAllocation(builder, loc, cast<MemRefType>(fType), |
| 219 | crdHeuristic, enableInit); |
| 220 | break; |
| 221 | case SparseTensorFieldKind::ValMemRef: |
| 222 | field = createAllocation(builder, loc, cast<MemRefType>(fType), |
| 223 | valHeuristic, enableInit); |
| 224 | break; |
| 225 | } |
| 226 | assert(field); |
| 227 | fields.push_back(Elt: field); |
| 228 | // Returns true to continue the iteration. |
| 229 | return true; |
| 230 | }); |
| 231 | // Initialize the storage scheme to an empty tensor. Sets the lvlSizes |
| 232 | // and gives all position fields an initial zero entry, so that it is |
| 233 | // easier to maintain the "linear + 1" length property. |
| 234 | MutSparseTensorDescriptor desc(stt, fields); |
| 235 | Value posZero = constantZero(builder, loc, tp: stt.getPosType()); |
| 236 | for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { |
| 237 | desc.setLvlSize(builder, loc, lvl, v: lvlSizesValues[lvl]); |
| 238 | const auto lt = stt.getLvlType(l: lvl); |
| 239 | if (isCompressedLT(lt) || isLooseCompressedLT(lt)) |
| 240 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::PosMemRef, lvl, |
| 241 | /*value=*/posZero); |
| 242 | } |
| 243 | allocSchemeForRank(builder, loc, desc, /*rank=*/startLvl: 0); |
| 244 | } |
| 245 | |
| 246 | /// Helper method that generates block specific to compressed case: |
| 247 | /// |
| 248 | /// // given: parentPos = posCursor[lvl-1] |
| 249 | /// pstart = desc.positions[lvl][parentPos] |
| 250 | /// pstop = desc.positions[lvl][parentPos+1] |
| 251 | /// plast = pstop - 1 |
| 252 | /// msz = desc.coordinates[lvl].size() |
| 253 | /// if (pstart < pstop) { |
| 254 | /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl]) |
| 255 | /// } else { // first insertion |
| 256 | /// isPresent = false |
| 257 | /// desc.positions[lvl][parentPos] = msz |
| 258 | /// } |
| 259 | /// if (isPresent) { // coordinate is already present |
| 260 | /// pnext = plast |
| 261 | /// } else { |
| 262 | /// desc.coordinates[lvl].push_back(lvlCoords[lvl]) |
| 263 | /// desc.positions[lvl][parentPos+1] = msz+1 |
| 264 | /// pnext = msz |
| 265 | /// <prepare level lvl+1> |
| 266 | /// } |
| 267 | /// posCursor[lvl] = pnext |
| 268 | static Value genCompressed(OpBuilder &builder, Location loc, |
| 269 | MutSparseTensorDescriptor desc, ValueRange lvlCoords, |
| 270 | Value /*unused*/, Value parentPos, Level lvl) { |
| 271 | const SparseTensorType stt(desc.getRankedTensorType()); |
| 272 | const Level lvlRank = stt.getLvlRank(); |
| 273 | assert(lvl < lvlRank && "Level is out of bounds" ); |
| 274 | assert(lvlCoords.size() == static_cast<size_t>(lvlRank) && |
| 275 | "Level-rank mismatch" ); |
| 276 | SmallVector<Type> types; |
| 277 | Type indexType = builder.getIndexType(); |
| 278 | Type boolType = builder.getIntegerType(1); |
| 279 | unsigned crdFidx; |
| 280 | unsigned crdStride; |
| 281 | std::tie(args&: crdFidx, args&: crdStride) = desc.getCrdMemRefIndexAndStride(lvl); |
| 282 | const Value one = constantIndex(builder, loc, i: 1); |
| 283 | const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one); |
| 284 | const Value positionsAtLvl = desc.getPosMemRef(lvl); |
| 285 | const Value pstart = genLoad(builder, loc, mem: positionsAtLvl, idx: parentPos); |
| 286 | const Value pstop = genLoad(builder, loc, mem: positionsAtLvl, idx: pp1); |
| 287 | const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl); |
| 288 | const Value crdStrideC = |
| 289 | crdStride > 1 ? constantIndex(builder, loc, i: crdStride) : Value(); |
| 290 | const Value msz = |
| 291 | crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC) |
| 292 | : crdMsz; |
| 293 | const Value plast = builder.create<arith::SubIOp>( |
| 294 | loc, genCast(builder, loc, pstop, indexType), one); |
| 295 | // Conditional expression. |
| 296 | Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
| 297 | pstart, pstop); |
| 298 | types.push_back(Elt: boolType); |
| 299 | scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true); |
| 300 | types.pop_back(); |
| 301 | builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); |
| 302 | Value crd = |
| 303 | genLoad(builder, loc, desc.getMemRefField(crdFidx), |
| 304 | crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC) |
| 305 | : plast); |
| 306 | Value eq = builder.create<arith::CmpIOp>( |
| 307 | loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType), |
| 308 | lvlCoords[lvl]); |
| 309 | builder.create<scf::YieldOp>(loc, eq); |
| 310 | builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); |
| 311 | if (lvl > 0) |
| 312 | genStore(builder, loc, val: msz, mem: positionsAtLvl, idx: parentPos); |
| 313 | builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false)); |
| 314 | builder.setInsertionPointAfter(ifOp1); |
| 315 | // If present construct. Note that for a non-unique dimension level, we |
| 316 | // simply set the condition to false and rely on CSE/DCE to clean up the IR. |
| 317 | // |
| 318 | // TODO: generate less temporary IR? |
| 319 | // |
| 320 | for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) |
| 321 | types.push_back(Elt: desc.getField(i).getType()); |
| 322 | types.push_back(Elt: indexType); |
| 323 | const Value p = stt.isUniqueLvl(l: lvl) ? ifOp1.getResult(0) |
| 324 | : constantI1(builder, loc, b: false); |
| 325 | scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true); |
| 326 | // If present (fields unaffected, update pnext to plast). |
| 327 | builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); |
| 328 | |
| 329 | // FIXME: This does not looks like a clean way, but probably the most |
| 330 | // efficient way. |
| 331 | desc.getFields().push_back(plast); |
| 332 | builder.create<scf::YieldOp>(loc, desc.getFields()); |
| 333 | desc.getFields().pop_back(); |
| 334 | |
| 335 | // If !present (changes fields, update pnext). |
| 336 | builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); |
| 337 | Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one); |
| 338 | genStore(builder, loc, val: mszp1, mem: positionsAtLvl, idx: pp1); |
| 339 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::CrdMemRef, lvl, |
| 340 | /*value=*/lvlCoords[lvl]); |
| 341 | // Prepare the next level "as needed". |
| 342 | if ((lvl + 1) < lvlRank) |
| 343 | allocSchemeForRank(builder, loc, desc, startLvl: lvl + 1); |
| 344 | |
| 345 | desc.getFields().push_back(msz); |
| 346 | builder.create<scf::YieldOp>(loc, desc.getFields()); |
| 347 | desc.getFields().pop_back(); |
| 348 | |
| 349 | // Update fields and return next pos. |
| 350 | builder.setInsertionPointAfter(ifOp2); |
| 351 | unsigned o = 0; |
| 352 | for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) |
| 353 | desc.setField(fidx: i, v: ifOp2.getResult(o++)); |
| 354 | return ifOp2.getResult(o); |
| 355 | } |
| 356 | |
| 357 | /// Generates insertion finalization code. |
| 358 | static void genEndInsert(OpBuilder &builder, Location loc, |
| 359 | SparseTensorDescriptor desc) { |
| 360 | const SparseTensorType stt(desc.getRankedTensorType()); |
| 361 | const Level lvlRank = stt.getLvlRank(); |
| 362 | for (Level lvl = 0; lvl < lvlRank; lvl++) { |
| 363 | const auto lt = stt.getLvlType(l: lvl); |
| 364 | if (isCompressedLT(lt)) { |
| 365 | // Compressed dimensions need a position cleanup for all entries |
| 366 | // that were not visited during the insertion pass. |
| 367 | // |
| 368 | // TODO: avoid cleanup and keep compressed scheme consistent at all |
| 369 | // times? |
| 370 | // |
| 371 | if (lvl > 0) { |
| 372 | Type posType = stt.getPosType(); |
| 373 | Value posMemRef = desc.getPosMemRef(lvl); |
| 374 | Value hi = desc.getPosMemSize(builder, loc, lvl); |
| 375 | Value zero = constantIndex(builder, loc, i: 0); |
| 376 | Value one = constantIndex(builder, loc, i: 1); |
| 377 | // Vector of only one, but needed by createFor's prototype. |
| 378 | SmallVector<Value, 1> inits{genLoad(builder, loc, mem: posMemRef, idx: zero)}; |
| 379 | scf::ForOp loop = createFor(builder, loc, hi, inits, one); |
| 380 | Value i = loop.getInductionVar(); |
| 381 | Value oldv = loop.getRegionIterArg(0); |
| 382 | Value newv = genLoad(builder, loc, mem: posMemRef, idx: i); |
| 383 | Value posZero = constantZero(builder, loc, tp: posType); |
| 384 | Value cond = builder.create<arith::CmpIOp>( |
| 385 | loc, arith::CmpIPredicate::eq, newv, posZero); |
| 386 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType), |
| 387 | cond, /*else*/ true); |
| 388 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 389 | genStore(builder, loc, val: oldv, mem: posMemRef, idx: i); |
| 390 | builder.create<scf::YieldOp>(loc, oldv); |
| 391 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 392 | builder.create<scf::YieldOp>(loc, newv); |
| 393 | builder.setInsertionPointAfter(ifOp); |
| 394 | builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); |
| 395 | builder.setInsertionPointAfter(loop); |
| 396 | } |
| 397 | } else { |
| 398 | assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) || |
| 399 | isNOutOfMLT(lt)); |
| 400 | } |
| 401 | } |
| 402 | } |
| 403 | |
| 404 | /// Generates a subview into the sizes. |
| 405 | static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, |
| 406 | Value sz) { |
| 407 | auto memTp = llvm::cast<MemRefType>(mem.getType()); |
| 408 | // For higher-dimensional memrefs, we assume that the innermost |
| 409 | // dimension is always of the right size. |
| 410 | // TODO: generate complex truncating view here too? |
| 411 | if (memTp.getRank() > 1) |
| 412 | return mem; |
| 413 | // Truncate linear memrefs to given size. |
| 414 | return builder |
| 415 | .create<memref::SubViewOp>( |
| 416 | loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()), |
| 417 | mem, ValueRange{}, ValueRange{sz}, ValueRange{}, |
| 418 | ArrayRef<int64_t>{0}, // static offset |
| 419 | ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size |
| 420 | ArrayRef<int64_t>{1}) // static stride |
| 421 | .getResult(); |
| 422 | } |
| 423 | |
| 424 | /// Creates the reassociation array. |
| 425 | static SmallVector<ReassociationIndices> |
| 426 | getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) { |
| 427 | SmallVector<ReassociationIndices> ret(batchLvls + 1, {}); |
| 428 | // Create reassociation in the form: |
| 429 | // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank} |
| 430 | for (unsigned i = 0; i < batchLvls; i++) |
| 431 | ret[i].push_back(Elt: i); |
| 432 | |
| 433 | for (int i = batchLvls, e = srcTp.getRank(); i < e; i++) |
| 434 | ret.back().push_back(Elt: i); |
| 435 | |
| 436 | return ret; |
| 437 | } |
| 438 | |
| 439 | //===----------------------------------------------------------------------===// |
| 440 | // Codegen rules. |
| 441 | //===----------------------------------------------------------------------===// |
| 442 | |
| 443 | namespace { |
| 444 | |
| 445 | /// Helper class to help lowering sparse_tensor.insert operation. |
| 446 | class SparseInsertGenerator |
| 447 | : public FuncCallOrInlineGenerator<SparseInsertGenerator> { |
| 448 | public: |
| 449 | SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params, |
| 450 | bool genCall) |
| 451 | : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){}; |
| 452 | |
| 453 | /// Generates code along an insertion path without the need for a "cursor". |
| 454 | /// This current insertion strategy comes at the expense of some testing |
| 455 | /// overhead for each insertion. The strategy will be optimized later for |
| 456 | /// common insertion patterns. The current insertion strategy also assumes |
| 457 | /// insertions occur in "a reasonable order" that enables building the |
| 458 | /// storage scheme in an appending/inserting kind of fashion (i.e. no |
| 459 | /// in-between insertions that need data movement). The implementation |
| 460 | /// relies on CSE/DCE to clean up all bookkeeping that is not needed. |
| 461 | /// |
| 462 | /// TODO: better unord/not-unique; also generalize, optimize, specialize! |
| 463 | SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args, |
| 464 | OpBuilder &builder, Location loc) { |
| 465 | const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp)); |
| 466 | const Level lvlRank = stt.getLvlRank(); |
| 467 | // Extract fields and coordinates from args. |
| 468 | SmallVector<Value> fields = llvm::to_vector(Range: args.drop_back(n: lvlRank + 1)); |
| 469 | MutSparseTensorDescriptor desc(stt, fields); |
| 470 | const SmallVector<Value> coords = |
| 471 | llvm::to_vector(Range: args.take_back(n: lvlRank + 1).drop_back()); |
| 472 | Value value = args.back(); |
| 473 | Value parentPos = constantZero(builder, loc, builder.getIndexType()); |
| 474 | // Generate code for every level. |
| 475 | for (Level lvl = 0; lvl < lvlRank; lvl++) { |
| 476 | const auto lt = stt.getLvlType(l: lvl); |
| 477 | if (isCompressedLT(lt) || isLooseCompressedLT(lt)) { |
| 478 | // Create: |
| 479 | // if (!present) { |
| 480 | // coordinates[lvl].push_back(coords[lvl]) |
| 481 | // <update positions and prepare level lvl + 1> |
| 482 | // } |
| 483 | // positions[lvl] = coordinates.size() - 1 |
| 484 | // <insert @ positions[lvl] at next level lvl + 1> |
| 485 | if (isLooseCompressedLT(lt)) { |
| 486 | Value two = constantIndex(builder, loc, i: 2); |
| 487 | parentPos = builder.create<arith::MulIOp>(loc, parentPos, two); |
| 488 | } |
| 489 | parentPos = |
| 490 | genCompressed(builder, loc, desc, lvlCoords: coords, value, parentPos, lvl); |
| 491 | } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { |
| 492 | // Create: |
| 493 | // coordinates[lvl].push_back(coords[lvl]) |
| 494 | // positions[lvl] = positions[lvl-1] |
| 495 | // <insert @ positions[lvl] at next level lvl + 1> |
| 496 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::CrdMemRef, |
| 497 | lvl, /*value=*/coords[lvl]); |
| 498 | } else { |
| 499 | assert(isDenseLT(lt)); |
| 500 | // Construct the new position as: |
| 501 | // positions[lvl] = size * positions[lvl-1] + coords[lvl] |
| 502 | // <insert @ positions[lvl] at next level lvl + 1> |
| 503 | Value size = desc.getLvlSize(builder, loc, lvl); |
| 504 | Value mult = builder.create<arith::MulIOp>(loc, size, parentPos); |
| 505 | parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]); |
| 506 | } |
| 507 | } |
| 508 | // Reached the actual value append/insert. |
| 509 | if (!stt.isDenseLvl(l: lvlRank - 1)) |
| 510 | createPushback(builder, loc, desc, kind: SparseTensorFieldKind::ValMemRef, |
| 511 | lvl: std::nullopt, value); |
| 512 | else |
| 513 | genStore(builder, loc, value, desc.getValMemRef(), parentPos); |
| 514 | return fields; |
| 515 | } |
| 516 | |
| 517 | std::string getMangledFuncName() { |
| 518 | // The mangled name of the function has this format: |
| 519 | // <namePrefix>_<LT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth> |
| 520 | constexpr const char kInsertFuncNamePrefix[] = "_insert_" ; |
| 521 | const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp)); |
| 522 | SmallString<32> nameBuffer; |
| 523 | llvm::raw_svector_ostream nameOstream(nameBuffer); |
| 524 | nameOstream << kInsertFuncNamePrefix; |
| 525 | const Level lvlRank = stt.getLvlRank(); |
| 526 | for (Level l = 0; l < lvlRank; l++) { |
| 527 | std::string lvlType = toMLIRString(lt: stt.getLvlType(l)); |
| 528 | // Replace/remove punctuations in level properties. |
| 529 | std::replace_if( |
| 530 | first: lvlType.begin(), last: lvlType.end(), |
| 531 | pred: [](char c) { return c == '(' || c == ','; }, new_value: '_'); |
| 532 | llvm::erase_if(C&: lvlType, P: [](char c) { return c == ')' || c == ' '; }); |
| 533 | nameOstream << lvlType << "_" ; |
| 534 | } |
| 535 | // Static dim sizes are used in the generated code while dynamic sizes are |
| 536 | // loaded from the dimSizes buffer. This is the reason for adding the shape |
| 537 | // to the function name. |
| 538 | for (const auto sz : stt.getDimShape()) |
| 539 | nameOstream << sz << "_" ; |
| 540 | // Permutation information is also used in generating insertion. |
| 541 | if (!stt.isIdentity()) |
| 542 | nameOstream << stt.getDimToLvl() << "_" ; |
| 543 | nameOstream << stt.getElementType() << "_" ; |
| 544 | nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth(); |
| 545 | return nameOstream.str().str(); |
| 546 | } |
| 547 | |
| 548 | private: |
| 549 | TensorType rtp; |
| 550 | }; |
| 551 | |
| 552 | /// Sparse tensor storage conversion rule for returns. |
| 553 | class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { |
| 554 | public: |
| 555 | using OpConversionPattern::OpConversionPattern; |
| 556 | LogicalResult |
| 557 | matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, |
| 558 | ConversionPatternRewriter &rewriter) const override { |
| 559 | // Create a return with the flattened value extracted from sparse tensors. |
| 560 | rewriter.replaceOpWithNewOp<func::ReturnOp>( |
| 561 | op, flattenValues(adaptor.getOperands())); |
| 562 | return success(); |
| 563 | } |
| 564 | }; |
| 565 | |
| 566 | /// Sparse tensor storage conversion rule for calls. |
| 567 | class SparseCallConverter : public OpConversionPattern<func::CallOp> { |
| 568 | public: |
| 569 | // The default CallOp converter can not handle 1:N type conversion. |
| 570 | using OpConversionPattern::OpConversionPattern; |
| 571 | LogicalResult |
| 572 | matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor, |
| 573 | ConversionPatternRewriter &rewriter) const override { |
| 574 | Location loc = op.getLoc(); |
| 575 | // In case of: |
| 576 | // sparse_tensor, f, sparse_tensor = call @foo(...) |
| 577 | // ==> |
| 578 | // memref..., f, memref = call @foo(...) replace with |
| 579 | // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor |
| 580 | SmallVector<Type> finalRetTy; |
| 581 | if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy))) |
| 582 | return failure(); |
| 583 | |
| 584 | // (1) Generates new call with flattened return value. |
| 585 | auto newCall = rewriter.create<func::CallOp>( |
| 586 | loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands())); |
| 587 | // (2) Gather sparse tensor returns. |
| 588 | SmallVector<SmallVector<Value>> packedResultVals; |
| 589 | // Tracks the offset of current return value (of the original call) |
| 590 | // relative to the new call (after sparse tensor flattening); |
| 591 | unsigned retOffset = 0; |
| 592 | // Temporal buffer to hold the flattened list of type for |
| 593 | // a sparse tensor. |
| 594 | SmallVector<Type> sparseFlat; |
| 595 | for (auto ret : op.getResults()) { |
| 596 | assert(retOffset < newCall.getNumResults()); |
| 597 | auto retType = ret.getType(); |
| 598 | if (failed(typeConverter->convertType(retType, sparseFlat))) |
| 599 | llvm_unreachable("Failed to convert type in sparse tensor codegen" ); |
| 600 | |
| 601 | // Converted types can not be empty when the type conversion succeed. |
| 602 | assert(!sparseFlat.empty()); |
| 603 | if (sparseFlat.size() > 1) { |
| 604 | auto flatSize = sparseFlat.size(); |
| 605 | packedResultVals.emplace_back(); |
| 606 | llvm::append_range(packedResultVals.back(), |
| 607 | newCall.getResults().slice(retOffset, flatSize)); |
| 608 | retOffset += flatSize; |
| 609 | } else { |
| 610 | // If this is an 1:1 conversion, no need for casting. |
| 611 | packedResultVals.emplace_back(); |
| 612 | packedResultVals.back().push_back(newCall.getResult(retOffset)); |
| 613 | retOffset++; |
| 614 | } |
| 615 | sparseFlat.clear(); |
| 616 | } |
| 617 | |
| 618 | assert(packedResultVals.size() == op.getNumResults()); |
| 619 | rewriter.replaceOpWithMultiple(op, std::move(packedResultVals)); |
| 620 | return success(); |
| 621 | } |
| 622 | }; |
| 623 | |
| 624 | /// Sparse codegen rule for level accesses. |
| 625 | class SparseLvlOpConverter : public OpConversionPattern<LvlOp> { |
| 626 | public: |
| 627 | using OpConversionPattern::OpConversionPattern; |
| 628 | LogicalResult |
| 629 | matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor, |
| 630 | ConversionPatternRewriter &rewriter) const override { |
| 631 | std::optional<int64_t> lvl = op.getConstantLvlIndex(); |
| 632 | RankedTensorType srcType = op.getSource().getType(); |
| 633 | if (!lvl || !getSparseTensorEncoding(srcType)) |
| 634 | return failure(); |
| 635 | |
| 636 | auto desc = getDescriptorFromTensorTuple(adaptor.getSource(), srcType); |
| 637 | auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl); |
| 638 | |
| 639 | rewriter.replaceOp(op, sz); |
| 640 | return success(); |
| 641 | } |
| 642 | }; |
| 643 | |
| 644 | // TODO: use a new SortCOO operation here instead of reusing convert op. |
| 645 | struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { |
| 646 | using OpConversionPattern::OpConversionPattern; |
| 647 | LogicalResult |
| 648 | matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor, |
| 649 | ConversionPatternRewriter &rewriter) const override { |
| 650 | Location loc = op.getLoc(); |
| 651 | MLIRContext *ctx = op.getContext(); |
| 652 | |
| 653 | SparseTensorType srcStt = getSparseTensorType(op.getInputCoo()); |
| 654 | SparseTensorType dstStt = getSparseTensorType(op.getResultCoo()); |
| 655 | |
| 656 | // Should have been verified. |
| 657 | assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() && |
| 658 | dstStt.isCOOType() && srcStt.isCOOType()); |
| 659 | assert(dstStt.hasSameDimToLvl(srcStt)); |
| 660 | |
| 661 | // We don't need a mutable descriptor here as we perform sorting in-place. |
| 662 | auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo(), |
| 663 | op.getInputCoo().getType()); |
| 664 | auto nnz = desc.getValMemSize(rewriter, op.getLoc()); |
| 665 | auto crd = desc.getAOSMemRef(); |
| 666 | auto val = desc.getValMemRef(); |
| 667 | |
| 668 | // Otherwise we need another data shuffle and a non-identity map. |
| 669 | assert(dstStt.hasSameDimToLvl(srcStt)); |
| 670 | (void)dstStt; // to silence warning when assertion is disabled |
| 671 | |
| 672 | auto id = AffineMap::getMultiDimIdentityMap(numDims: srcStt.getLvlRank(), context: ctx); |
| 673 | |
| 674 | rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id, |
| 675 | rewriter.getIndexAttr(0), op.getAlgorithm()); |
| 676 | |
| 677 | // Since we do in-place sorting, the destinate tensor will have the same set |
| 678 | // of memrefs as the source tensor. |
| 679 | rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()}); |
| 680 | return success(); |
| 681 | } |
| 682 | }; |
| 683 | |
| 684 | template <typename Op, StorageSpecifierKind kind> |
| 685 | class SparseSliceGetterOpConverter : public OpConversionPattern<Op> { |
| 686 | public: |
| 687 | using OpConversionPattern<Op>::OpConversionPattern; |
| 688 | using typename OpConversionPattern<Op>::OneToNOpAdaptor; |
| 689 | |
| 690 | LogicalResult |
| 691 | matchAndRewrite(Op op, OneToNOpAdaptor adaptor, |
| 692 | ConversionPatternRewriter &rewriter) const override { |
| 693 | // Simply lowers to specifer.get <field> operation. |
| 694 | auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(), |
| 695 | op.getSlice().getType()); |
| 696 | auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind, |
| 697 | op.getDim().getZExtValue()); |
| 698 | |
| 699 | rewriter.replaceOp(op, v); |
| 700 | return success(); |
| 701 | } |
| 702 | }; |
| 703 | |
| 704 | /// Sparse codegen rule for trivial tensor casts. |
| 705 | class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { |
| 706 | public: |
| 707 | using OpConversionPattern::OpConversionPattern; |
| 708 | LogicalResult |
| 709 | matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor, |
| 710 | ConversionPatternRewriter &rewriter) const override { |
| 711 | // Only rewrite identically annotated source/dest. |
| 712 | auto encDst = getSparseTensorEncoding(op.getType()); |
| 713 | auto encSrc = getSparseTensorEncoding(op.getSource().getType()); |
| 714 | if (!encDst || encDst != encSrc) |
| 715 | return failure(); |
| 716 | rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); |
| 717 | return success(); |
| 718 | } |
| 719 | }; |
| 720 | |
| 721 | class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { |
| 722 | public: |
| 723 | using OpConversionPattern::OpConversionPattern; |
| 724 | LogicalResult |
| 725 | matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor, |
| 726 | ConversionPatternRewriter &rewriter) const override { |
| 727 | // Simply fold the operation. |
| 728 | rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); |
| 729 | return success(); |
| 730 | } |
| 731 | }; |
| 732 | |
| 733 | /// Sparse codegen rule for the alloc operator. |
| 734 | class SparseTensorAllocConverter |
| 735 | : public OpConversionPattern<bufferization::AllocTensorOp> { |
| 736 | public: |
| 737 | using OpConversionPattern::OpConversionPattern; |
| 738 | SparseTensorAllocConverter(const TypeConverter &typeConverter, |
| 739 | MLIRContext *context, bool enableInit) |
| 740 | : OpConversionPattern(typeConverter, context), |
| 741 | enableBufferInitialization(enableInit) {} |
| 742 | |
| 743 | LogicalResult |
| 744 | matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor, |
| 745 | ConversionPatternRewriter &rewriter) const override { |
| 746 | const auto resType = getSparseTensorType(op); |
| 747 | if (!resType.hasEncoding()) |
| 748 | return failure(); |
| 749 | |
| 750 | Location loc = op.getLoc(); |
| 751 | // Deal with copy. |
| 752 | if (op.getCopy()) { |
| 753 | auto desc = getDescriptorFromTensorTuple( |
| 754 | adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType())); |
| 755 | SmallVector<Value> fields; |
| 756 | fields.reserve(N: desc.getNumFields()); |
| 757 | // Memcpy on memref fields. |
| 758 | for (auto field : desc.getMemRefFields()) { |
| 759 | auto memrefTp = cast<MemRefType>(field.getType()); |
| 760 | auto size = rewriter.create<memref::DimOp>(loc, field, 0); |
| 761 | auto copied = |
| 762 | rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size}); |
| 763 | rewriter.create<memref::CopyOp>(loc, field, copied); |
| 764 | fields.push_back(copied); |
| 765 | } |
| 766 | // Reuses specifier. |
| 767 | fields.push_back(Elt: desc.getSpecifier()); |
| 768 | assert(fields.size() == desc.getNumFields()); |
| 769 | rewriter.replaceOpWithMultiple(op, {fields}); |
| 770 | return success(); |
| 771 | } |
| 772 | |
| 773 | if (!resType.isIdentity()) { |
| 774 | return rewriter.notifyMatchFailure( |
| 775 | op, "try run --sparse-reinterpret-map before codegen" ); |
| 776 | } |
| 777 | // Level size equals to dimension size since lvl2dim map is an identity map. |
| 778 | SmallVector<Value> lvlSizesValues; |
| 779 | createDimSizes(rewriter, loc, resType, |
| 780 | flattenValues(adaptor.getDynamicSizes()), |
| 781 | /*dimSizesValues=*/lvlSizesValues); |
| 782 | |
| 783 | // Construct allocation for each field. |
| 784 | Value sizeHint = op.getSizeHint(); |
| 785 | SmallVector<Value> fields; |
| 786 | createAllocFields(rewriter, loc, resType, enableBufferInitialization, |
| 787 | sizeHint, lvlSizesValues, fields); |
| 788 | |
| 789 | // Replace operation with resulting memrefs. |
| 790 | rewriter.replaceOpWithMultiple(op, {fields}); |
| 791 | return success(); |
| 792 | } |
| 793 | |
| 794 | private: |
| 795 | bool enableBufferInitialization; |
| 796 | }; |
| 797 | |
| 798 | /// Sparse codegen rule for the empty tensor operator. |
| 799 | class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> { |
| 800 | public: |
| 801 | using OpConversionPattern::OpConversionPattern; |
| 802 | SparseTensorEmptyConverter(const TypeConverter &typeConverter, |
| 803 | MLIRContext *context, bool enableInit) |
| 804 | : OpConversionPattern(typeConverter, context), |
| 805 | enableBufferInitialization(enableInit) {} |
| 806 | |
| 807 | LogicalResult |
| 808 | matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, |
| 809 | ConversionPatternRewriter &rewriter) const override { |
| 810 | const auto resType = getSparseTensorType(op); |
| 811 | if (!resType.hasEncoding()) |
| 812 | return failure(); |
| 813 | |
| 814 | if (!resType.isIdentity()) { |
| 815 | return rewriter.notifyMatchFailure( |
| 816 | op, "try run --sparse-reinterpret-map before codegen" ); |
| 817 | } |
| 818 | |
| 819 | Location loc = op.getLoc(); |
| 820 | // Level size equals to dimension size since lvl2dim map is an identity map. |
| 821 | SmallVector<Value> lvlSizesValues; |
| 822 | createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(), |
| 823 | /*dimSizesValues=*/lvlSizesValues); |
| 824 | // Construct allocation for each field. |
| 825 | Value sizeHint; // none |
| 826 | SmallVector<Value> fields; |
| 827 | createAllocFields(rewriter, loc, resType, enableBufferInitialization, |
| 828 | sizeHint, lvlSizesValues, fields); |
| 829 | |
| 830 | // Replace operation with resulting memrefs. |
| 831 | rewriter.replaceOpWithMultiple(op, {fields}); |
| 832 | return success(); |
| 833 | } |
| 834 | |
| 835 | private: |
| 836 | bool enableBufferInitialization; |
| 837 | }; |
| 838 | |
| 839 | /// Sparse codegen rule for the dealloc operator. |
| 840 | class SparseTensorDeallocConverter |
| 841 | : public OpConversionPattern<bufferization::DeallocTensorOp> { |
| 842 | public: |
| 843 | using OpConversionPattern::OpConversionPattern; |
| 844 | SparseTensorDeallocConverter(const TypeConverter &typeConverter, |
| 845 | MLIRContext *context, bool createDeallocs) |
| 846 | : OpConversionPattern(typeConverter, context), |
| 847 | createDeallocs(createDeallocs) {} |
| 848 | |
| 849 | LogicalResult |
| 850 | matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor, |
| 851 | ConversionPatternRewriter &rewriter) const override { |
| 852 | auto enc = getSparseTensorEncoding(op.getTensor().getType()); |
| 853 | if (!enc) |
| 854 | return failure(); |
| 855 | |
| 856 | // If user requests not to deallocate sparse tensors, simply erase the |
| 857 | // operation. |
| 858 | if (createDeallocs) { |
| 859 | // Replace the sparse tensor deallocation with field deallocations. |
| 860 | Location loc = op.getLoc(); |
| 861 | auto desc = getDescriptorFromTensorTuple( |
| 862 | adaptor.getTensor(), |
| 863 | cast<RankedTensorType>(op.getTensor().getType())); |
| 864 | for (auto input : desc.getMemRefFields()) |
| 865 | // Deallocate every buffer used to store the sparse tensor handler. |
| 866 | rewriter.create<memref::DeallocOp>(loc, input); |
| 867 | } |
| 868 | rewriter.eraseOp(op: op); |
| 869 | return success(); |
| 870 | } |
| 871 | |
| 872 | private: |
| 873 | const bool createDeallocs; |
| 874 | }; |
| 875 | |
| 876 | /// Sparse codegen rule for tensor rematerialization. |
| 877 | class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { |
| 878 | public: |
| 879 | using OpConversionPattern::OpConversionPattern; |
| 880 | LogicalResult |
| 881 | matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor, |
| 882 | ConversionPatternRewriter &rewriter) const override { |
| 883 | // Prepare descriptor. |
| 884 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
| 885 | op.getTensor().getType()); |
| 886 | // Generate optional insertion finalization code. |
| 887 | if (op.getHasInserts()) |
| 888 | genEndInsert(rewriter, op.getLoc(), desc); |
| 889 | // Replace operation with resulting memrefs. |
| 890 | rewriter.replaceOpWithMultiple(op, {desc.getFields()}); |
| 891 | return success(); |
| 892 | } |
| 893 | }; |
| 894 | |
| 895 | /// Sparse codegen rule for the expand op. |
| 896 | class SparseExpandConverter : public OpConversionPattern<ExpandOp> { |
| 897 | public: |
| 898 | using OpConversionPattern::OpConversionPattern; |
| 899 | LogicalResult |
| 900 | matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor, |
| 901 | ConversionPatternRewriter &rewriter) const override { |
| 902 | if (!getSparseTensorEncoding(op.getTensor().getType())) |
| 903 | return failure(); |
| 904 | Location loc = op->getLoc(); |
| 905 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
| 906 | op.getTensor().getType()); |
| 907 | const auto srcType = getSparseTensorType(op.getTensor()); |
| 908 | Type eltType = srcType.getElementType(); |
| 909 | Type boolType = rewriter.getIntegerType(1); |
| 910 | Type idxType = rewriter.getIndexType(); |
| 911 | // All initialization should be done on entry of the loop nest. |
| 912 | rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); |
| 913 | |
| 914 | // Determine the size for access expansion (always the innermost stored |
| 915 | // level size). |
| 916 | const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1); |
| 917 | // Generate a memref for `sz` elements of type `t`. |
| 918 | const auto genAlloc = [&](Type t) { |
| 919 | const auto memTp = MemRefType::get({ShapedType::kDynamic}, t); |
| 920 | return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz}); |
| 921 | }; |
| 922 | // Allocate temporary buffers for values/filled-switch and added. |
| 923 | // We do not use stack buffers for this, since the expanded size may |
| 924 | // be rather large (as it envelops a single expanded dense dimension). |
| 925 | Value values = genAlloc(eltType); |
| 926 | Value filled = genAlloc(boolType); |
| 927 | Value added = genAlloc(idxType); |
| 928 | Value zero = constantZero(builder&: rewriter, loc, tp: idxType); |
| 929 | // Reset the values/filled-switch to all-zero/false. Note that this |
| 930 | // introduces an O(N) operation into the computation, but this reset |
| 931 | // operation is amortized over the innermost loops for the access |
| 932 | // pattern expansion. As noted in the operation doc, we would like |
| 933 | // to amortize this setup cost even between kernels. |
| 934 | rewriter.create<linalg::FillOp>( |
| 935 | loc, ValueRange{constantZero(rewriter, loc, eltType)}, |
| 936 | ValueRange{values}); |
| 937 | rewriter.create<linalg::FillOp>( |
| 938 | loc, ValueRange{constantZero(rewriter, loc, boolType)}, |
| 939 | ValueRange{filled}); |
| 940 | // Replace expansion op with these buffers and initial coordinate. |
| 941 | assert(op.getNumResults() == 4); |
| 942 | rewriter.replaceOp(op, {values, filled, added, zero}); |
| 943 | return success(); |
| 944 | } |
| 945 | }; |
| 946 | |
| 947 | /// Sparse codegen rule for the compress operator. |
| 948 | class SparseCompressConverter : public OpConversionPattern<CompressOp> { |
| 949 | public: |
| 950 | using OpConversionPattern::OpConversionPattern; |
| 951 | LogicalResult |
| 952 | matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor, |
| 953 | ConversionPatternRewriter &rewriter) const override { |
| 954 | Location loc = op->getLoc(); |
| 955 | SmallVector<Value> fields; |
| 956 | auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields, |
| 957 | op.getTensor().getType()); |
| 958 | Value values = llvm::getSingleElement(adaptor.getValues()); |
| 959 | Value filled = llvm::getSingleElement(adaptor.getFilled()); |
| 960 | Value added = llvm::getSingleElement(adaptor.getAdded()); |
| 961 | Value count = llvm::getSingleElement(adaptor.getCount()); |
| 962 | const SparseTensorType dstType(desc.getRankedTensorType()); |
| 963 | Type eltType = dstType.getElementType(); |
| 964 | |
| 965 | // If the innermost level is ordered, we need to sort the coordinates |
| 966 | // in the "added" array prior to applying the compression. |
| 967 | if (dstType.isOrderedLvl(dstType.getLvlRank() - 1)) |
| 968 | rewriter.create<SortOp>( |
| 969 | loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1), |
| 970 | rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); |
| 971 | // While performing the insertions, we also need to reset the elements |
| 972 | // of the values/filled-switch by only iterating over the set elements, |
| 973 | // to ensure that the runtime complexity remains proportional to the |
| 974 | // sparsity of the expanded access pattern. |
| 975 | // |
| 976 | // Generate |
| 977 | // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) { |
| 978 | // crd = added[i]; |
| 979 | // value = values[crd]; |
| 980 | // insert({lvlCoords, crd}, value); |
| 981 | // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value); |
| 982 | // values[crd] = 0; |
| 983 | // filled[crd] = false; |
| 984 | // yield new_memrefs |
| 985 | // } |
| 986 | scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields()); |
| 987 | Value i = loop.getInductionVar(); |
| 988 | |
| 989 | Value crd = genLoad(builder&: rewriter, loc, mem: added, idx: i); |
| 990 | Value value = genLoad(builder&: rewriter, loc, mem: values, idx: crd); |
| 991 | SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end()); |
| 992 | SmallVector<Type> flatSpTensorTps = llvm::to_vector( |
| 993 | llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); })); |
| 994 | SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords()); |
| 995 | params.append(in_start: flatLvlCoords.begin(), in_end: flatLvlCoords.end()); |
| 996 | params.push_back(Elt: crd); |
| 997 | params.push_back(Elt: value); |
| 998 | SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps, |
| 999 | params, /*genCall=*/true); |
| 1000 | SmallVector<Value> insertRet = insertGen.genCallOrInline(builder&: rewriter, loc); |
| 1001 | genStore(builder&: rewriter, loc, val: constantZero(builder&: rewriter, loc, tp: eltType), mem: values, idx: crd); |
| 1002 | genStore(builder&: rewriter, loc, val: constantI1(builder&: rewriter, loc, b: false), mem: filled, idx: crd); |
| 1003 | rewriter.create<scf::YieldOp>(loc, insertRet); |
| 1004 | |
| 1005 | rewriter.setInsertionPointAfter(loop); |
| 1006 | // Deallocate the buffers on exit of the full loop nest. |
| 1007 | Operation *parent = getTop(op); |
| 1008 | rewriter.setInsertionPointAfter(parent); |
| 1009 | rewriter.create<memref::DeallocOp>(loc, values); |
| 1010 | rewriter.create<memref::DeallocOp>(loc, filled); |
| 1011 | rewriter.create<memref::DeallocOp>(loc, added); |
| 1012 | // Replace operation with resulting memrefs. |
| 1013 | rewriter.replaceOpWithMultiple(op, {loop->getResults()}); |
| 1014 | return success(); |
| 1015 | } |
| 1016 | }; |
| 1017 | |
| 1018 | /// Sparse codegen rule for the insert operator. |
| 1019 | class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> { |
| 1020 | public: |
| 1021 | using OpConversionPattern::OpConversionPattern; |
| 1022 | LogicalResult |
| 1023 | matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor, |
| 1024 | ConversionPatternRewriter &rewriter) const override { |
| 1025 | auto stt = getSparseTensorType(op.getDest()); |
| 1026 | if (!stt.hasEncoding()) |
| 1027 | return failure(); |
| 1028 | assert(stt.isIdentity() && "Run reinterpret-map before conversion." ); |
| 1029 | |
| 1030 | Location loc = op.getLoc(); |
| 1031 | auto desc = |
| 1032 | getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType()); |
| 1033 | TypeRange flatSpTensorTps = desc.getFields().getTypes(); |
| 1034 | SmallVector<Value> params = llvm::to_vector(desc.getFields()); |
| 1035 | SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices()); |
| 1036 | params.append(in_start: flatIndices.begin(), in_end: flatIndices.end()); |
| 1037 | params.push_back(Elt: llvm::getSingleElement(adaptor.getScalar())); |
| 1038 | SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps, |
| 1039 | params, /*genCall=*/true); |
| 1040 | SmallVector<Value> ret = insertGen.genCallOrInline(builder&: rewriter, loc); |
| 1041 | // Replace operation with resulting memrefs. |
| 1042 | rewriter.replaceOpWithMultiple(op, {ret}); |
| 1043 | return success(); |
| 1044 | } |
| 1045 | }; |
| 1046 | |
| 1047 | /// Sparse codegen rule for position accesses. |
| 1048 | class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> { |
| 1049 | public: |
| 1050 | using OpAdaptor = typename ToPositionsOp::Adaptor; |
| 1051 | using OpConversionPattern<ToPositionsOp>::OpConversionPattern; |
| 1052 | LogicalResult |
| 1053 | matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor, |
| 1054 | ConversionPatternRewriter &rewriter) const override { |
| 1055 | // Replace the requested position access with corresponding field. |
| 1056 | // The view is restricted to the actual size to ensure clients |
| 1057 | // of this operation truly observe size, not capacity! |
| 1058 | Location loc = op.getLoc(); |
| 1059 | Level lvl = op.getLevel(); |
| 1060 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
| 1061 | op.getTensor().getType()); |
| 1062 | auto mem = desc.getPosMemRef(lvl); |
| 1063 | auto size = desc.getPosMemSize(rewriter, loc, lvl); |
| 1064 | rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); |
| 1065 | return success(); |
| 1066 | } |
| 1067 | }; |
| 1068 | |
| 1069 | /// Sparse codegen rule for accessing the coordinates arrays. |
| 1070 | class SparseToCoordinatesConverter |
| 1071 | : public OpConversionPattern<ToCoordinatesOp> { |
| 1072 | public: |
| 1073 | using OpAdaptor = typename ToCoordinatesOp::Adaptor; |
| 1074 | using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern; |
| 1075 | LogicalResult |
| 1076 | matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor, |
| 1077 | ConversionPatternRewriter &rewriter) const override { |
| 1078 | // Replace the requested coordinates access with corresponding field. |
| 1079 | // The view is restricted to the actual size to ensure clients |
| 1080 | // of this operation truly observe size, not capacity! |
| 1081 | Location loc = op.getLoc(); |
| 1082 | Level lvl = op.getLevel(); |
| 1083 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
| 1084 | op.getTensor().getType()); |
| 1085 | auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl); |
| 1086 | if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) { |
| 1087 | auto size = desc.getCrdMemSize(rewriter, loc, lvl); |
| 1088 | mem = genSliceToSize(rewriter, loc, mem, size); |
| 1089 | } |
| 1090 | rewriter.replaceOp(op, mem); |
| 1091 | return success(); |
| 1092 | } |
| 1093 | }; |
| 1094 | |
| 1095 | /// Sparse codegen rule for accessing the linear coordinates buffer. |
| 1096 | class SparseToCoordinatesBufferConverter |
| 1097 | : public OpConversionPattern<ToCoordinatesBufferOp> { |
| 1098 | public: |
| 1099 | using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor; |
| 1100 | using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern; |
| 1101 | LogicalResult |
| 1102 | matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor, |
| 1103 | ConversionPatternRewriter &rewriter) const override { |
| 1104 | // Replace the requested coordinates access with corresponding field. |
| 1105 | // The view is restricted to the actual size to ensure clients |
| 1106 | // of this operation truly observe size, not capacity! |
| 1107 | Location loc = op.getLoc(); |
| 1108 | Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart(); |
| 1109 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
| 1110 | op.getTensor().getType()); |
| 1111 | auto mem = desc.getAOSMemRef(); |
| 1112 | auto size = desc.getCrdMemSize(rewriter, loc, lvl); |
| 1113 | rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); |
| 1114 | return success(); |
| 1115 | } |
| 1116 | }; |
| 1117 | |
| 1118 | /// Sparse codegen rule for value accesses. |
| 1119 | class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> { |
| 1120 | public: |
| 1121 | using OpAdaptor = typename ToValuesOp::Adaptor; |
| 1122 | using OpConversionPattern<ToValuesOp>::OpConversionPattern; |
| 1123 | LogicalResult |
| 1124 | matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor, |
| 1125 | ConversionPatternRewriter &rewriter) const override { |
| 1126 | // Replace the requested values access with corresponding field. |
| 1127 | // The view is restricted to the actual size to ensure clients |
| 1128 | // of this operation truly observe size, not capacity! |
| 1129 | Location loc = op.getLoc(); |
| 1130 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
| 1131 | op.getTensor().getType()); |
| 1132 | auto mem = desc.getValMemRef(); |
| 1133 | auto size = desc.getValMemSize(rewriter, loc); |
| 1134 | rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); |
| 1135 | return success(); |
| 1136 | } |
| 1137 | }; |
| 1138 | |
| 1139 | /// Sparse codegen rule for the convert operator. |
| 1140 | class SparseConvertConverter : public OpConversionPattern<ConvertOp> { |
| 1141 | public: |
| 1142 | using OpConversionPattern::OpConversionPattern; |
| 1143 | LogicalResult |
| 1144 | matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor, |
| 1145 | ConversionPatternRewriter &rewriter) const override { |
| 1146 | SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType()); |
| 1147 | SparseTensorEncodingAttr encSrc = |
| 1148 | getSparseTensorEncoding(op.getSource().getType()); |
| 1149 | // The output tensor can not be a slice and those cases should have been |
| 1150 | // rejected by ConvertOp::verify() already. |
| 1151 | assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices." ); |
| 1152 | // Different encoding (except for different bitwidth) should be handled by |
| 1153 | // rewriting. |
| 1154 | // We need further rewrites if the input tensor is a slice too. |
| 1155 | if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() || |
| 1156 | encSrc.isSlice()) { |
| 1157 | return failure(); |
| 1158 | } |
| 1159 | |
| 1160 | Type retElemTp = op.getResult().getType().getElementType(); |
| 1161 | Type srcElemTp = op.getSource().getType().getElementType(); |
| 1162 | // Fold the trivial cases. |
| 1163 | if (retElemTp == srcElemTp && encDst == encSrc) { |
| 1164 | rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); |
| 1165 | return success(); |
| 1166 | } |
| 1167 | // |
| 1168 | // Do element-wise type conversion without using InsertOp. |
| 1169 | // |
| 1170 | // for each memref in srcTensor: |
| 1171 | // dst = memref.alloc |
| 1172 | // if srcMemRefType != dstMemRefType: |
| 1173 | // for every dst[i] = cast(src[i]) |
| 1174 | // else: |
| 1175 | // dst = memref.copy(src) |
| 1176 | Location loc = op.getLoc(); |
| 1177 | auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(), |
| 1178 | op.getSource().getType()); |
| 1179 | SmallVector<Value> fields; |
| 1180 | foreachFieldAndTypeInSparseTensor( |
| 1181 | SparseTensorType(cast<RankedTensorType>(op.getResult().getType())), |
| 1182 | [&rewriter, &fields, srcDesc, |
| 1183 | loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl, |
| 1184 | LevelType /*lt*/) -> bool { |
| 1185 | // Simply reuses the storage specifier as it is an SSA value. |
| 1186 | if (fKind == SparseTensorFieldKind::StorageSpec) { |
| 1187 | fields.push_back(Elt: srcDesc.getSpecifier()); |
| 1188 | } else { |
| 1189 | // Allocates new memrefs |
| 1190 | Value srcMem = srcDesc.getMemRefField(fIdx); |
| 1191 | // TODO: We can instead use the actual memSize in specifier, that |
| 1192 | // would require a subViewOp to avoid overflow when copying |
| 1193 | // values. |
| 1194 | Value sz = linalg::createOrFoldDimOp(b&: rewriter, loc, val: srcMem, dim: 0); |
| 1195 | auto dstMem = rewriter.create<memref::AllocOp>( |
| 1196 | loc, cast<MemRefType>(fTp), sz); |
| 1197 | if (fTp != srcMem.getType()) { |
| 1198 | // Converts elements type. |
| 1199 | scf::buildLoopNest( |
| 1200 | builder&: rewriter, loc, lbs: constantIndex(builder&: rewriter, loc, i: 0), ubs: sz, |
| 1201 | steps: constantIndex(builder&: rewriter, loc, i: 1), |
| 1202 | bodyBuilder: [srcMem, &dstMem](OpBuilder &builder, Location loc, |
| 1203 | ValueRange ivs) { |
| 1204 | Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs); |
| 1205 | Value casted = genCast(builder, loc, v, |
| 1206 | dstMem.getType().getElementType()); |
| 1207 | builder.create<memref::StoreOp>(loc, casted, dstMem, ivs); |
| 1208 | }); |
| 1209 | } else { |
| 1210 | // TODO: We can even reuse the same memref for the new tensor, |
| 1211 | // but that requires a `ref-counting` based memory management |
| 1212 | // for shared memrefs between multiple sparse tensors. |
| 1213 | rewriter.create<memref::CopyOp>(loc, srcMem, dstMem); |
| 1214 | } |
| 1215 | fields.push_back(Elt: dstMem); |
| 1216 | } |
| 1217 | return true; |
| 1218 | }); |
| 1219 | |
| 1220 | rewriter.replaceOpWithMultiple(op, {fields}); |
| 1221 | return success(); |
| 1222 | } |
| 1223 | }; |
| 1224 | |
| 1225 | class |
| 1226 | : public OpConversionPattern<tensor::ExtractSliceOp> { |
| 1227 | public: |
| 1228 | using OpConversionPattern::OpConversionPattern; |
| 1229 | LogicalResult |
| 1230 | matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor, |
| 1231 | ConversionPatternRewriter &rewriter) const override { |
| 1232 | Location loc = op.getLoc(); |
| 1233 | MLIRContext *ctx = op.getContext(); |
| 1234 | auto srcEnc = getSparseTensorEncoding(op.getSourceType()); |
| 1235 | auto dstEnc = getSparseTensorEncoding(op.getResult().getType()); |
| 1236 | // TODO: We should check these in ExtractSliceOp::verify. |
| 1237 | if (!srcEnc || !dstEnc || !dstEnc.isSlice()) |
| 1238 | return failure(); |
| 1239 | assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices()); |
| 1240 | |
| 1241 | SmallVector<Value> fields; |
| 1242 | auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields, |
| 1243 | op.getSource().getType()); |
| 1244 | |
| 1245 | auto newSpec = rewriter.create<StorageSpecifierInitOp>( |
| 1246 | loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier()); |
| 1247 | desc.setSpecifier(newSpec); |
| 1248 | |
| 1249 | // Fills in slice information. |
| 1250 | for (auto [idx, offset, size, stride] : llvm::enumerate( |
| 1251 | op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) { |
| 1252 | Dimension dim = idx; |
| 1253 | |
| 1254 | Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset); |
| 1255 | Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size); |
| 1256 | Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride); |
| 1257 | // TODO: We could probably only set dynamic value here. But it would |
| 1258 | // requires us to fill the hole when casting a static slice to dynamic |
| 1259 | // slice. |
| 1260 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset, |
| 1261 | dim, offsetV); |
| 1262 | |
| 1263 | // FIXME: we need to distinguish level sizes and dimension size for slices |
| 1264 | // here. Maybe we should store slice level sizes in a different array |
| 1265 | // instead of reusing it. |
| 1266 | assert(srcEnc.isIdentity()); |
| 1267 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim, |
| 1268 | sizeV); |
| 1269 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride, |
| 1270 | dim, strideV); |
| 1271 | } |
| 1272 | |
| 1273 | // NOTE: we can not generate tuples directly from descriptor here, as the |
| 1274 | // descriptor is holding the original type, yet we want the slice type |
| 1275 | // here (they shared every memref but with an updated specifier). |
| 1276 | rewriter.replaceOpWithMultiple(op, {desc.getFields()}); |
| 1277 | return success(); |
| 1278 | } |
| 1279 | }; |
| 1280 | |
| 1281 | /// Sparse codegen rule for number of entries operator. |
| 1282 | class SparseNumberOfEntriesConverter |
| 1283 | : public OpConversionPattern<NumberOfEntriesOp> { |
| 1284 | public: |
| 1285 | using OpConversionPattern::OpConversionPattern; |
| 1286 | LogicalResult |
| 1287 | matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor, |
| 1288 | ConversionPatternRewriter &rewriter) const override { |
| 1289 | // Query memSizes for the actually stored values. |
| 1290 | // FIXME: the nse value computed in this way might be wrong when there is |
| 1291 | // any "loose_compressed" level. |
| 1292 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
| 1293 | op.getTensor().getType()); |
| 1294 | rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc())); |
| 1295 | return success(); |
| 1296 | } |
| 1297 | }; |
| 1298 | |
| 1299 | struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> { |
| 1300 | using OpConversionPattern::OpConversionPattern; |
| 1301 | LogicalResult |
| 1302 | matchAndRewrite(AssembleOp op, OpAdaptor adaptor, |
| 1303 | ConversionPatternRewriter &rewriter) const override { |
| 1304 | Location loc = op.getLoc(); |
| 1305 | const auto stt = getSparseTensorType(op.getResult()); |
| 1306 | |
| 1307 | SmallVector<Value> fields; |
| 1308 | |
| 1309 | foreachFieldAndTypeInSparseTensor( |
| 1310 | stt, |
| 1311 | [&rewriter, &fields, &op, &stt, |
| 1312 | loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, |
| 1313 | Level /*lvl*/, LevelType lt) -> bool { |
| 1314 | assert(fields.size() == fIdx); |
| 1315 | if (fKind == SparseTensorFieldKind::StorageSpec) { |
| 1316 | fields.push_back( |
| 1317 | Elt: SparseTensorSpecifier::getInitValue(builder&: rewriter, loc, stt: stt)); |
| 1318 | } else { |
| 1319 | // Else simply takes the inputs. |
| 1320 | Value tensor = fKind == SparseTensorFieldKind::ValMemRef |
| 1321 | ? op.getValues() |
| 1322 | : op.getLevels()[fIdx]; |
| 1323 | // TODO: handle batch. |
| 1324 | TypedValue<BaseMemRefType> mem = genToMemref(builder&: rewriter, loc, tensor); |
| 1325 | if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) { |
| 1326 | // Flattens the buffer to batchLvlRank. |
| 1327 | auto reassoc = getReassociationForFlattening( |
| 1328 | mem.getType(), stt.getBatchLvlRank()); |
| 1329 | mem = rewriter.create<memref::CastOp>( |
| 1330 | loc, fType, |
| 1331 | rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc)); |
| 1332 | } else { |
| 1333 | mem = rewriter.create<memref::CastOp>(loc, fType, mem); |
| 1334 | } |
| 1335 | fields.push_back(Elt: mem); |
| 1336 | } |
| 1337 | return true; |
| 1338 | }); |
| 1339 | |
| 1340 | MutSparseTensorDescriptor desc(stt, fields); |
| 1341 | Value c0 = constantIndex(builder&: rewriter, loc, i: 0); |
| 1342 | Value c1 = constantIndex(builder&: rewriter, loc, i: 1); |
| 1343 | Value c2 = constantIndex(builder&: rewriter, loc, i: 2); |
| 1344 | Value posBack = c0; // index to the last value in the position array |
| 1345 | Value memSize = c1; // memory size for current array |
| 1346 | |
| 1347 | Level trailCOOStart = stt.getAoSCOOStart(); |
| 1348 | Level trailCOORank = stt.getLvlRank() - trailCOOStart; |
| 1349 | // Sets up SparseTensorSpecifier. |
| 1350 | for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { |
| 1351 | assert(!ShapedType::isDynamic(stt.getDimShape()[lvl])); |
| 1352 | |
| 1353 | // Sets up the level size. |
| 1354 | auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]); |
| 1355 | desc.setLvlSize(builder&: rewriter, loc, lvl, v: lvlSize); |
| 1356 | // We use a single AOS array to store the trailing COO, so there is only |
| 1357 | // one memory size to set for the entire COO section. |
| 1358 | if (lvl > trailCOOStart) |
| 1359 | continue; |
| 1360 | |
| 1361 | // Sets up the memory size by reading the last value in position array. |
| 1362 | LevelType lt = stt.getLvlType(lvl); |
| 1363 | // Simply forwards the position index when this is a dense level. |
| 1364 | if (lt.isa<LevelFormat::Dense>()) { |
| 1365 | memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize); |
| 1366 | posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1); |
| 1367 | continue; |
| 1368 | } |
| 1369 | if (lt.isa<LevelFormat::Batch>()) { |
| 1370 | // Skips batch levels as it is not linearized. |
| 1371 | // FIXME: this assumes that every batch has the same number of nse, need |
| 1372 | // to be generalized to handle varied-size batches. |
| 1373 | continue; |
| 1374 | } |
| 1375 | |
| 1376 | if (isWithPosLT(lt)) { |
| 1377 | assert(isCompressedLT(lt) || isLooseCompressedLT(lt)); |
| 1378 | if (isLooseCompressedLT(lt)) { |
| 1379 | memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2); |
| 1380 | posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1); |
| 1381 | } else { |
| 1382 | assert(isCompressedLT(lt)); |
| 1383 | posBack = memSize; |
| 1384 | memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1); |
| 1385 | } |
| 1386 | desc.setPosMemSize(builder&: rewriter, loc, lvl, v: memSize); |
| 1387 | // The last value in position array is the memory size for next level. |
| 1388 | // FIXME: this assumes that every batch has the same number of nse, need |
| 1389 | // to be generalized to handle varied-size batches. |
| 1390 | SmallVector<Value> batched(stt.getBatchLvlRank(), |
| 1391 | constantIndex(builder&: rewriter, loc, i: 0)); |
| 1392 | batched.push_back(Elt: posBack); |
| 1393 | memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched); |
| 1394 | posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1); |
| 1395 | } |
| 1396 | assert(isWithCrdLT(lt) && lvl <= trailCOOStart); |
| 1397 | // FIXME: This seems to be unnecessarily complex, can we simplify it? |
| 1398 | if (lvl == trailCOOStart) { |
| 1399 | Value cooSz = rewriter.create<arith::MulIOp>( |
| 1400 | loc, memSize, constantIndex(rewriter, loc, trailCOORank)); |
| 1401 | desc.setCrdMemSize(builder&: rewriter, loc, lvl, v: cooSz); |
| 1402 | } else { |
| 1403 | desc.setCrdMemSize(builder&: rewriter, loc, lvl, v: memSize); |
| 1404 | } |
| 1405 | } |
| 1406 | desc.setValMemSize(builder&: rewriter, loc, v: memSize); |
| 1407 | |
| 1408 | rewriter.replaceOpWithMultiple(op, {desc.getFields()}); |
| 1409 | return success(); |
| 1410 | } |
| 1411 | }; |
| 1412 | |
| 1413 | struct SparseDisassembleOpConverter |
| 1414 | : public OpConversionPattern<DisassembleOp> { |
| 1415 | using OpConversionPattern::OpConversionPattern; |
| 1416 | SparseDisassembleOpConverter(const TypeConverter &typeConverter, |
| 1417 | MLIRContext *context) |
| 1418 | : OpConversionPattern(typeConverter, context) {} |
| 1419 | |
| 1420 | LogicalResult |
| 1421 | matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor, |
| 1422 | ConversionPatternRewriter &rewriter) const override { |
| 1423 | auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), |
| 1424 | op.getTensor().getType()); |
| 1425 | Location loc = op.getLoc(); |
| 1426 | SmallVector<Value> retMem; |
| 1427 | SmallVector<Value> retLen; |
| 1428 | desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, |
| 1429 | &retLen](FieldIndex fid, |
| 1430 | SparseTensorFieldKind fKind, |
| 1431 | Level lvl, LevelType lt) -> bool { |
| 1432 | if (fKind == SparseTensorFieldKind::StorageSpec) |
| 1433 | return true; |
| 1434 | SparseTensorType stt(desc.getRankedTensorType()); |
| 1435 | Value sz, src; |
| 1436 | TypedValue<BaseMemRefType> dst; |
| 1437 | if (fKind == SparseTensorFieldKind::ValMemRef) { |
| 1438 | sz = desc.getValMemSize(rewriter, loc); |
| 1439 | src = desc.getValMemRef(); |
| 1440 | dst = genToMemref(rewriter, loc, op.getOutValues()); |
| 1441 | |
| 1442 | retMem.push_back(Elt: dst); |
| 1443 | Type valLenTp = op.getValLen().getType(); |
| 1444 | retLen.push_back(Elt: genScalarToTensor(builder&: rewriter, loc, elem: sz, dstTp: valLenTp)); |
| 1445 | } else { |
| 1446 | assert(fKind == SparseTensorFieldKind::PosMemRef || |
| 1447 | fKind == SparseTensorFieldKind::CrdMemRef); |
| 1448 | |
| 1449 | sz = fKind == SparseTensorFieldKind::PosMemRef |
| 1450 | ? desc.getPosMemSize(rewriter, loc, lvl) |
| 1451 | : desc.getCrdMemSize(rewriter, loc, lvl); |
| 1452 | src = desc.getMemRefField(fid); |
| 1453 | dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]); |
| 1454 | retMem.push_back(Elt: dst); |
| 1455 | // Retrieves the corresponding level length type. |
| 1456 | Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()]; |
| 1457 | retLen.push_back(Elt: genScalarToTensor(builder&: rewriter, loc, elem: sz, dstTp: lvlLenTp)); |
| 1458 | } |
| 1459 | Value flatOut = dst; |
| 1460 | if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) { |
| 1461 | auto reassoc = |
| 1462 | getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank()); |
| 1463 | flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc); |
| 1464 | } |
| 1465 | Value dstMem = genSliceToSize(builder&: rewriter, loc, mem: flatOut, sz); |
| 1466 | Value srcMem = genSliceToSize(builder&: rewriter, loc, mem: src, sz); |
| 1467 | rewriter.create<memref::CopyOp>(loc, srcMem, dstMem); |
| 1468 | return true; |
| 1469 | }); |
| 1470 | |
| 1471 | // Converts MemRefs back to Tensors. |
| 1472 | SmallVector<Value> retValues = llvm::to_vector( |
| 1473 | Range: llvm::map_range(C&: retMem, F: [&rewriter, loc](Value v) -> Value { |
| 1474 | return rewriter.create<bufferization::ToTensorOp>(loc, v); |
| 1475 | })); |
| 1476 | // Appends the actual memory length used in each buffer returned. |
| 1477 | retValues.append(in_start: retLen.begin(), in_end: retLen.end()); |
| 1478 | rewriter.replaceOp(op, retValues); |
| 1479 | return success(); |
| 1480 | } |
| 1481 | }; |
| 1482 | |
| 1483 | struct SparseNewConverter : public OpConversionPattern<NewOp> { |
| 1484 | using OpConversionPattern::OpConversionPattern; |
| 1485 | LogicalResult |
| 1486 | matchAndRewrite(NewOp op, OpAdaptor adaptor, |
| 1487 | ConversionPatternRewriter &rewriter) const override { |
| 1488 | Location loc = op.getLoc(); |
| 1489 | const auto dstTp = getSparseTensorType(op.getResult()); |
| 1490 | // Creating COO with NewOp is handled by direct IR codegen. All other cases |
| 1491 | // are handled by rewriting. |
| 1492 | if (!dstTp.hasEncoding() || dstTp.getAoSCOOStart() != 0) |
| 1493 | return failure(); |
| 1494 | |
| 1495 | // Implement as follows: |
| 1496 | // %reader = @createCheckedSparseTensorReader(%filename) |
| 1497 | // %nse = @getSparseTensorNSE(%reader) |
| 1498 | // %coo = bufferization.alloc_tensor an ordered COO with |
| 1499 | // dst dim ordering, size_hint = %nse |
| 1500 | // %coordinates = sparse_tensor.coordinates_buffer(%coo) |
| 1501 | // %values = sparse_tensor.values(%coo) |
| 1502 | // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values) |
| 1503 | // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values) |
| 1504 | // update storage specifier |
| 1505 | // @delSparseTensorReader(%reader) |
| 1506 | SmallVector<Value> dimSizesValues; |
| 1507 | Value dimSizesBuffer; |
| 1508 | Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0], |
| 1509 | dimSizesValues, dimSizesBuffer); |
| 1510 | |
| 1511 | // Get the number of stored entries. |
| 1512 | const Type indexTp = rewriter.getIndexType(); |
| 1513 | Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE" , |
| 1514 | {indexTp}, {reader}, EmitCInterface::Off) |
| 1515 | .getResult(0); |
| 1516 | |
| 1517 | // Construct the lvl sizes and the dim2lvl/lvl2dim buffers. |
| 1518 | SmallVector<Value> lvlSizesValues; |
| 1519 | Value dim2lvlBuffer; |
| 1520 | Value lvl2dimBuffer; |
| 1521 | genMapBuffers(rewriter, loc, dstTp, dimSizesValues, dimSizesBuffer, |
| 1522 | lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer); |
| 1523 | |
| 1524 | // Construct allocation for each field. |
| 1525 | Value sizeHint = nse; |
| 1526 | SmallVector<Value> fields; |
| 1527 | createAllocFields(rewriter, loc, dstTp, /*enableInit=*/false, sizeHint, |
| 1528 | lvlSizesValues, fields); |
| 1529 | |
| 1530 | // Read the COO tensor data. |
| 1531 | MutSparseTensorDescriptor desc(dstTp, fields); |
| 1532 | Value xs = desc.getAOSMemRef(); |
| 1533 | Value ys = desc.getValMemRef(); |
| 1534 | const Type boolTp = rewriter.getIntegerType(1); |
| 1535 | const Type elemTp = dstTp.getElementType(); |
| 1536 | const Type crdTp = dstTp.getCrdType(); |
| 1537 | SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers" , |
| 1538 | overheadTypeFunctionSuffix(overheadTp: crdTp), |
| 1539 | primaryTypeFunctionSuffix(elemTp)}; |
| 1540 | Value isSorted = |
| 1541 | createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp}, |
| 1542 | {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys}, |
| 1543 | EmitCInterface::On) |
| 1544 | .getResult(0); |
| 1545 | |
| 1546 | // If the destination tensor is a sorted COO, we need to sort the COO tensor |
| 1547 | // data if the input elements aren't sorted yet. |
| 1548 | const Level lvlRank = dstTp.getLvlRank(); |
| 1549 | if (dstTp.isOrderedLvl(lvlRank - 1)) { |
| 1550 | Value kFalse = constantI1(builder&: rewriter, loc, b: false); |
| 1551 | Value notSorted = rewriter.create<arith::CmpIOp>( |
| 1552 | loc, arith::CmpIPredicate::eq, isSorted, kFalse); |
| 1553 | scf::IfOp ifOp = |
| 1554 | rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false); |
| 1555 | rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 1556 | auto xPerm = rewriter.getMultiDimIdentityMap(rank: lvlRank); |
| 1557 | rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm, |
| 1558 | rewriter.getIndexAttr(0), |
| 1559 | SparseTensorSortKind::HybridQuickSort); |
| 1560 | rewriter.setInsertionPointAfter(ifOp); |
| 1561 | } |
| 1562 | |
| 1563 | // Set PosMemRef0[1] = nse. |
| 1564 | const Value c1 = constantIndex(builder&: rewriter, loc, i: 1); |
| 1565 | const Value posMemref0 = desc.getPosMemRef(0); |
| 1566 | const Type posTp = dstTp.getPosType(); |
| 1567 | const Value posNse = genCast(builder&: rewriter, loc, value: nse, dstTy: posTp); |
| 1568 | rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1); |
| 1569 | |
| 1570 | // Update storage specifier. |
| 1571 | Value coordinatesSize = rewriter.create<arith::MulIOp>( |
| 1572 | loc, nse, constantIndex(rewriter, loc, lvlRank)); |
| 1573 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0, |
| 1574 | coordinatesSize); |
| 1575 | desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize, |
| 1576 | std::nullopt, nse); |
| 1577 | |
| 1578 | // Release the sparse tensor reader. |
| 1579 | createFuncCall(builder&: rewriter, loc, name: "delSparseTensorReader" , resultType: {}, operands: {reader}, |
| 1580 | emitCInterface: EmitCInterface::Off); |
| 1581 | |
| 1582 | // Replace operation with resulting memrefs. |
| 1583 | rewriter.replaceOpWithMultiple(op, {fields}); |
| 1584 | return success(); |
| 1585 | } |
| 1586 | }; |
| 1587 | |
| 1588 | struct SparseHasRuntimeLibraryConverter |
| 1589 | : public OpConversionPattern<HasRuntimeLibraryOp> { |
| 1590 | using OpConversionPattern::OpConversionPattern; |
| 1591 | LogicalResult |
| 1592 | matchAndRewrite(HasRuntimeLibraryOp op, OpAdaptor adaptor, |
| 1593 | ConversionPatternRewriter &rewriter) const override { |
| 1594 | auto i1Type = rewriter.getI1Type(); |
| 1595 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
| 1596 | op, i1Type, rewriter.getIntegerAttr(i1Type, 0)); |
| 1597 | return success(); |
| 1598 | } |
| 1599 | }; |
| 1600 | |
| 1601 | } // namespace |
| 1602 | |
| 1603 | //===----------------------------------------------------------------------===// |
| 1604 | // Public method for populating conversion rules. |
| 1605 | //===----------------------------------------------------------------------===// |
| 1606 | |
| 1607 | /// Populates the given patterns list with conversion rules required for |
| 1608 | /// the sparsification of linear algebra operations. |
| 1609 | void mlir::populateSparseTensorCodegenPatterns( |
| 1610 | const TypeConverter &typeConverter, RewritePatternSet &patterns, |
| 1611 | bool createSparseDeallocs, bool enableBufferInitialization) { |
| 1612 | patterns.add< |
| 1613 | SparseAssembleOpConverter, SparseDisassembleOpConverter, |
| 1614 | SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter, |
| 1615 | SparseCastConverter, SparseExtractSliceConverter, |
| 1616 | SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, |
| 1617 | SparseInsertConverter, SparseReorderCOOConverter, SparseReMapConverter, |
| 1618 | SparseSliceGetterOpConverter<ToSliceOffsetOp, |
| 1619 | StorageSpecifierKind::DimOffset>, |
| 1620 | SparseSliceGetterOpConverter<ToSliceStrideOp, |
| 1621 | StorageSpecifierKind::DimStride>, |
| 1622 | SparseToPositionsConverter, SparseToCoordinatesConverter, |
| 1623 | SparseToCoordinatesBufferConverter, SparseToValuesConverter, |
| 1624 | SparseConvertConverter, SparseNewConverter, |
| 1625 | SparseNumberOfEntriesConverter, SparseHasRuntimeLibraryConverter>( |
| 1626 | typeConverter, patterns.getContext()); |
| 1627 | patterns.add<SparseTensorDeallocConverter>( |
| 1628 | arg: typeConverter, args: patterns.getContext(), args&: createSparseDeallocs); |
| 1629 | patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>( |
| 1630 | arg: typeConverter, args: patterns.getContext(), args&: enableBufferInitialization); |
| 1631 | } |
| 1632 | |