1 | //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This header file defines utilities for generating MLIR. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_ |
14 | #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_ |
15 | |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Complex/IR/Complex.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
20 | #include "mlir/Dialect/SparseTensor/IR/Enums.h" |
21 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
22 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
23 | #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
24 | #include "mlir/IR/Builders.h" |
25 | |
26 | namespace mlir { |
27 | |
28 | class Location; |
29 | class Type; |
30 | class Value; |
31 | |
32 | namespace sparse_tensor { |
33 | |
34 | /// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, |
35 | /// `createFuncCall()`, and `replaceOpWithFuncCall()`. |
36 | enum class EmitCInterface : bool { Off = false, On = true }; |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // ExecutionEngine/SparseTensorUtils helper functions. |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | /// Converts an overhead storage bitwidth to its internal type-encoding. |
43 | OverheadType overheadTypeEncoding(unsigned width); |
44 | |
45 | /// Converts an overhead storage type to its internal type-encoding. |
46 | OverheadType overheadTypeEncoding(Type tp); |
47 | |
48 | /// Converts the internal type-encoding for overhead storage to an mlir::Type. |
49 | Type getOverheadType(Builder &builder, OverheadType ot); |
50 | |
51 | /// Returns the OverheadType for position overhead storage. |
52 | OverheadType posTypeEncoding(SparseTensorEncodingAttr enc); |
53 | |
54 | /// Returns the OverheadType for coordinate overhead storage. |
55 | OverheadType crdTypeEncoding(SparseTensorEncodingAttr enc); |
56 | |
57 | /// Convert OverheadType to its function-name suffix. |
58 | StringRef overheadTypeFunctionSuffix(OverheadType ot); |
59 | |
60 | /// Converts an overhead storage type to its function-name suffix. |
61 | StringRef overheadTypeFunctionSuffix(Type overheadTp); |
62 | |
63 | /// Converts a primary storage type to its internal type-encoding. |
64 | PrimaryType primaryTypeEncoding(Type elemTp); |
65 | |
66 | /// Convert PrimaryType to its function-name suffix. |
67 | StringRef primaryTypeFunctionSuffix(PrimaryType pt); |
68 | |
69 | /// Converts a primary storage type to its function-name suffix. |
70 | StringRef primaryTypeFunctionSuffix(Type elemTp); |
71 | |
72 | //===----------------------------------------------------------------------===// |
73 | // Misc code generators and utilities. |
74 | //===----------------------------------------------------------------------===// |
75 | |
76 | /// A helper class to simplify lowering operations with/without function calls. |
77 | template <class SubClass> |
78 | class FuncCallOrInlineGenerator { |
79 | public: |
80 | FuncCallOrInlineGenerator(TypeRange retTypes, ValueRange params, bool genCall) |
81 | : retTypes(retTypes), params(params), genCall(genCall) {} |
82 | |
83 | // The main API invoked by clients, which abstracts away the details of |
84 | // creating function calls from clients. |
85 | SmallVector<Value> genCallOrInline(OpBuilder &builder, Location loc) { |
86 | if (!genCall) |
87 | return genImplementation(retTypes, params, builder, loc); |
88 | |
89 | // Looks up the function. |
90 | std::string funcName = getMangledFuncName(); |
91 | ModuleOp module = getParentOpOf<ModuleOp>(builder); |
92 | MLIRContext *context = module.getContext(); |
93 | auto result = SymbolRefAttr::get(context, funcName); |
94 | auto func = module.lookupSymbol<func::FuncOp>(result.getAttr()); |
95 | |
96 | if (!func) { |
97 | // Create the function if not already exist. |
98 | OpBuilder::InsertionGuard insertionGuard(builder); |
99 | builder.setInsertionPoint(getParentOpOf<func::FuncOp>(builder)); |
100 | func = builder.create<func::FuncOp>( |
101 | loc, funcName, |
102 | FunctionType::get(context, params.getTypes(), retTypes)); |
103 | func.setPrivate(); |
104 | // Set the insertion point to the body of the function. |
105 | Block *entryBB = func.addEntryBlock(); |
106 | builder.setInsertionPointToStart(entryBB); |
107 | ValueRange args = entryBB->getArguments(); |
108 | // Delegates to user to generate the actually implementation. |
109 | SmallVector<Value> result = |
110 | genImplementation(retTypes, params: args, builder, loc); |
111 | builder.create<func::ReturnOp>(loc, result); |
112 | } |
113 | // Returns the CallOp result. |
114 | func::CallOp call = builder.create<func::CallOp>(loc, func, params); |
115 | return call.getResults(); |
116 | } |
117 | |
118 | private: |
119 | template <class OpTp> |
120 | OpTp getParentOpOf(OpBuilder &builder) { |
121 | return builder.getInsertionBlock()->getParent()->getParentOfType<OpTp>(); |
122 | } |
123 | |
124 | // CRTP: get the mangled function name (only called when genCall=true). |
125 | std::string getMangledFuncName() { |
126 | return static_cast<SubClass *>(this)->getMangledFuncName(); |
127 | } |
128 | |
129 | // CRTP: Client implementation. |
130 | SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange params, |
131 | OpBuilder &builder, Location loc) { |
132 | return static_cast<SubClass *>(this)->genImplementation(retTypes, params, |
133 | builder, loc); |
134 | } |
135 | |
136 | private: |
137 | TypeRange retTypes; // The types of all returned results |
138 | ValueRange params; // The values of all input parameters |
139 | bool genCall; // Should the implemetantion be wrapped in a function |
140 | }; |
141 | |
142 | /// Add type casting between arith and index types when needed. |
143 | Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy); |
144 | |
145 | /// Add conversion from scalar to given type (possibly a 0-rank tensor). |
146 | Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, |
147 | Type dstTp); |
148 | |
149 | /// Generates a pointer/index load from the sparse storage scheme. Narrower |
150 | /// data types need to be zero extended before casting the value into the |
151 | /// index type used for looping and indexing. |
152 | Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s); |
153 | |
154 | /// Generates a 1-valued attribute of the given type. This supports |
155 | /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, |
156 | /// for unsupported types we raise `llvm_unreachable` rather than |
157 | /// returning a null attribute. |
158 | TypedAttr getOneAttr(Builder &builder, Type tp); |
159 | |
160 | /// Generates the comparison `v != 0` where `v` is of numeric type. |
161 | /// For floating types, we use the "unordered" comparator (i.e., returns |
162 | /// true if `v` is NaN). |
163 | Value genIsNonzero(OpBuilder &builder, Location loc, Value v); |
164 | |
165 | /// Computes the shape of destination tensor of a reshape operator. This is only |
166 | /// used when operands have dynamic shape. The shape of the destination is |
167 | /// stored into dstShape. |
168 | void genReshapeDstShape(OpBuilder &builder, Location loc, |
169 | SmallVectorImpl<Value> &dstShape, |
170 | ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape, |
171 | ArrayRef<ReassociationIndices> reassociation); |
172 | |
173 | /// Reshape coordinates during a reshaping operation. |
174 | void reshapeCvs(OpBuilder &builder, Location loc, |
175 | ArrayRef<ReassociationIndices> reassociation, |
176 | ValueRange srcSizes, ValueRange srcCvs, // NOLINT |
177 | ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs); |
178 | |
179 | /// Returns a function reference (first hit also inserts into module). Sets |
180 | /// the "_emit_c_interface" on the function declaration when requested, |
181 | /// so that LLVM lowering generates a wrapper function that takes care |
182 | /// of ABI complications with passing in and returning MemRefs to C functions. |
183 | FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType, |
184 | ValueRange operands, EmitCInterface emitCInterface); |
185 | |
186 | /// Creates a `CallOp` to the function reference returned by `getFunc()` in |
187 | /// the builder's module. |
188 | func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name, |
189 | TypeRange resultType, ValueRange operands, |
190 | EmitCInterface emitCInterface); |
191 | |
192 | /// Returns the equivalent of `void*` for opaque arguments to the |
193 | /// execution engine. |
194 | Type getOpaquePointerType(MLIRContext *ctx); |
195 | Type getOpaquePointerType(Builder &builder); |
196 | |
197 | /// Generates an uninitialized temporary buffer of the given size and |
198 | /// type, but returns it as type `memref<? x $tp>` (rather than as type |
199 | /// `memref<$sz x $tp>`). |
200 | Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp); |
201 | |
202 | /// Generates an uninitialized temporary buffer of the given size and |
203 | /// type, and returns it as type `memref<? x $tp>` (staticShape=false) or |
204 | /// `memref<$sz x $tp>` (staticShape=true). |
205 | Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp, |
206 | bool staticShape = false); |
207 | |
208 | /// Generates an uninitialized temporary buffer with room for one value |
209 | /// of the given type, and returns the `memref<$tp>`. |
210 | Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp); |
211 | |
212 | /// Generates a temporary buffer, initializes it with the given contents, |
213 | /// and returns it as type `memref<? x $tp>` (rather than specifying the |
214 | /// size of the buffer). |
215 | Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values); |
216 | |
217 | /// Generates code to allocate a buffer of the given type, and zero |
218 | /// initialize it. If the buffer type has any dynamic sizes, then the |
219 | /// `sizes` parameter should be as filled by sizesFromPtr(); that way |
220 | /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). |
221 | Value allocDenseTensor(OpBuilder &builder, Location loc, |
222 | RankedTensorType tensorTp, ValueRange sizes); |
223 | |
224 | /// Generates code to deallocate a dense buffer. |
225 | void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer); |
226 | |
227 | /// Populates given sizes array from dense tensor or sparse tensor constant. |
228 | void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes, |
229 | Location loc, Value src); |
230 | |
231 | /// Scans to top of generated loop. |
232 | Operation *getTop(Operation *op); |
233 | |
234 | /// Iterate over a sparse constant, generates constantOp for value |
235 | /// and coordinates. E.g., |
236 | /// sparse<[ [0], [28], [31] ], |
237 | /// [ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > |
238 | /// => |
239 | /// %c1 = arith.constant 0 |
240 | /// %v1 = complex.constant (5.13, 2.0) |
241 | /// callback({%c1}, %v1) |
242 | /// |
243 | /// %c2 = arith.constant 28 |
244 | /// %v2 = complex.constant (3.0, 4.0) |
245 | /// callback({%c2}, %v2) |
246 | /// |
247 | /// %c3 = arith.constant 31 |
248 | /// %v3 = complex.constant (5.0, 6.0) |
249 | /// callback({%c3}, %v3) |
250 | void foreachInSparseConstant( |
251 | OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, |
252 | function_ref<void(ArrayRef<Value>, Value)> callback); |
253 | |
254 | /// Loads `size`-many values from the memref, which must have rank-1 and |
255 | /// size greater-or-equal to `size`. If the optional `(offsetIdx,offsetVal)` |
256 | /// arguments are provided, then the `offsetVal` will be added to the |
257 | /// `offsetIdx`-th value after loading. |
258 | SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size, |
259 | Value mem, size_t offsetIdx = 0, |
260 | Value offsetVal = Value()); |
261 | |
262 | /// Stores all the values of `vs` into the memref `mem`, which must have |
263 | /// rank-1 and size greater-or-equal to `vs.size()`. If the optional |
264 | /// `(offsetIdx,offsetVal)` arguments are provided, then the `offsetVal` |
265 | /// will be added to the `offsetIdx`-th value before storing. |
266 | void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, |
267 | size_t offsetIdx = 0, Value offsetVal = Value()); |
268 | |
269 | // Generates code to cast a tensor to a memref. |
270 | TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc, |
271 | Value tensor); |
272 | |
273 | /// Generates code to retrieve the slice offset for the sparse tensor slice, |
274 | /// return a constant if the offset is statically known. |
275 | Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, |
276 | Dimension dim); |
277 | |
278 | /// Generates code to retrieve the slice slice for the sparse tensor slice, |
279 | /// return a constant if the offset is statically known. |
280 | Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, |
281 | Dimension dim); |
282 | |
283 | /// Generates code that opens a reader and sets the dimension sizes. |
284 | Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, |
285 | Value tensor, |
286 | /*out*/ SmallVectorImpl<Value> &dimSizesValues, |
287 | /*out*/ Value &dimSizesBuffer); |
288 | |
289 | /// Generates code to set up the buffer parameters for a map. |
290 | Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, |
291 | ArrayRef<Value> dimSizesValues, Value dimSizesBuffer, |
292 | /*out*/ SmallVectorImpl<Value> &lvlSizesValues, |
293 | /*out*/ Value &dim2lvlBuffer, |
294 | /*out*/ Value &lvl2dimBuffer); |
295 | |
296 | //===----------------------------------------------------------------------===// |
297 | // Inlined constant generators. |
298 | // |
299 | // All these functions are just wrappers to improve code legibility; |
300 | // therefore, we mark them as `inline` to avoid introducing any additional |
301 | // overhead due to the legibility. Ideally these should move upstream. |
302 | // |
303 | //===----------------------------------------------------------------------===// |
304 | |
305 | /// Generates a 0-valued constant of the given type. In addition to |
306 | /// the scalar types (`ComplexType`, `FloatType`, `IndexType`, |
307 | /// `IntegerType`), this also works for `RankedTensorType` and `VectorType` |
308 | /// (for which it generates a constant `DenseElementsAttr` of zeros). |
309 | inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { |
310 | if (auto ctp = dyn_cast<ComplexType>(tp)) { |
311 | auto zeroe = builder.getZeroAttr(type: ctp.getElementType()); |
312 | auto zeroa = builder.getArrayAttr(value: {zeroe, zeroe}); |
313 | return builder.create<complex::ConstantOp>(loc, tp, zeroa); |
314 | } |
315 | return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp)); |
316 | } |
317 | |
318 | /// Generates a 1-valued constant of the given type. This supports all |
319 | /// the same types as `constantZero`. |
320 | inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { |
321 | if (auto ctp = dyn_cast<ComplexType>(tp)) { |
322 | auto zeroe = builder.getZeroAttr(type: ctp.getElementType()); |
323 | auto onee = getOneAttr(builder, ctp.getElementType()); |
324 | auto zeroa = builder.getArrayAttr(value: {onee, zeroe}); |
325 | return builder.create<complex::ConstantOp>(loc, tp, zeroa); |
326 | } |
327 | return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp)); |
328 | } |
329 | |
330 | /// Generates a constant of `index` type. |
331 | inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) { |
332 | return builder.create<arith::ConstantIndexOp>(location: loc, args&: i); |
333 | } |
334 | |
335 | /// Generates a constant of `i64` type. |
336 | inline Value constantI64(OpBuilder &builder, Location loc, int64_t i) { |
337 | return builder.create<arith::ConstantIntOp>(location: loc, args&: i, args: 64); |
338 | } |
339 | |
340 | /// Generates a constant of `i32` type. |
341 | inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) { |
342 | return builder.create<arith::ConstantIntOp>(location: loc, args&: i, args: 32); |
343 | } |
344 | |
345 | /// Generates a constant of `i16` type. |
346 | inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) { |
347 | return builder.create<arith::ConstantIntOp>(location: loc, args&: i, args: 16); |
348 | } |
349 | |
350 | /// Generates a constant of `i8` type. |
351 | inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) { |
352 | return builder.create<arith::ConstantIntOp>(location: loc, args&: i, args: 8); |
353 | } |
354 | |
355 | /// Generates a constant of `i1` type. |
356 | inline Value constantI1(OpBuilder &builder, Location loc, bool b) { |
357 | return builder.create<arith::ConstantIntOp>(location: loc, args&: b, args: 1); |
358 | } |
359 | |
360 | /// Generates a constant of the given `Action`. |
361 | inline Value constantAction(OpBuilder &builder, Location loc, Action action) { |
362 | return constantI32(builder, loc, i: static_cast<uint32_t>(action)); |
363 | } |
364 | |
365 | /// Generates a constant of the internal type-encoding for overhead storage. |
366 | inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc, |
367 | unsigned width) { |
368 | return constantI32(builder, loc, |
369 | i: static_cast<uint32_t>(overheadTypeEncoding(width))); |
370 | } |
371 | |
372 | /// Generates a constant of the internal type-encoding for position |
373 | /// overhead storage. |
374 | inline Value constantPosTypeEncoding(OpBuilder &builder, Location loc, |
375 | SparseTensorEncodingAttr enc) { |
376 | return constantOverheadTypeEncoding(builder, loc, enc.getPosWidth()); |
377 | } |
378 | |
379 | /// Generates a constant of the internal type-encoding for coordinate |
380 | /// overhead storage. |
381 | inline Value constantCrdTypeEncoding(OpBuilder &builder, Location loc, |
382 | SparseTensorEncodingAttr enc) { |
383 | return constantOverheadTypeEncoding(builder, loc, enc.getCrdWidth()); |
384 | } |
385 | |
386 | /// Generates a constant of the internal type-encoding for primary storage. |
387 | inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, |
388 | Type elemTp) { |
389 | return constantI32(builder, loc, |
390 | i: static_cast<uint32_t>(primaryTypeEncoding(elemTp))); |
391 | } |
392 | |
393 | /// Generates a constant of the internal dimension level type encoding. |
394 | inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, |
395 | LevelType lt) { |
396 | return constantI64(builder, loc, i: static_cast<uint64_t>(lt)); |
397 | } |
398 | |
399 | // Generates a constant from a validated value carrying attribute. |
400 | inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) { |
401 | if (auto complexAttr = dyn_cast<complex::NumberAttr>(attr)) { |
402 | Type tp = cast<ComplexType>(complexAttr.getType()).getElementType(); |
403 | return builder.create<complex::ConstantOp>( |
404 | loc, complexAttr.getType(), |
405 | builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()), |
406 | FloatAttr::get(tp, complexAttr.getImag())})); |
407 | } |
408 | return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr)); |
409 | } |
410 | |
411 | // TODO: is this at the right place? |
412 | inline bool isZeroRankedTensorOrScalar(Type type) { |
413 | auto rtp = dyn_cast<RankedTensorType>(type); |
414 | return !rtp || rtp.getRank() == 0; |
415 | } |
416 | |
417 | } // namespace sparse_tensor |
418 | } // namespace mlir |
419 | |
420 | #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_ |
421 | |