| 1 | //===- SparseAssembler.cpp - adds wrapper method around sparse types ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "Utils/CodegenUtils.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 12 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
| 14 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| 15 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| 16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 17 | #include "llvm/Support/FormatVariadic.h" |
| 18 | |
| 19 | using namespace mlir; |
| 20 | using namespace sparse_tensor; |
| 21 | |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | // Helper methods. |
| 24 | //===----------------------------------------------------------------------===// |
| 25 | |
| 26 | // Convert type range to new types range, with sparse tensors externalized. |
| 27 | static void convTypes(bool &hasAnnotation, TypeRange types, |
| 28 | SmallVectorImpl<Type> &convTypes, |
| 29 | SmallVectorImpl<Type> *, bool directOut) { |
| 30 | for (auto type : types) { |
| 31 | // All "dense" data passes through unmodified. |
| 32 | if (!getSparseTensorEncoding(type)) { |
| 33 | convTypes.push_back(Elt: type); |
| 34 | continue; |
| 35 | } |
| 36 | hasAnnotation = true; |
| 37 | |
| 38 | // Convert the external representations of the pos/crd/val arrays. |
| 39 | const SparseTensorType stt(cast<RankedTensorType>(type)); |
| 40 | foreachFieldAndTypeInSparseTensor( |
| 41 | stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex, |
| 42 | SparseTensorFieldKind kind, |
| 43 | Level, LevelType) { |
| 44 | if (kind == SparseTensorFieldKind::PosMemRef || |
| 45 | kind == SparseTensorFieldKind::CrdMemRef || |
| 46 | kind == SparseTensorFieldKind::ValMemRef) { |
| 47 | auto rtp = cast<ShapedType>(t); |
| 48 | if (!directOut) { |
| 49 | rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); |
| 50 | if (extraTypes) |
| 51 | extraTypes->push_back(Elt: rtp); |
| 52 | } |
| 53 | convTypes.push_back(Elt: rtp); |
| 54 | } |
| 55 | return true; |
| 56 | }); |
| 57 | } |
| 58 | } |
| 59 | |
| 60 | // Convert input and output values to [dis]assemble ops for sparse tensors. |
| 61 | static void convVals(OpBuilder &builder, Location loc, TypeRange types, |
| 62 | ValueRange fromVals, ValueRange , |
| 63 | SmallVectorImpl<Value> &toVals, unsigned , bool isIn, |
| 64 | bool directOut) { |
| 65 | unsigned idx = 0; |
| 66 | for (auto type : types) { |
| 67 | // All "dense" data passes through unmodified. |
| 68 | if (!getSparseTensorEncoding(type)) { |
| 69 | toVals.push_back(Elt: fromVals[idx++]); |
| 70 | continue; |
| 71 | } |
| 72 | // Handle sparse data. |
| 73 | auto rtp = cast<RankedTensorType>(type); |
| 74 | const SparseTensorType stt(rtp); |
| 75 | SmallVector<Value> inputs; |
| 76 | SmallVector<Type> retTypes; |
| 77 | SmallVector<Type> cntTypes; |
| 78 | if (!isIn) |
| 79 | inputs.push_back(Elt: fromVals[idx++]); // The sparse tensor to disassemble |
| 80 | |
| 81 | // Collect the external representations of the pos/crd/val arrays. |
| 82 | foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex, |
| 83 | SparseTensorFieldKind kind, |
| 84 | Level lv, LevelType) { |
| 85 | if (kind == SparseTensorFieldKind::PosMemRef || |
| 86 | kind == SparseTensorFieldKind::CrdMemRef || |
| 87 | kind == SparseTensorFieldKind::ValMemRef) { |
| 88 | if (isIn) { |
| 89 | inputs.push_back(Elt: fromVals[idx++]); |
| 90 | } else if (directOut) { |
| 91 | Value mem; |
| 92 | if (kind == SparseTensorFieldKind::PosMemRef) |
| 93 | mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0], |
| 94 | lv); |
| 95 | else if (kind == SparseTensorFieldKind::CrdMemRef) |
| 96 | mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0], |
| 97 | lv); |
| 98 | else |
| 99 | mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]); |
| 100 | toVals.push_back(Elt: mem); |
| 101 | } else { |
| 102 | ShapedType rtp = cast<ShapedType>(t); |
| 103 | rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); |
| 104 | inputs.push_back(Elt: extraVals[extra++]); |
| 105 | retTypes.push_back(Elt: rtp); |
| 106 | cntTypes.push_back(builder.getIndexType()); |
| 107 | } |
| 108 | } |
| 109 | return true; |
| 110 | }); |
| 111 | |
| 112 | if (isIn) { |
| 113 | // Assemble multiple inputs into a single sparse tensor. |
| 114 | auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs); |
| 115 | toVals.push_back(Elt: a.getResult()); |
| 116 | } else if (!directOut) { |
| 117 | // Disassemble a single sparse input into multiple outputs. |
| 118 | // Note that this includes the counters, which are dropped. |
| 119 | unsigned len = retTypes.size(); |
| 120 | retTypes.append(RHS: cntTypes); |
| 121 | auto d = |
| 122 | builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs); |
| 123 | for (unsigned i = 0; i < len; i++) |
| 124 | toVals.push_back(Elt: d.getResult(i)); |
| 125 | } |
| 126 | } |
| 127 | } |
| 128 | |
| 129 | //===----------------------------------------------------------------------===// |
| 130 | // Rewriting rules. |
| 131 | //===----------------------------------------------------------------------===// |
| 132 | |
| 133 | namespace { |
| 134 | |
| 135 | // A rewriting rules that converts public entry methods that use sparse tensors |
| 136 | // as input parameters and/or output return values into wrapper methods that |
| 137 | // [dis]assemble the individual tensors that constitute the actual storage used |
| 138 | // externally into MLIR sparse tensors before calling the original method. |
| 139 | // |
| 140 | // In particular, each sparse tensor input |
| 141 | // |
| 142 | // void foo(..., t, ...) { } |
| 143 | // |
| 144 | // makes the original foo() internal and adds the following wrapper method |
| 145 | // |
| 146 | // void foo(..., t1..tn, ...) { |
| 147 | // t = assemble t1..tn |
| 148 | // _internal_foo(..., t, ...) |
| 149 | // } |
| 150 | // |
| 151 | // and likewise, each output tensor |
| 152 | // |
| 153 | // ... T ... bar(...) { return ..., t, ...; } |
| 154 | // |
| 155 | // makes the original bar() internal and adds the following wrapper method |
| 156 | // |
| 157 | // ... T1..TN ... bar(..., t1'..tn') { |
| 158 | // ..., t, ... = _internal_bar(...) |
| 159 | // t1..tn = disassemble t, t1'..tn' |
| 160 | // return ..., t1..tn, ... |
| 161 | // } |
| 162 | // |
| 163 | // (with a direct-out variant without the disassemble). |
| 164 | // |
| 165 | struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> { |
| 166 | using OpRewritePattern::OpRewritePattern; |
| 167 | |
| 168 | SparseFuncAssembler(MLIRContext *context, bool dO) |
| 169 | : OpRewritePattern(context), directOut(dO) {} |
| 170 | |
| 171 | LogicalResult matchAndRewrite(func::FuncOp funcOp, |
| 172 | PatternRewriter &rewriter) const override { |
| 173 | // Only rewrite public entry methods. |
| 174 | if (funcOp.isPrivate()) |
| 175 | return failure(); |
| 176 | |
| 177 | // Translate sparse tensor types to external types. |
| 178 | SmallVector<Type> inputTypes; |
| 179 | SmallVector<Type> outputTypes; |
| 180 | SmallVector<Type> ; |
| 181 | bool hasAnnotation = false; |
| 182 | convTypes(hasAnnotation, funcOp.getArgumentTypes(), inputTypes, nullptr, |
| 183 | false); |
| 184 | convTypes(hasAnnotation, funcOp.getResultTypes(), outputTypes, &extraTypes, |
| 185 | directOut); |
| 186 | |
| 187 | // Only sparse inputs or outputs need a wrapper method. |
| 188 | if (!hasAnnotation) |
| 189 | return failure(); |
| 190 | |
| 191 | // Modify the original method into an internal, private method. |
| 192 | auto orgName = funcOp.getName(); |
| 193 | std::string wrapper = llvm::formatv("_internal_{0}" , orgName).str(); |
| 194 | funcOp.setName(wrapper); |
| 195 | funcOp.setPrivate(); |
| 196 | |
| 197 | // Start the new public wrapper method with original name. |
| 198 | Location loc = funcOp.getLoc(); |
| 199 | ModuleOp modOp = funcOp->getParentOfType<ModuleOp>(); |
| 200 | MLIRContext *context = modOp.getContext(); |
| 201 | OpBuilder moduleBuilder(modOp.getBodyRegion()); |
| 202 | unsigned = inputTypes.size(); |
| 203 | inputTypes.append(RHS: extraTypes); |
| 204 | auto func = moduleBuilder.create<func::FuncOp>( |
| 205 | loc, orgName, FunctionType::get(context, inputTypes, outputTypes)); |
| 206 | func.setPublic(); |
| 207 | |
| 208 | // Construct new wrapper method body. |
| 209 | OpBuilder::InsertionGuard insertionGuard(rewriter); |
| 210 | Block *body = func.addEntryBlock(); |
| 211 | rewriter.setInsertionPointToStart(body); |
| 212 | |
| 213 | // Convert inputs. |
| 214 | SmallVector<Value> inputs; |
| 215 | convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(), |
| 216 | ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut); |
| 217 | |
| 218 | // Call the original, now private method. A subsequent inlining pass can |
| 219 | // determine whether cloning the method body in place is worthwhile. |
| 220 | auto org = SymbolRefAttr::get(context, wrapper); |
| 221 | auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org, |
| 222 | inputs); |
| 223 | |
| 224 | // Convert outputs and return. |
| 225 | SmallVector<Value> outputs; |
| 226 | convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(), |
| 227 | body->getArguments(), outputs, extra, /*isIn=*/false, directOut); |
| 228 | rewriter.create<func::ReturnOp>(loc, outputs); |
| 229 | |
| 230 | // Finally, migrate a potential c-interface property. |
| 231 | if (funcOp->getAttrOfType<UnitAttr>( |
| 232 | LLVM::LLVMDialect::getEmitCWrapperAttrName())) { |
| 233 | func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
| 234 | UnitAttr::get(context)); |
| 235 | funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName()); |
| 236 | } |
| 237 | return success(); |
| 238 | } |
| 239 | |
| 240 | private: |
| 241 | const bool directOut; |
| 242 | }; |
| 243 | |
| 244 | } // namespace |
| 245 | |
| 246 | //===----------------------------------------------------------------------===// |
| 247 | // Public method for populating conversion rules. |
| 248 | //===----------------------------------------------------------------------===// |
| 249 | |
| 250 | void mlir::populateSparseAssembler(RewritePatternSet &patterns, |
| 251 | bool directOut) { |
| 252 | patterns.add<SparseFuncAssembler>(arg: patterns.getContext(), args&: directOut); |
| 253 | } |
| 254 | |