1//===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines utilities for generating MLIR.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_
14#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Complex/IR/Complex.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Dialect/SparseTensor/IR/Enums.h"
21#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
23#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24#include "mlir/IR/Builders.h"
25
26namespace mlir {
27
28class Location;
29class Type;
30class Value;
31
32namespace sparse_tensor {
33
34/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
35/// `createFuncCall()`, and `replaceOpWithFuncCall()`.
36enum class EmitCInterface : bool { Off = false, On = true };
37
38//===----------------------------------------------------------------------===//
39// ExecutionEngine/SparseTensorUtils helper functions.
40//===----------------------------------------------------------------------===//
41
42/// Converts an overhead storage bitwidth to its internal type-encoding.
43OverheadType overheadTypeEncoding(unsigned width);
44
45/// Converts an overhead storage type to its internal type-encoding.
46OverheadType overheadTypeEncoding(Type tp);
47
48/// Converts the internal type-encoding for overhead storage to an mlir::Type.
49Type getOverheadType(Builder &builder, OverheadType ot);
50
51/// Returns the OverheadType for position overhead storage.
52OverheadType posTypeEncoding(SparseTensorEncodingAttr enc);
53
54/// Returns the OverheadType for coordinate overhead storage.
55OverheadType crdTypeEncoding(SparseTensorEncodingAttr enc);
56
57/// Convert OverheadType to its function-name suffix.
58StringRef overheadTypeFunctionSuffix(OverheadType ot);
59
60/// Converts an overhead storage type to its function-name suffix.
61StringRef overheadTypeFunctionSuffix(Type overheadTp);
62
63/// Converts a primary storage type to its internal type-encoding.
64PrimaryType primaryTypeEncoding(Type elemTp);
65
66/// Convert PrimaryType to its function-name suffix.
67StringRef primaryTypeFunctionSuffix(PrimaryType pt);
68
69/// Converts a primary storage type to its function-name suffix.
70StringRef primaryTypeFunctionSuffix(Type elemTp);
71
72//===----------------------------------------------------------------------===//
73// Misc code generators and utilities.
74//===----------------------------------------------------------------------===//
75
76/// A helper class to simplify lowering operations with/without function calls.
77template <class SubClass>
78class FuncCallOrInlineGenerator {
79public:
80 FuncCallOrInlineGenerator(TypeRange retTypes, ValueRange params, bool genCall)
81 : retTypes(retTypes), params(params), genCall(genCall) {}
82
83 // The main API invoked by clients, which abstracts away the details of
84 // creating function calls from clients.
85 SmallVector<Value> genCallOrInline(OpBuilder &builder, Location loc) {
86 if (!genCall)
87 return genImplementation(retTypes, params, builder, loc);
88
89 // Looks up the function.
90 std::string funcName = getMangledFuncName();
91 ModuleOp module = getParentOpOf<ModuleOp>(builder);
92 MLIRContext *context = module.getContext();
93 auto result = SymbolRefAttr::get(context, funcName);
94 auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
95
96 if (!func) {
97 // Create the function if not already exist.
98 OpBuilder::InsertionGuard insertionGuard(builder);
99 builder.setInsertionPoint(getParentOpOf<func::FuncOp>(builder));
100 func = builder.create<func::FuncOp>(
101 loc, funcName,
102 FunctionType::get(context, params.getTypes(), retTypes));
103 func.setPrivate();
104 // Set the insertion point to the body of the function.
105 Block *entryBB = func.addEntryBlock();
106 builder.setInsertionPointToStart(entryBB);
107 ValueRange args = entryBB->getArguments();
108 // Delegates to user to generate the actually implementation.
109 SmallVector<Value> result =
110 genImplementation(retTypes, params: args, builder, loc);
111 builder.create<func::ReturnOp>(loc, result);
112 }
113 // Returns the CallOp result.
114 func::CallOp call = builder.create<func::CallOp>(loc, func, params);
115 return call.getResults();
116 }
117
118private:
119 template <class OpTp>
120 OpTp getParentOpOf(OpBuilder &builder) {
121 return builder.getInsertionBlock()->getParent()->getParentOfType<OpTp>();
122 }
123
124 // CRTP: get the mangled function name (only called when genCall=true).
125 std::string getMangledFuncName() {
126 return static_cast<SubClass *>(this)->getMangledFuncName();
127 }
128
129 // CRTP: Client implementation.
130 SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange params,
131 OpBuilder &builder, Location loc) {
132 return static_cast<SubClass *>(this)->genImplementation(retTypes, params,
133 builder, loc);
134 }
135
136private:
137 TypeRange retTypes; // The types of all returned results
138 ValueRange params; // The values of all input parameters
139 bool genCall; // Should the implemetantion be wrapped in a function
140};
141
142/// Add type casting between arith and index types when needed.
143Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
144
145/// Add conversion from scalar to given type (possibly a 0-rank tensor).
146Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
147 Type dstTp);
148
149/// Generates a pointer/index load from the sparse storage scheme. Narrower
150/// data types need to be zero extended before casting the value into the
151/// index type used for looping and indexing.
152Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s);
153
154/// Generates a 1-valued attribute of the given type. This supports
155/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
156/// for unsupported types we raise `llvm_unreachable` rather than
157/// returning a null attribute.
158TypedAttr getOneAttr(Builder &builder, Type tp);
159
160/// Generates the comparison `v != 0` where `v` is of numeric type.
161/// For floating types, we use the "unordered" comparator (i.e., returns
162/// true if `v` is NaN).
163Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
164
165/// Computes the shape of destination tensor of a reshape operator. This is only
166/// used when operands have dynamic shape. The shape of the destination is
167/// stored into dstShape.
168void genReshapeDstShape(OpBuilder &builder, Location loc,
169 SmallVectorImpl<Value> &dstShape,
170 ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
171 ArrayRef<ReassociationIndices> reassociation);
172
173/// Reshape coordinates during a reshaping operation.
174void reshapeCvs(OpBuilder &builder, Location loc,
175 ArrayRef<ReassociationIndices> reassociation,
176 ValueRange srcSizes, ValueRange srcCvs, // NOLINT
177 ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs);
178
179/// Returns a function reference (first hit also inserts into module). Sets
180/// the "_emit_c_interface" on the function declaration when requested,
181/// so that LLVM lowering generates a wrapper function that takes care
182/// of ABI complications with passing in and returning MemRefs to C functions.
183FlatSymbolRefAttr getFunc(ModuleOp module, StringRef name, TypeRange resultType,
184 ValueRange operands, EmitCInterface emitCInterface);
185
186/// Creates a `CallOp` to the function reference returned by `getFunc()` in
187/// the builder's module.
188func::CallOp createFuncCall(OpBuilder &builder, Location loc, StringRef name,
189 TypeRange resultType, ValueRange operands,
190 EmitCInterface emitCInterface);
191
192/// Returns the equivalent of `void*` for opaque arguments to the
193/// execution engine.
194Type getOpaquePointerType(MLIRContext *ctx);
195Type getOpaquePointerType(Builder &builder);
196
197/// Generates an uninitialized temporary buffer of the given size and
198/// type, but returns it as type `memref<? x $tp>` (rather than as type
199/// `memref<$sz x $tp>`).
200Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp);
201
202/// Generates an uninitialized temporary buffer of the given size and
203/// type, and returns it as type `memref<? x $tp>` (staticShape=false) or
204/// `memref<$sz x $tp>` (staticShape=true).
205Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp,
206 bool staticShape = false);
207
208/// Generates an uninitialized temporary buffer with room for one value
209/// of the given type, and returns the `memref<$tp>`.
210Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp);
211
212/// Generates a temporary buffer, initializes it with the given contents,
213/// and returns it as type `memref<? x $tp>` (rather than specifying the
214/// size of the buffer).
215Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values);
216
217/// Generates code to allocate a buffer of the given type, and zero
218/// initialize it. If the buffer type has any dynamic sizes, then the
219/// `sizes` parameter should be as filled by sizesFromPtr(); that way
220/// we can reuse the genDimSizeCall() results generated by sizesFromPtr().
221Value allocDenseTensor(OpBuilder &builder, Location loc,
222 RankedTensorType tensorTp, ValueRange sizes);
223
224/// Generates code to deallocate a dense buffer.
225void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer);
226
227/// Populates given sizes array from dense tensor or sparse tensor constant.
228void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
229 Location loc, Value src);
230
231/// Scans to top of generated loop.
232Operation *getTop(Operation *op);
233
234/// Iterate over a sparse constant, generates constantOp for value
235/// and coordinates. E.g.,
236/// sparse<[ [0], [28], [31] ],
237/// [ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] >
238/// =>
239/// %c1 = arith.constant 0
240/// %v1 = complex.constant (5.13, 2.0)
241/// callback({%c1}, %v1)
242///
243/// %c2 = arith.constant 28
244/// %v2 = complex.constant (3.0, 4.0)
245/// callback({%c2}, %v2)
246///
247/// %c3 = arith.constant 31
248/// %v3 = complex.constant (5.0, 6.0)
249/// callback({%c3}, %v3)
250void foreachInSparseConstant(
251 OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
252 function_ref<void(ArrayRef<Value>, Value)> callback);
253
254/// Loads `size`-many values from the memref, which must have rank-1 and
255/// size greater-or-equal to `size`. If the optional `(offsetIdx,offsetVal)`
256/// arguments are provided, then the `offsetVal` will be added to the
257/// `offsetIdx`-th value after loading.
258SmallVector<Value> loadAll(OpBuilder &builder, Location loc, size_t size,
259 Value mem, size_t offsetIdx = 0,
260 Value offsetVal = Value());
261
262/// Stores all the values of `vs` into the memref `mem`, which must have
263/// rank-1 and size greater-or-equal to `vs.size()`. If the optional
264/// `(offsetIdx,offsetVal)` arguments are provided, then the `offsetVal`
265/// will be added to the `offsetIdx`-th value before storing.
266void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
267 size_t offsetIdx = 0, Value offsetVal = Value());
268
269// Generates code to cast a tensor to a memref.
270TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
271 Value tensor);
272
273/// Generates code to retrieve the slice offset for the sparse tensor slice,
274/// return a constant if the offset is statically known.
275Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
276 Dimension dim);
277
278/// Generates code to retrieve the slice slice for the sparse tensor slice,
279/// return a constant if the offset is statically known.
280Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor,
281 Dimension dim);
282
283/// Generates code that opens a reader and sets the dimension sizes.
284Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt,
285 Value tensor,
286 /*out*/ SmallVectorImpl<Value> &dimSizesValues,
287 /*out*/ Value &dimSizesBuffer);
288
289/// Generates code to set up the buffer parameters for a map.
290Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt,
291 ArrayRef<Value> dimSizesValues, Value dimSizesBuffer,
292 /*out*/ SmallVectorImpl<Value> &lvlSizesValues,
293 /*out*/ Value &dim2lvlBuffer,
294 /*out*/ Value &lvl2dimBuffer);
295
296//===----------------------------------------------------------------------===//
297// Inlined constant generators.
298//
299// All these functions are just wrappers to improve code legibility;
300// therefore, we mark them as `inline` to avoid introducing any additional
301// overhead due to the legibility. Ideally these should move upstream.
302//
303//===----------------------------------------------------------------------===//
304
305/// Generates a 0-valued constant of the given type. In addition to
306/// the scalar types (`ComplexType`, `FloatType`, `IndexType`,
307/// `IntegerType`), this also works for `RankedTensorType` and `VectorType`
308/// (for which it generates a constant `DenseElementsAttr` of zeros).
309inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
310 if (auto ctp = dyn_cast<ComplexType>(tp)) {
311 auto zeroe = builder.getZeroAttr(type: ctp.getElementType());
312 auto zeroa = builder.getArrayAttr(value: {zeroe, zeroe});
313 return builder.create<complex::ConstantOp>(loc, tp, zeroa);
314 }
315 return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp));
316}
317
318/// Generates a 1-valued constant of the given type. This supports all
319/// the same types as `constantZero`.
320inline Value constantOne(OpBuilder &builder, Location loc, Type tp) {
321 if (auto ctp = dyn_cast<ComplexType>(tp)) {
322 auto zeroe = builder.getZeroAttr(type: ctp.getElementType());
323 auto onee = getOneAttr(builder, ctp.getElementType());
324 auto zeroa = builder.getArrayAttr(value: {onee, zeroe});
325 return builder.create<complex::ConstantOp>(loc, tp, zeroa);
326 }
327 return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp));
328}
329
330/// Generates a constant of `index` type.
331inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) {
332 return builder.create<arith::ConstantIndexOp>(location: loc, args&: i);
333}
334
335/// Generates a constant of `i64` type.
336inline Value constantI64(OpBuilder &builder, Location loc, int64_t i) {
337 return builder.create<arith::ConstantIntOp>(location: loc, args&: i, args: 64);
338}
339
340/// Generates a constant of `i32` type.
341inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) {
342 return builder.create<arith::ConstantIntOp>(location: loc, args&: i, args: 32);
343}
344
345/// Generates a constant of `i16` type.
346inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) {
347 return builder.create<arith::ConstantIntOp>(location: loc, args&: i, args: 16);
348}
349
350/// Generates a constant of `i8` type.
351inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) {
352 return builder.create<arith::ConstantIntOp>(location: loc, args&: i, args: 8);
353}
354
355/// Generates a constant of `i1` type.
356inline Value constantI1(OpBuilder &builder, Location loc, bool b) {
357 return builder.create<arith::ConstantIntOp>(location: loc, args&: b, args: 1);
358}
359
360/// Generates a constant of the given `Action`.
361inline Value constantAction(OpBuilder &builder, Location loc, Action action) {
362 return constantI32(builder, loc, i: static_cast<uint32_t>(action));
363}
364
365/// Generates a constant of the internal type-encoding for overhead storage.
366inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc,
367 unsigned width) {
368 return constantI32(builder, loc,
369 i: static_cast<uint32_t>(overheadTypeEncoding(width)));
370}
371
372/// Generates a constant of the internal type-encoding for position
373/// overhead storage.
374inline Value constantPosTypeEncoding(OpBuilder &builder, Location loc,
375 SparseTensorEncodingAttr enc) {
376 return constantOverheadTypeEncoding(builder, loc, enc.getPosWidth());
377}
378
379/// Generates a constant of the internal type-encoding for coordinate
380/// overhead storage.
381inline Value constantCrdTypeEncoding(OpBuilder &builder, Location loc,
382 SparseTensorEncodingAttr enc) {
383 return constantOverheadTypeEncoding(builder, loc, enc.getCrdWidth());
384}
385
386/// Generates a constant of the internal type-encoding for primary storage.
387inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
388 Type elemTp) {
389 return constantI32(builder, loc,
390 i: static_cast<uint32_t>(primaryTypeEncoding(elemTp)));
391}
392
393/// Generates a constant of the internal dimension level type encoding.
394inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
395 LevelType lt) {
396 return constantI64(builder, loc, i: static_cast<uint64_t>(lt));
397}
398
399// Generates a constant from a validated value carrying attribute.
400inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) {
401 if (auto complexAttr = dyn_cast<complex::NumberAttr>(attr)) {
402 Type tp = cast<ComplexType>(complexAttr.getType()).getElementType();
403 return builder.create<complex::ConstantOp>(
404 loc, complexAttr.getType(),
405 builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()),
406 FloatAttr::get(tp, complexAttr.getImag())}));
407 }
408 return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr));
409}
410
411// TODO: is this at the right place?
412inline bool isZeroRankedTensorOrScalar(Type type) {
413 auto rtp = dyn_cast<RankedTensorType>(type);
414 return !rtp || rtp.getRank() == 0;
415}
416
417} // namespace sparse_tensor
418} // namespace mlir
419
420#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENUTILS_H_
421

source code of mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h