1 | //===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements an ODS (and C++) generator from a YAML form |
10 | // derived from the mathematical expression of linalg named ops. Typically a |
11 | // math oriented DSL will be used to export the essential representation to |
12 | // this form, and maintaining the SOT at the math level (versus recreating it |
13 | // in MLIR) is deemed to have systemic value. |
14 | // |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "mlir/AsmParser/AsmParser.h" |
18 | #include "mlir/IR/AffineMap.h" |
19 | #include "mlir/IR/Diagnostics.h" |
20 | #include "mlir/IR/MLIRContext.h" |
21 | #include "mlir/Support/FileUtilities.h" |
22 | #include "mlir/Support/LLVM.h" |
23 | #include "llvm/ADT/StringRef.h" |
24 | #include "llvm/Support/CommandLine.h" |
25 | #include "llvm/Support/Debug.h" |
26 | #include "llvm/Support/FormatVariadic.h" |
27 | #include "llvm/Support/ToolOutputFile.h" |
28 | #include "llvm/Support/YAMLTraits.h" |
29 | #include <optional> |
30 | |
31 | using namespace mlir; |
32 | |
33 | using llvm::yaml::Input; |
34 | |
35 | #define DEBUG_TYPE "linalg-ods-gen" |
36 | |
37 | //===----------------------------------------------------------------------===// |
38 | // Mapping structs (correspond to data types in the YAML description). |
39 | // TODO: Since this is a schema/part of the contract, it should be moved to |
40 | // a real header. |
41 | //===----------------------------------------------------------------------===// |
42 | |
43 | namespace { |
44 | |
45 | struct LinalgYAMLContext { |
46 | MLIRContext *mlirContext; |
47 | }; |
48 | |
49 | struct LinalgOpMetadata { |
50 | std::string name; |
51 | std::string cppClassName; |
52 | std::optional<std::string> doc; |
53 | SmallVector<std::string> implements; |
54 | SmallVector<std::string> defines; |
55 | }; |
56 | |
57 | struct SerializedAffineMap { |
58 | AffineMapAttr affineMapAttr; |
59 | |
60 | AffineMap affineMap() { return affineMapAttr.getValue(); } |
61 | }; |
62 | |
63 | enum class LinalgOperandDefKind { |
64 | InputTensor, |
65 | Scalar, |
66 | OutputTensor, |
67 | IndexAttr, |
68 | UnaryFnAttr, |
69 | BinaryFnAttr, |
70 | TernaryFnAttr, |
71 | TypeFnAttr |
72 | }; |
73 | |
74 | struct LinalgOperandDef { |
75 | std::string name; |
76 | LinalgOperandDefKind kind; |
77 | std::optional<std::string> typeVar; |
78 | std::optional<SerializedAffineMap> shapeMap; |
79 | std::optional<SerializedAffineMap> indexAttrMap; |
80 | std::optional<SmallVector<int64_t>> defaultIndices; |
81 | std::optional<std::string> defaultFn; |
82 | }; |
83 | |
84 | enum class LinalgIteratorTypeDef { |
85 | parallel, |
86 | reduction, |
87 | }; |
88 | |
89 | struct LinalgIndexingMapsConfig { |
90 | std::optional<SmallVector<SerializedAffineMap>> staticIndexingMaps; |
91 | }; |
92 | |
93 | struct ScalarExpression; |
94 | |
95 | enum class ScalarFnKind { Unary, Binary, Ternary, Type }; |
96 | |
97 | struct ScalarFn { |
98 | ScalarFnKind kind; |
99 | std::optional<std::string> fnName; |
100 | std::optional<std::string> attrName; |
101 | std::optional<std::string> typeVar; |
102 | // NOTE: This must be of arity 1, but to break the self-referential cycle, |
103 | // we use a heap allocated vector. |
104 | std::vector<ScalarExpression> operands; |
105 | }; |
106 | |
107 | struct ScalarExpression { |
108 | std::optional<std::string> arg; |
109 | std::optional<std::string> constant; |
110 | std::optional<int64_t> index; |
111 | std::optional<ScalarFn> scalarFn; |
112 | }; |
113 | |
114 | struct ScalarAssign { |
115 | std::string arg; |
116 | ScalarExpression value; |
117 | }; |
118 | |
119 | struct LinalgStructuredOpConfig { |
120 | SmallVector<LinalgOperandDef, 4> args; |
121 | LinalgIndexingMapsConfig indexingMaps; |
122 | SmallVector<LinalgIteratorTypeDef, 4> iteratorTypes; |
123 | std::vector<ScalarAssign> assignments; |
124 | }; |
125 | |
126 | struct LinalgOpConfig { |
127 | std::optional<LinalgOpMetadata> metadata; |
128 | std::optional<LinalgStructuredOpConfig> structuredOp; |
129 | }; |
130 | |
131 | } // namespace |
132 | |
133 | //===----------------------------------------------------------------------===// |
134 | // Mapping traits. |
135 | //===----------------------------------------------------------------------===// |
136 | |
137 | LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef) |
138 | LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap) |
139 | LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef) |
140 | LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign) |
141 | LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression) |
142 | LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig) |
143 | |
144 | namespace llvm { |
145 | namespace yaml { |
146 | |
147 | /// Top-level type containing op metadata and one of a concrete op type. |
148 | /// Currently, the only defined op type is `structured_op` (maps to |
149 | /// `LinalgStructuredOpConfig`). |
150 | template <> |
151 | struct MappingTraits<LinalgOpConfig> { |
152 | static void mapping(IO &io, LinalgOpConfig &info) { |
153 | io.mapOptional(Key: "metadata", Val&: info.metadata); |
154 | io.mapOptional(Key: "structured_op", Val&: info.structuredOp); |
155 | } |
156 | }; |
157 | |
158 | /// A structured op models (at most) a single contraction by modeling |
159 | /// - A list of named arguments (`LinalgOperandDef`), which can be inputs, |
160 | /// outputs, or index attributes. |
161 | /// - List of indexing maps (see `LinalgIndexingMaps`). |
162 | /// - Iterator types (see `LinalgIteratorTypeDef`). |
163 | /// - List of scalar level assignment (see `ScalarAssign`). |
164 | template <> |
165 | struct MappingTraits<LinalgStructuredOpConfig> { |
166 | static void mapping(IO &io, LinalgStructuredOpConfig &info) { |
167 | io.mapRequired(Key: "args", Val&: info.args); |
168 | io.mapRequired(Key: "indexing_maps", Val&: info.indexingMaps); |
169 | io.mapRequired(Key: "iterator_types", Val&: info.iteratorTypes); |
170 | io.mapRequired(Key: "assignments", Val&: info.assignments); |
171 | } |
172 | }; |
173 | |
174 | /// Maps a named tensor, scalar or attribute argument to an operation, |
175 | /// consisting of: |
176 | /// - `name`: Must be unique within the operation. |
177 | /// - `usage`: How the argument is used (input, output, attribute, etc). |
178 | /// - `type_var`: The symbolic type variable that binds to the element or self |
179 | /// type of the tensor or scalar argument, respectively. |
180 | /// - `shape_map`: An optional AffineMap from all op symbols to the shape of |
181 | /// the argument. Only tensor arguments have a `shape_map`. Each shape must |
182 | /// be normalized over the same list of symbols and have no dimension |
183 | /// inputs. |
184 | /// - `index_attr_map`: An optional AffineMap from all op symbols to the |
185 | /// index attribute symbols. During op creation these symbols are replaced |
186 | /// by the corresponding `name` index attribue values. Only index attribute |
187 | /// arguments have an `index_attr_map`. |
188 | /// - `default_indices`: An optional default initialization for index |
189 | /// attribute arguments. |
190 | /// - `default_fn`: An optional default initialization for function attribute |
191 | /// arguments. |
192 | template <> |
193 | struct MappingTraits<LinalgOperandDef> { |
194 | static void mapping(IO &io, LinalgOperandDef &info) { |
195 | io.mapRequired(Key: "name", Val&: info.name); |
196 | io.mapRequired(Key: "kind", Val&: info.kind); |
197 | io.mapOptional(Key: "type_var", Val&: info.typeVar); |
198 | io.mapOptional(Key: "shape_map", Val&: info.shapeMap); |
199 | io.mapOptional(Key: "index_attr_map", Val&: info.indexAttrMap); |
200 | io.mapOptional(Key: "default_indices", Val&: info.defaultIndices); |
201 | io.mapOptional(Key: "default_fn", Val&: info.defaultFn); |
202 | } |
203 | }; |
204 | |
205 | /// Usage enum for a named argument. |
206 | template <> |
207 | struct ScalarEnumerationTraits<LinalgOperandDefKind> { |
208 | static void enumeration(IO &io, LinalgOperandDefKind &value) { |
209 | io.enumCase(Val&: value, Str: "input_tensor", ConstVal: LinalgOperandDefKind::InputTensor); |
210 | io.enumCase(Val&: value, Str: "scalar", ConstVal: LinalgOperandDefKind::Scalar); |
211 | io.enumCase(Val&: value, Str: "output_tensor", ConstVal: LinalgOperandDefKind::OutputTensor); |
212 | io.enumCase(Val&: value, Str: "index_attr", ConstVal: LinalgOperandDefKind::IndexAttr); |
213 | io.enumCase(Val&: value, Str: "unary_fn_attr", ConstVal: LinalgOperandDefKind::UnaryFnAttr); |
214 | io.enumCase(Val&: value, Str: "binary_fn_attr", ConstVal: LinalgOperandDefKind::BinaryFnAttr); |
215 | io.enumCase(Val&: value, Str: "ternary_fn_attr", ConstVal: LinalgOperandDefKind::TernaryFnAttr); |
216 | io.enumCase(Val&: value, Str: "type_fn_attr", ConstVal: LinalgOperandDefKind::TypeFnAttr); |
217 | } |
218 | }; |
219 | |
220 | /// Iterator type enum. |
221 | template <> |
222 | struct ScalarEnumerationTraits<LinalgIteratorTypeDef> { |
223 | static void enumeration(IO &io, LinalgIteratorTypeDef &value) { |
224 | io.enumCase(Val&: value, Str: "parallel", ConstVal: LinalgIteratorTypeDef::parallel); |
225 | io.enumCase(Val&: value, Str: "reduction", ConstVal: LinalgIteratorTypeDef::reduction); |
226 | } |
227 | }; |
228 | |
229 | /// Metadata about the op (name, C++ name, and documentation). |
230 | template <> |
231 | struct MappingTraits<LinalgOpMetadata> { |
232 | static void mapping(IO &io, LinalgOpMetadata &info) { |
233 | io.mapRequired(Key: "name", Val&: info.name); |
234 | io.mapRequired(Key: "cpp_class_name", Val&: info.cppClassName); |
235 | io.mapOptional(Key: "doc", Val&: info.doc); |
236 | io.mapOptional(Key: "implements", Val&: info.implements); |
237 | io.mapOptional(Key: "defines", Val&: info.defines); |
238 | } |
239 | }; |
240 | |
241 | /// How the ops indexing maps are produced. Must be one of: |
242 | /// - static_indexing_maps: A static list of AffineMaps, possibly with |
243 | /// some symbols that bind to attributes of the op. Each indexing map must |
244 | /// be normalized over the same list of dimensions, and its symbols must |
245 | /// match the symbols for argument shapes. |
246 | template <> |
247 | struct MappingTraits<LinalgIndexingMapsConfig> { |
248 | static void mapping(IO &io, LinalgIndexingMapsConfig &info) { |
249 | io.mapOptional(Key: "static_indexing_maps", Val&: info.staticIndexingMaps); |
250 | } |
251 | }; |
252 | |
253 | /// Models an assignment to a named output. |
254 | /// - The `arg` name must match a named output. |
255 | /// - The `value` is a scalar expression for computing the value to |
256 | /// assign (see `ScalarExpression`). |
257 | template <> |
258 | struct MappingTraits<ScalarAssign> { |
259 | static void mapping(IO &io, ScalarAssign &info) { |
260 | io.mapRequired(Key: "arg", Val&: info.arg); |
261 | io.mapRequired(Key: "value", Val&: info.value); |
262 | } |
263 | }; |
264 | |
265 | /// A scalar expression (RHS of an assignment). Must be one of: |
266 | /// - `scalar_arg`: An operation argument. |
267 | /// - `scalar_const`: A constant definition. |
268 | /// - `scalar_index`: An iteration index. |
269 | /// - `scalar_fn`: A named function (see `ScalarFn`). |
270 | template <> |
271 | struct MappingTraits<ScalarExpression> { |
272 | static void mapping(IO &io, ScalarExpression &info) { |
273 | io.mapOptional(Key: "scalar_arg", Val&: info.arg); |
274 | io.mapOptional(Key: "scalar_const", Val&: info.constant); |
275 | io.mapOptional(Key: "scalar_index", Val&: info.index); |
276 | io.mapOptional(Key: "scalar_fn", Val&: info.scalarFn); |
277 | } |
278 | }; |
279 | |
280 | /// Scalar function kind enum. |
281 | template <> |
282 | struct ScalarEnumerationTraits<ScalarFnKind> { |
283 | static void enumeration(IO &io, ScalarFnKind &value) { |
284 | io.enumCase(Val&: value, Str: "unary", ConstVal: ScalarFnKind::Unary); |
285 | io.enumCase(Val&: value, Str: "binary", ConstVal: ScalarFnKind::Binary); |
286 | io.enumCase(Val&: value, Str: "ternary", ConstVal: ScalarFnKind::Ternary); |
287 | io.enumCase(Val&: value, Str: "type", ConstVal: ScalarFnKind::Type); |
288 | } |
289 | }; |
290 | |
291 | /// A scalar expression that evaluates a named function. |
292 | /// Functions are generally "math" level and type polymorphic. Builtin |
293 | /// functions include: |
294 | /// - `add(lhs, rhs)` |
295 | /// - `mul(lhs, rhs)` |
296 | template <> |
297 | struct MappingTraits<ScalarFn> { |
298 | static void mapping(IO &io, ScalarFn &info) { |
299 | io.mapRequired(Key: "kind", Val&: info.kind); |
300 | io.mapOptional(Key: "fn_name", Val&: info.fnName); |
301 | io.mapOptional(Key: "attr_name", Val&: info.attrName); |
302 | io.mapOptional(Key: "type_var", Val&: info.typeVar); |
303 | io.mapRequired(Key: "operands", Val&: info.operands); |
304 | } |
305 | }; |
306 | |
307 | /// Helper mapping which accesses an AffineMapAttr as a serialized string of |
308 | /// the same. |
309 | template <> |
310 | struct ScalarTraits<SerializedAffineMap> { |
311 | static void output(const SerializedAffineMap &value, void *rawYamlContext, |
312 | raw_ostream &out) { |
313 | assert(value.affineMapAttr); |
314 | value.affineMapAttr.print(os&: out); |
315 | } |
316 | static StringRef input(StringRef scalar, void *rawYamlContext, |
317 | SerializedAffineMap &value) { |
318 | assert(rawYamlContext); |
319 | auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext); |
320 | if (auto attr = dyn_cast_or_null<AffineMapAttr>( |
321 | Val: mlir::parseAttribute(attrStr: scalar, context: yamlContext->mlirContext))) |
322 | value.affineMapAttr = attr; |
323 | else if (!value.affineMapAttr || !isa<AffineMapAttr>(Val: value.affineMapAttr)) |
324 | return "could not parse as an affine map attribute"; |
325 | return StringRef(); |
326 | } |
327 | static QuotingType mustQuote(StringRef) { return QuotingType::None; } |
328 | }; |
329 | |
330 | } // namespace yaml |
331 | } // namespace llvm |
332 | |
333 | namespace { |
334 | |
335 | //===----------------------------------------------------------------------===// |
336 | // Generation utilities |
337 | //===----------------------------------------------------------------------===// |
338 | |
339 | class GenerationContext { |
340 | public: |
341 | GenerationContext(MLIRContext *context, raw_ostream *odsOut, |
342 | raw_ostream *defnOut) |
343 | : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut), |
344 | defnOut(defnOut) {} |
345 | |
346 | MLIRContext *getContext() { return context; } |
347 | |
348 | void setLoc(Location loc) { this->loc = loc; } |
349 | Location getLoc() { return loc; } |
350 | |
351 | bool shouldGenerateOds() { return odsOut; } |
352 | bool shouldGenerateDefns() { return defnOut; } |
353 | |
354 | raw_ostream &odss() { |
355 | assert(odsOut && "ODS stream not defined"); |
356 | return *odsOut; |
357 | } |
358 | |
359 | raw_ostream &defns() { |
360 | assert(defnOut && "Definition stream not defined"); |
361 | return *defnOut; |
362 | } |
363 | |
364 | private: |
365 | MLIRContext *context; |
366 | Location loc; |
367 | raw_ostream *odsOut; |
368 | raw_ostream *defnOut; |
369 | }; |
370 | |
371 | } // namespace |
372 | |
373 | static std::string generateCppExpression(SerializedAffineMap self, |
374 | StringRef contextName) { |
375 | std::string printedStr; |
376 | llvm::raw_string_ostream printedSs(printedStr); |
377 | self.affineMapAttr.print(os&: printedSs); |
378 | |
379 | static const char exprFormat[] = |
380 | R"FMT(llvm::cast<AffineMapAttr>(mlir::parseAttribute("{0}", {1})).getValue())FMT"; |
381 | return llvm::formatv(Fmt: exprFormat, Vals&: printedStr, Vals&: contextName); |
382 | } |
383 | |
384 | template <typename Container> |
385 | static std::string interleaveToString(Container &container, |
386 | StringRef separator) { |
387 | std::string result; |
388 | llvm::raw_string_ostream ss(result); |
389 | llvm::interleave(container, ss, separator); |
390 | return result; |
391 | } |
392 | |
393 | static std::optional<int> |
394 | findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) { |
395 | for (const auto &it : llvm::enumerate(First&: args)) { |
396 | if (it.value().name == name) |
397 | return it.index(); |
398 | } |
399 | return std::nullopt; |
400 | } |
401 | |
402 | // Try to map the TypeVar to a predefined or an argument type. |
403 | static std::optional<std::string> |
404 | findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) { |
405 | // Handle all predefined types. |
406 | if (typeVar == "I32") |
407 | return std::string("helper.getIntegerType(32)"); |
408 | if (typeVar == "I64") |
409 | return std::string("helper.getIntegerType(64)"); |
410 | if (typeVar == "F32") |
411 | return std::string("helper.getFloat32Type()"); |
412 | if (typeVar == "F64") |
413 | return std::string("helper.getFloat64Type()"); |
414 | |
415 | // Search all argument types. |
416 | for (const auto &it : llvm::enumerate(First&: args)) { |
417 | if (it.value().kind != LinalgOperandDefKind::InputTensor && |
418 | it.value().kind != LinalgOperandDefKind::Scalar && |
419 | it.value().kind != LinalgOperandDefKind::OutputTensor) |
420 | continue; |
421 | if (*it.value().typeVar == typeVar) |
422 | return llvm::formatv(Fmt: "block.getArgument({0}).getType()", Vals: it.index()) |
423 | .str(); |
424 | } |
425 | |
426 | return std::nullopt; |
427 | } |
428 | |
429 | static ScalarAssign *findAssignment(StringRef name, |
430 | std::vector<ScalarAssign> &assignments) { |
431 | for (auto &assign : assignments) { |
432 | if (assign.arg == name) |
433 | return &assign; |
434 | } |
435 | return nullptr; |
436 | } |
437 | |
438 | // Return true if the operand is a function attribute. |
439 | static bool isFunctionAttribute(LinalgOperandDefKind kind) { |
440 | return kind == LinalgOperandDefKind::UnaryFnAttr || |
441 | kind == LinalgOperandDefKind::BinaryFnAttr || |
442 | kind == LinalgOperandDefKind::TernaryFnAttr || |
443 | kind == LinalgOperandDefKind::TypeFnAttr; |
444 | } |
445 | |
446 | // Return true if the operand is an attribute. |
447 | static bool isAttribute(LinalgOperandDefKind kind) { |
448 | return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind); |
449 | } |
450 | |
451 | // Get the enum name for the given operand kind. |
452 | std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) { |
453 | switch (kind) { |
454 | case LinalgOperandDefKind::UnaryFnAttr: |
455 | return std::string("UnaryFn"); |
456 | case LinalgOperandDefKind::BinaryFnAttr: |
457 | return std::string("BinaryFn"); |
458 | case LinalgOperandDefKind::TernaryFnAttr: |
459 | return std::string("TernaryFn"); |
460 | case LinalgOperandDefKind::TypeFnAttr: |
461 | return std::string("TypeFn"); |
462 | default: |
463 | break; |
464 | } |
465 | llvm_unreachable("unsupported function attribute kind"); |
466 | } |
467 | |
468 | // Get the enum name for the given function kind. |
469 | std::string convertFunctionKindToEnumName(ScalarFnKind kind) { |
470 | switch (kind) { |
471 | case ScalarFnKind::Unary: |
472 | return std::string("UnaryFn"); |
473 | case ScalarFnKind::Binary: |
474 | return std::string("BinaryFn"); |
475 | case ScalarFnKind::Ternary: |
476 | return std::string("TernaryFn"); |
477 | case ScalarFnKind::Type: |
478 | return std::string("TypeFn"); |
479 | } |
480 | llvm_unreachable("unsupported function kind"); |
481 | } |
482 | |
483 | //===----------------------------------------------------------------------===// |
484 | // Templates |
485 | //===----------------------------------------------------------------------===// |
486 | |
487 | // A single line banner format. Parameters: |
488 | // {0}: Single line comment |
489 | static const char bannerFormat[] = R"FMT( |
490 | //===----------------------------------------------------------------------===// |
491 | // {0} |
492 | //===----------------------------------------------------------------------===// |
493 | )FMT"; |
494 | |
495 | //===----------------------------------------------------------------------===// |
496 | // Named generic op generation. |
497 | // These ops map at most a single contraction that complies with the limitations |
498 | // of a linalg.generic. |
499 | //===----------------------------------------------------------------------===// |
500 | |
501 | // Template for Linalg named ops' ODS definitions. Parameters: |
502 | // {0}: ODS/C++ op name |
503 | // {1}: assembly op mnemonic |
504 | // {2}: op interface list |
505 | // {3}: documentation (summary + description) |
506 | // {4}: op attribute list |
507 | // {5}: builder methods taking standalone attribute parameters |
508 | // {6}: additional method defintions |
509 | // {7}: additional methods for attributes used by indexing maps |
510 | static const char structuredOpOdsHeaderFormat[] = R"FMT( |
511 | //===----------------------------------------------------------------------===// |
512 | // Op definition for {0} |
513 | //===----------------------------------------------------------------------===// |
514 | |
515 | def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], |
516 | /*extraInterfaces=*/[{2}])> { |
517 | {3} |
518 | let arguments = (ins |
519 | Variadic<AnyType>:$inputs, |
520 | Variadic<AnyShaped>:$outputs{4} |
521 | ); |
522 | let results = (outs Variadic<AnyRankedTensor>:$result_tensors); |
523 | let regions = (region AnyRegion:$region); |
524 | |
525 | let skipDefaultBuilders = 1; |
526 | let builders = [ |
527 | OpBuilder< |
528 | (ins "ValueRange":$inputs, "ValueRange":$outputs, |
529 | CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
530 | [{{ |
531 | buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, |
532 | attributes, {0}::getRegionBuilder()); |
533 | }]>, |
534 | OpBuilder< |
535 | (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, |
536 | "ValueRange":$outputs, |
537 | CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
538 | [{{ |
539 | buildStructuredOp($_builder, $_state, resultTensorTypes, |
540 | inputs, outputs, attributes, {0}::getRegionBuilder()); |
541 | }]>, |
542 | OpBuilder< |
543 | (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, |
544 | CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
545 | [{{ |
546 | $_state.addOperands(operands); |
547 | $_state.addAttributes(attributes); |
548 | $_state.addTypes(resultTensorTypes); |
549 | (void)$_state.addRegion(); |
550 | }]> |
551 | {5} |
552 | ]; |
553 | let hasCustomAssemblyFormat = 1; |
554 | let hasFolder = 1; |
555 | {6} |
556 | |
557 | let extraClassDeclaration = structuredOpsBaseDecls # [{{ |
558 | // Auto-generated. |
559 | SmallVector<utils::IteratorType> getIteratorTypesArray(); |
560 | ArrayAttr getIndexingMaps(); |
561 | static void regionBuilder(ImplicitLocOpBuilder &b, |
562 | Block &block, ArrayRef<NamedAttribute> attrs); |
563 | static std::function<void(ImplicitLocOpBuilder &, |
564 | Block &, ArrayRef<NamedAttribute>)> |
565 | getRegionBuilder() {{ |
566 | return regionBuilder; |
567 | } |
568 | |
569 | ::mlir::MutableOperandRange getDpsInitsMutable() {{ |
570 | return getOutputsMutable(); |
571 | } |
572 | |
573 | // Generic methods. |
574 | static unsigned getNumRegionArgs(); |
575 | std::string getLibraryCallName(); |
576 | {7} |
577 | }]; |
578 | } |
579 | )FMT"; |
580 | |
581 | // Builder method taking attribute parameters. Parameters: |
582 | // {0}: Class name |
583 | // {1}: Comma interleaved attribute parameters |
584 | // {2}: Attribute initialization |
585 | static const char structuredOpBuilderFormat[] = R"FMT( |
586 | , OpBuilder< |
587 | (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, |
588 | "ValueRange":$outputs, {1}, |
589 | CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
590 | [{{ |
591 | {2} |
592 | buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs, |
593 | attributes, {0}::getRegionBuilder()); |
594 | }]> |
595 | )FMT"; |
596 | |
597 | // The getIteratorTypesArray() method for structured ops. Parameters: |
598 | // {0}: Class name |
599 | // {1}: Comma interleaved iterator type names. |
600 | static const char structuredOpIteratorTypesFormat[] = |
601 | R"FMT( |
602 | SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{ |
603 | return SmallVector<utils::IteratorType>{{ {1} }; |
604 | } |
605 | )FMT"; |
606 | |
607 | // The getIteratorTypesArray() method for rank polymorphic structured ops. |
608 | // Parameters: |
609 | // {0}: Class name |
610 | static const char rankPolyStructuredOpIteratorTypesFormat[] = |
611 | R"FMT( |
612 | SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{ |
613 | int64_t rank = getRank(getDpsInitOperand(0)); |
614 | return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel); |
615 | } |
616 | )FMT"; |
617 | |
618 | // The indexing_maps() method for structured ops. Parameters: |
619 | // {0}: Class name |
620 | // {1}: Comma-separated list of dimension variable names. |
621 | // {2}: Statements |
622 | static const char structuredOpIndexingMapsFormat[] = R"FMT( |
623 | ArrayAttr {0}::getIndexingMaps() {{ |
624 | static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; |
625 | ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr); |
626 | if (cached) |
627 | return cached; |
628 | |
629 | MLIRContext *context = getContext(); |
630 | auto symbolBindings = getSymbolBindings(*this); |
631 | SmallVector<AffineMap> maps; |
632 | {1} |
633 | cached = Builder(context).getAffineMapArrayAttr(maps); |
634 | getOperation()->setAttr(memoizeAttr, cached); |
635 | return cached; |
636 | } |
637 | )FMT"; |
638 | |
639 | // The indexing_maps() method for rank polymorphic structured ops. Parameters: |
640 | // {0}: Class name |
641 | static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT( |
642 | ArrayAttr {0}::getIndexingMaps() {{ |
643 | MLIRContext *context = getContext(); |
644 | AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context); |
645 | AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( |
646 | getNumParallelLoops(), context); |
647 | SmallVector<AffineMap> indexingMaps; |
648 | for (OpOperand &opOperand : getOperation()->getOpOperands()) |
649 | indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap); |
650 | return Builder(getContext()).getAffineMapArrayAttr(indexingMaps); |
651 | } |
652 | )FMT"; |
653 | |
654 | // Implementations of fold, getEffects and getSpeculatability. |
655 | // Parameters: |
656 | // {0}: Class name |
657 | const char structuredOpFoldersFormat[] = R"FMT( |
658 | LogicalResult {0}::fold(FoldAdaptor, |
659 | SmallVectorImpl<OpFoldResult> &) {{ |
660 | return memref::foldMemRefCast(*this); |
661 | } |
662 | void {0}::getEffects(SmallVectorImpl< |
663 | SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{ |
664 | if (hasPureTensorSemantics()) return; |
665 | getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation())); |
666 | } |
667 | Speculation::Speculatability {0}::getSpeculatability() {{ |
668 | return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); |
669 | } |
670 | )FMT"; |
671 | |
672 | // Implementation of parse/print. |
673 | // Parameters: |
674 | // {0}: Class name |
675 | static const char structuredOpParserFormat[] = R"FMT( |
676 | ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{ |
677 | return ::parseNamedStructuredOp(parser, result, |
678 | {0}::getNumRegionArgs(), {0}::getRegionBuilder()); |
679 | } |
680 | void {0}::print(OpAsmPrinter &p) {{ |
681 | SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes", |
682 | "linalg.memoized_indexing_maps"}; |
683 | ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(), |
684 | elidedAttrs); |
685 | } |
686 | )FMT"; |
687 | |
688 | static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig, |
689 | GenerationContext &genContext) { |
690 | if (!genContext.shouldGenerateOds()) |
691 | return success(); |
692 | |
693 | raw_ostream &os = genContext.odss(); |
694 | |
695 | std::string interfaceNameList; |
696 | std::string attrList; |
697 | std::string attrMethods; |
698 | std::string attrBuilder; |
699 | |
700 | std::string doc; |
701 | if (opConfig.metadata->doc) { |
702 | static const char structuredOpDocFmt[] = R"FMT( |
703 | let summary = [{{{0}}]; |
704 | let description = [{{{1}}]; |
705 | )FMT"; |
706 | StringRef summary, description; |
707 | std::tie(args&: summary, args&: description) = |
708 | StringRef(*opConfig.metadata->doc).trim().split(Separator: "\n\n"); |
709 | |
710 | doc = llvm::formatv(Fmt: structuredOpDocFmt, Vals: summary.trim(), Vals: description.trim()); |
711 | } |
712 | |
713 | interfaceNameList = interleaveToString(container&: opConfig.metadata->implements, separator: ", "); |
714 | |
715 | std::string definitionList; |
716 | for (const std::string &definition : opConfig.metadata->defines) { |
717 | static const char definitionFmt[] = "let {0} = 1;\n"; |
718 | definitionList.append(str: llvm::formatv(Fmt: definitionFmt, Vals: definition)); |
719 | } |
720 | |
721 | if (llvm::any_of(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) { |
722 | return isAttribute(kind: arg.kind); |
723 | })) { |
724 | SmallVector<std::string> attrDefs; |
725 | SmallVector<std::string> attrParams; |
726 | SmallVector<std::string> attrStmts; |
727 | for (LinalgOperandDef &arg : opConfig.structuredOp->args) { |
728 | static const char paramFmt[] = "\"Attribute\":${0}"; |
729 | static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});"; |
730 | // Add the type conversion attributes to the op definition and builders. |
731 | if (isFunctionAttribute(kind: arg.kind)) { |
732 | assert(arg.defaultFn); |
733 | std::string enumName = convertOperandKindToEnumName(kind: arg.kind); |
734 | static const char typeFmt[] = "{0}::{1}"; |
735 | static const char defFmt[] = |
736 | "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}"; |
737 | attrDefs.push_back(Elt: llvm::formatv( |
738 | Fmt: defFmt, Vals: llvm::formatv(Fmt: "{0}Attr", Vals&: enumName), |
739 | Vals: llvm::formatv(Fmt: typeFmt, Vals&: enumName, Vals&: arg.defaultFn), Vals&: arg.name)); |
740 | attrParams.push_back(Elt: llvm::formatv(Fmt: paramFmt, Vals&: arg.name)); |
741 | attrStmts.push_back(Elt: llvm::formatv(Fmt: stmtFmt, Vals&: arg.name)); |
742 | } |
743 | // Add the index attributes to the op definition and builders. |
744 | if (arg.kind == LinalgOperandDefKind::IndexAttr) { |
745 | assert(arg.indexAttrMap.has_value()); |
746 | assert(arg.defaultIndices.has_value()); |
747 | size_t size = arg.indexAttrMap->affineMap().getNumResults(); |
748 | assert(arg.defaultIndices->size() == size); |
749 | static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>"; |
750 | static const char defFmt[] = |
751 | "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}"; |
752 | std::string defaultVals; |
753 | llvm::raw_string_ostream ss(defaultVals); |
754 | llvm::interleave( |
755 | c: *arg.defaultIndices, os&: ss, |
756 | each_fn: [&](int64_t val) { ss << "static_cast<int64_t>("<< val << ")"; }, |
757 | separator: ", "); |
758 | attrDefs.push_back(Elt: llvm::formatv(Fmt: defFmt, Vals: llvm::formatv(Fmt: typeFmt, Vals&: size), |
759 | Vals&: ss.str(), Vals&: arg.name)); |
760 | attrParams.push_back(Elt: llvm::formatv(Fmt: paramFmt, Vals&: arg.name)); |
761 | attrStmts.push_back(Elt: llvm::formatv(Fmt: stmtFmt, Vals&: arg.name)); |
762 | } |
763 | } |
764 | if (llvm::any_of(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) { |
765 | return arg.kind == LinalgOperandDefKind::IndexAttr; |
766 | })) { |
767 | attrMethods = R"( |
768 | bool hasDynamicIndexingMaps(); |
769 | LogicalResult verifyIndexingMapRequiredAttributes(); |
770 | )"; |
771 | } |
772 | attrList = ",\n"+ llvm::join(R&: attrDefs, Separator: ",\n"); |
773 | attrBuilder = llvm::formatv( |
774 | Fmt: structuredOpBuilderFormat, Vals&: opConfig.metadata->cppClassName, |
775 | Vals: llvm::join(R&: attrParams, Separator: ", "), Vals: llvm::join(R&: attrStmts, Separator: "\n")); |
776 | } |
777 | |
778 | os << llvm::formatv(Fmt: structuredOpOdsHeaderFormat, |
779 | Vals&: opConfig.metadata->cppClassName, Vals&: opConfig.metadata->name, |
780 | Vals&: interfaceNameList, Vals&: doc, Vals&: attrList, Vals&: attrBuilder, |
781 | Vals&: definitionList, Vals&: attrMethods); |
782 | |
783 | return success(); |
784 | } |
785 | |
786 | static LogicalResult |
787 | generateNamedGenericOpDefns(LinalgOpConfig &opConfig, |
788 | GenerationContext &genContext) { |
789 | if (!genContext.shouldGenerateDefns()) |
790 | return success(); |
791 | |
792 | raw_ostream &os = genContext.defns(); |
793 | StringRef className = opConfig.metadata->cppClassName; |
794 | |
795 | // Implementation banner. |
796 | std::string bannerComment = llvm::formatv(Fmt: "Implementation of {0}", Vals&: className); |
797 | os << llvm::formatv(Fmt: bannerFormat, Vals&: bannerComment); |
798 | |
799 | // Compute the number of scalar and tensor arguments. |
800 | int64_t numOfArgs = |
801 | llvm::count_if(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) { |
802 | return arg.kind == LinalgOperandDefKind::InputTensor || |
803 | arg.kind == LinalgOperandDefKind::Scalar || |
804 | arg.kind == LinalgOperandDefKind::OutputTensor; |
805 | }); |
806 | |
807 | // An operation that accesses only scalars and scalar/rank zero tensors is |
808 | // rank polymorhpic. We implement rank polymorphism by generating different |
809 | // indexing maps and iterators that match the rank of the first output tensor. |
810 | // An operation is rank polymorphic if the iteration domain has rank zero. |
811 | bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty(); |
812 | |
813 | // Generate the iterator_types() method. |
814 | if (!isRankPolymorphic) { |
815 | std::string iteratorsStr; |
816 | llvm::raw_string_ostream ss(iteratorsStr); |
817 | llvm::interleaveComma(c: opConfig.structuredOp->iteratorTypes, os&: ss, |
818 | each_fn: [&](LinalgIteratorTypeDef it) { |
819 | switch (it) { |
820 | case LinalgIteratorTypeDef::parallel: |
821 | ss << "utils::IteratorType::parallel"; |
822 | break; |
823 | case LinalgIteratorTypeDef::reduction: |
824 | ss << "utils::IteratorType::reduction"; |
825 | break; |
826 | } |
827 | }); |
828 | os << llvm::formatv(Fmt: structuredOpIteratorTypesFormat, Vals&: className, |
829 | Vals&: iteratorsStr); |
830 | } else { |
831 | os << llvm::formatv(Fmt: rankPolyStructuredOpIteratorTypesFormat, Vals&: className); |
832 | } |
833 | |
834 | // Generating the getIndexingMaps() method. |
835 | if (auto &staticMaps = |
836 | opConfig.structuredOp->indexingMaps.staticIndexingMaps) { |
837 | if (staticMaps->empty()) |
838 | return emitError(loc: genContext.getLoc()) << "op has no indexing maps"; |
839 | if (!isRankPolymorphic) { |
840 | AffineMap firstMap = staticMaps->front().affineMap(); |
841 | |
842 | // Symbol bindings. |
843 | { |
844 | // For each symbol, generate a declaration for it, either with an |
845 | // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from |
846 | // an attribute). |
847 | // TODO: Possibly lift into a top-level method. |
848 | static const char structuredOpSymbolBindingsFormat[] = R"FMT( |
849 | static SmallVector<AffineExpr> getSymbolBindings({0} self) { |
850 | MLIRContext *context = self.getContext(); |
851 | SmallVector<AffineExpr> exprs; |
852 | {1} |
853 | return exprs; |
854 | } |
855 | )FMT"; |
856 | |
857 | unsigned symbolCount = firstMap.getNumSymbols(); |
858 | SmallVector<std::string> symbolBindings; |
859 | for (unsigned i = 0; i < symbolCount; ++i) { |
860 | symbolBindings.push_back(Elt: llvm::formatv( |
861 | Fmt: " exprs.push_back(getAffineSymbolExpr({0}, context));", Vals&: i)); |
862 | } |
863 | |
864 | // Access an index attribute. Parameters: |
865 | // {0}: Attribute name |
866 | // {1}: Symbol position |
867 | // {2}: Attribute index |
868 | static const char structuredOpAccessAttrFormat[] = R"FMT( |
869 | int64_t cst{1} = self.get{0}().getValues<int64_t>()[{2}]; |
870 | exprs.push_back(getAffineConstantExpr(cst{1}, context)); |
871 | )FMT"; |
872 | // Update all symbol bindings mapped to an attribute. |
873 | for (LinalgOperandDef &arg : opConfig.structuredOp->args) { |
874 | if (arg.kind != LinalgOperandDefKind::IndexAttr) |
875 | continue; |
876 | assert(arg.indexAttrMap); |
877 | for (auto [idx, result] : |
878 | llvm::enumerate(First: arg.indexAttrMap->affineMap().getResults())) { |
879 | if (auto symbol = dyn_cast<AffineSymbolExpr>(Val: result)) { |
880 | std::string argName = arg.name; |
881 | argName[0] = toupper(c: argName[0]); |
882 | symbolBindings[symbol.getPosition()] = |
883 | llvm::formatv(Fmt: structuredOpAccessAttrFormat, Vals&: argName, |
884 | Vals: symbol.getPosition(), Vals&: idx); |
885 | } |
886 | } |
887 | } |
888 | |
889 | std::string symbolBindingsStr; |
890 | llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr); |
891 | llvm::interleave(c: symbolBindings, os&: symbolBindingsSs, separator: "\n"); |
892 | |
893 | os << llvm::formatv(Fmt: structuredOpSymbolBindingsFormat, Vals&: className, |
894 | Vals&: symbolBindingsStr); |
895 | } |
896 | |
897 | // Indexing maps. |
898 | { |
899 | unsigned dimCount = firstMap.getNumDims(); |
900 | |
901 | // Generate a comma-separated list of dim identifiers to be passed to |
902 | // bindDims, ensuring tht AffineExpr identifiers are bound in the right |
903 | // order to the proper AffineDimExpr. |
904 | // This results in vars in scope like: d0, d1, d2... |
905 | SmallVector<unsigned> dimIndices; |
906 | for (unsigned i = 0; i < dimCount; ++i) |
907 | dimIndices.push_back(Elt: i); |
908 | std::string dimIdentsStr; |
909 | llvm::raw_string_ostream dimIdentsSs(dimIdentsStr); |
910 | llvm::interleaveComma(c: dimIndices, os&: dimIdentsSs, |
911 | each_fn: [&](unsigned i) { dimIdentsSs << "d"<< i; }); |
912 | |
913 | // Statements to add and simplify each affine map. |
914 | SmallVector<std::string> stmts; |
915 | for (auto &indexingMap : *staticMaps) { |
916 | // TODO: Assert that dim and symbol count match the first. |
917 | stmts.push_back( |
918 | Elt: llvm::formatv(Fmt: "maps.push_back({0});", |
919 | Vals: generateCppExpression(self: indexingMap, contextName: "context"))); |
920 | stmts.push_back(Elt: llvm::formatv( |
921 | Fmt: "maps.back() = " |
922 | "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, " |
923 | "symbolBindings, {0}, 0));", |
924 | Vals&: dimCount)); |
925 | } |
926 | |
927 | // TODO: This needs to be memoized and/or converted to non-parser based |
928 | // C++ codegen prior to real use. |
929 | os << llvm::formatv(Fmt: structuredOpIndexingMapsFormat, Vals&: className, |
930 | Vals: interleaveToString(container&: stmts, separator: "\n ")); |
931 | } |
932 | } else { |
933 | os << llvm::formatv(Fmt: rankPolyStructuredOpIndexingMapsFormat, Vals&: className); |
934 | } |
935 | } else { |
936 | return emitError(loc: genContext.getLoc()) |
937 | << "generating code for non static indexing maps not currently " |
938 | "supported"; |
939 | } |
940 | |
941 | // getNumRegionArgs() |
942 | { |
943 | // Generates a getNumRegionArgs() method. Parameters: |
944 | // {0}: Class name |
945 | // {1}: Number of region args |
946 | static const char structuredOpGetNumRegionArgsFormat[] = R"FMT( |
947 | unsigned {0}::getNumRegionArgs() {{ return {1}; } |
948 | )FMT"; |
949 | os << llvm::formatv(Fmt: structuredOpGetNumRegionArgsFormat, Vals&: className, |
950 | Vals&: numOfArgs); |
951 | } |
952 | |
953 | // getLibraryCallName() |
954 | { |
955 | // Generates a getLibraryCallName method. Parameters: |
956 | // {0}: Class name |
957 | static const char structuredOpGetLibraryCallFormat[] = R"FMT( |
958 | std::string {0}::getLibraryCallName() {{ |
959 | return generateLibraryCallName(getOperation()); |
960 | } |
961 | )FMT"; |
962 | os << llvm::formatv(Fmt: structuredOpGetLibraryCallFormat, Vals&: className); |
963 | } |
964 | |
965 | // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes() |
966 | if (llvm::any_of(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) { |
967 | return arg.kind == LinalgOperandDefKind::IndexAttr; |
968 | })) { |
969 | std::vector<std::string> attrVerifications; |
970 | for (LinalgOperandDef &arg : opConfig.structuredOp->args) { |
971 | if (arg.kind != LinalgOperandDefKind::IndexAttr) |
972 | continue; |
973 | assert(arg.indexAttrMap); |
974 | // Verify index attribute. Paramters: |
975 | // {0}: Attribute name |
976 | // {1}: Attribute size |
977 | static const char attrFmt[] = R"FMT( |
978 | if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{ |
979 | if (!attr.getType().getElementType().isInteger(64)) |
980 | return op->emitError("incorrect element type for index attribute '{0}'"); |
981 | if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} }) |
982 | return op->emitError("incorrect shape for index attribute '{0}'"); |
983 | } |
984 | )FMT"; |
985 | attrVerifications.push_back(x: llvm::formatv( |
986 | Fmt: attrFmt, Vals&: arg.name, Vals: arg.indexAttrMap->affineMap().getNumResults())); |
987 | } |
988 | |
989 | // Generates the verifyIndexingMapRequiredAttributes method. Parameters: |
990 | // {0}: Class name |
991 | // {1}: Attribute verification |
992 | static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT( |
993 | bool {0}::hasDynamicIndexingMaps() {{ return true; } |
994 | LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{ |
995 | Operation *op = getOperation(); |
996 | {1} |
997 | return success(); |
998 | } |
999 | )FMT"; |
1000 | os << llvm::formatv(Fmt: structuredOpVerifyIndexingMapRequiredAttributes, |
1001 | Vals&: className, Vals: llvm::join(R&: attrVerifications, Separator: "\n")); |
1002 | } |
1003 | |
1004 | // regionBuilder() |
1005 | { |
1006 | // Generates a regionBuilder method. Parameters. |
1007 | // {0}: Class name |
1008 | // {1}: Number of args |
1009 | // {2}: Attributes |
1010 | // {3}: Statements |
1011 | static const char structuredOpRegionBuilderFormat[] = R"FMT( |
1012 | void {0}::regionBuilder(ImplicitLocOpBuilder &b, |
1013 | Block &block, ArrayRef<NamedAttribute> attrs) {{ |
1014 | assert({1} > 0 && block.getNumArguments() == {1} && |
1015 | "{0} regionBuilder expects {1} (>=0) args"); |
1016 | RegionBuilderHelper helper(b, block); |
1017 | SmallVector<Value> yields; |
1018 | {2} |
1019 | {3} |
1020 | helper.yieldOutputs(yields); |
1021 | } |
1022 | )FMT"; |
1023 | auto &args = opConfig.structuredOp->args; |
1024 | auto &assignments = opConfig.structuredOp->assignments; |
1025 | size_t generatedAssignmentCount = 0; |
1026 | int localCounter = 0; |
1027 | SmallVector<std::string> attrs; |
1028 | SmallVector<std::string> stmts; |
1029 | for (LinalgOperandDef &arg : args) { |
1030 | if (!isFunctionAttribute(kind: arg.kind)) |
1031 | continue; |
1032 | // Obtain the type function attribute values. Parameters. |
1033 | // {0}: enum name |
1034 | // {1}: attribute name |
1035 | // {2}: default type function name |
1036 | static const char attrDef[] = R"FMT( |
1037 | {0} {1}Val = {0}::{2}; |
1038 | auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{ |
1039 | return attr.getName() == "{1}"; }); |
1040 | if ({1}Iter != attrs.end()) {{ |
1041 | if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue())) |
1042 | {1}Val = attr.getValue(); |
1043 | } |
1044 | )FMT"; |
1045 | std::string enumName = convertOperandKindToEnumName(kind: arg.kind); |
1046 | attrs.push_back( |
1047 | Elt: llvm::formatv(Fmt: attrDef, Vals&: enumName, Vals&: arg.name, Vals&: arg.defaultFn)); |
1048 | } |
1049 | for (LinalgOperandDef &arg : args) { |
1050 | if (arg.kind != LinalgOperandDefKind::OutputTensor) |
1051 | continue; |
1052 | |
1053 | // Find the assignment that correlates with the argument. |
1054 | ScalarAssign *assignment = findAssignment(name: arg.name, assignments); |
1055 | if (!assignment) |
1056 | return emitError(loc: genContext.getLoc()) |
1057 | << "no assignment found for output argument "<< arg.name; |
1058 | ++generatedAssignmentCount; |
1059 | |
1060 | // Recursively generate the expression. |
1061 | std::function<std::optional<std::string>(ScalarExpression &)> |
1062 | generateExpression = |
1063 | [&](ScalarExpression &expression) -> std::optional<std::string> { |
1064 | if (expression.arg) { |
1065 | // Argument reference. |
1066 | std::optional<int> argIndex = |
1067 | findTensorDefArgIndex(name: *expression.arg, args); |
1068 | if (!argIndex) { |
1069 | emitError(loc: genContext.getLoc()) |
1070 | << "scalar argument not defined on the op: "<< *expression.arg; |
1071 | return std::nullopt; |
1072 | } |
1073 | return std::string( |
1074 | llvm::formatv(Fmt: "block.getArgument({0})", Vals&: *argIndex)); |
1075 | } |
1076 | if (expression.constant) { |
1077 | std::string cppIdent = llvm::formatv(Fmt: "value{0}", Vals&: ++localCounter); |
1078 | stmts.push_back( |
1079 | Elt: llvm::formatv(Fmt: R"FMT(Value {0} = helper.constant("{1}");)FMT", |
1080 | Vals&: cppIdent, Vals&: expression.constant)); |
1081 | return cppIdent; |
1082 | } |
1083 | if (expression.index) { |
1084 | // Access an iteration index. |
1085 | std::string cppIdent = llvm::formatv(Fmt: "value{0}", Vals&: ++localCounter); |
1086 | stmts.push_back(Elt: llvm::formatv(Fmt: "Value {0} = helper.index({1});", |
1087 | Vals&: cppIdent, Vals&: *expression.index)); |
1088 | return cppIdent; |
1089 | } |
1090 | if (expression.scalarFn) { |
1091 | std::string enumName = |
1092 | convertFunctionKindToEnumName(kind: expression.scalarFn->kind); |
1093 | |
1094 | // Get the function or attribute name. |
1095 | assert(expression.scalarFn->fnName || expression.scalarFn->attrName); |
1096 | std::string funcType; |
1097 | if (expression.scalarFn->fnName) { |
1098 | funcType = llvm::formatv(Fmt: "{0}::{1}", Vals&: enumName, |
1099 | Vals&: *expression.scalarFn->fnName); |
1100 | } |
1101 | if (expression.scalarFn->attrName) { |
1102 | if (llvm::none_of(Range&: args, P: [&](LinalgOperandDef &arg) { |
1103 | return isFunctionAttribute(kind: arg.kind) && |
1104 | arg.name == *expression.scalarFn->attrName; |
1105 | })) { |
1106 | emitError(loc: genContext.getLoc()) << "missing function attribute " |
1107 | << *expression.scalarFn->attrName; |
1108 | } |
1109 | funcType = llvm::formatv(Fmt: "{0}Val", Vals&: *expression.scalarFn->attrName); |
1110 | } |
1111 | assert(!funcType.empty()); |
1112 | |
1113 | // Add the optional type parameter to the operands. |
1114 | SmallVector<std::string> operandCppValues; |
1115 | if (expression.scalarFn->kind == ScalarFnKind::Type) { |
1116 | assert(expression.scalarFn->typeVar.has_value()); |
1117 | std::optional<std::string> typeCppValue = |
1118 | findTypeValue(typeVar: *expression.scalarFn->typeVar, args); |
1119 | if (!typeCppValue) { |
1120 | emitError(loc: genContext.getLoc()) |
1121 | << "type variable "<< *expression.scalarFn->typeVar |
1122 | << ", used in a type conversion, must map to a predefined or " |
1123 | << "an argument type but it does not"; |
1124 | return std::nullopt; |
1125 | } |
1126 | operandCppValues.push_back(Elt: *typeCppValue); |
1127 | } |
1128 | |
1129 | // Collect the scalar operands. |
1130 | for (ScalarExpression &operand : expression.scalarFn->operands) { |
1131 | auto operandCppValue = generateExpression(operand); |
1132 | if (!operandCppValue) |
1133 | return std::nullopt; |
1134 | operandCppValues.push_back(Elt: *operandCppValue); |
1135 | } |
1136 | |
1137 | // Call the function builder. |
1138 | std::string cppIdent = llvm::formatv(Fmt: "value{0}", Vals&: ++localCounter); |
1139 | stmts.push_back(Elt: llvm::formatv( |
1140 | Fmt: "Value {0} = helper.build{1}({2}, {3});", Vals&: cppIdent, Vals&: enumName, |
1141 | Vals&: funcType, Vals: interleaveToString(container&: operandCppValues, separator: ", "))); |
1142 | return cppIdent; |
1143 | } |
1144 | emitError(loc: genContext.getLoc()) << "unknown ScalarExpression type"; |
1145 | return std::nullopt; |
1146 | }; |
1147 | std::optional<std::string> cppValue = |
1148 | generateExpression(assignment->value); |
1149 | if (!cppValue) |
1150 | return failure(); |
1151 | stmts.push_back(Elt: llvm::formatv(Fmt: "yields.push_back({0});", Vals&: *cppValue)); |
1152 | } |
1153 | |
1154 | if (generatedAssignmentCount != assignments.size()) |
1155 | return emitError(loc: genContext.getLoc()) |
1156 | << "mismatched number of assignments vs output arguments"; |
1157 | |
1158 | os << llvm::formatv(Fmt: structuredOpRegionBuilderFormat, Vals&: className, Vals&: numOfArgs, |
1159 | Vals: interleaveToString(container&: attrs, separator: "\n "), |
1160 | Vals: interleaveToString(container&: stmts, separator: "\n ")); |
1161 | } |
1162 | |
1163 | // Parser and printer. |
1164 | os << llvm::formatv(Fmt: structuredOpParserFormat, Vals&: className); |
1165 | |
1166 | // Canonicalizers and folders. |
1167 | os << llvm::formatv(Fmt: structuredOpFoldersFormat, Vals&: className); |
1168 | |
1169 | return success(); |
1170 | } |
1171 | |
1172 | static LogicalResult generateOp(LinalgOpConfig &opConfig, |
1173 | GenerationContext &genContext) { |
1174 | // Switch on op type being generated. |
1175 | if (opConfig.structuredOp) { |
1176 | return success( |
1177 | IsSuccess: succeeded(Result: generateNamedGenericOpOds(opConfig, genContext)) && |
1178 | succeeded(Result: generateNamedGenericOpDefns(opConfig, genContext))); |
1179 | } |
1180 | return emitError(loc: genContext.getLoc()) << "unsupported operation type"; |
1181 | } |
1182 | |
1183 | //===----------------------------------------------------------------------===// |
1184 | // Command line options and main |
1185 | //===----------------------------------------------------------------------===// |
1186 | |
1187 | static llvm::cl::opt<std::string> |
1188 | inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"), |
1189 | llvm::cl::init(Val: "-"), llvm::cl::value_desc( "YAML filename")); |
1190 | |
1191 | static llvm::cl::opt<std::string> |
1192 | outputOdsDeclFilename("o-ods-decl", llvm::cl::desc( "ODS output filename"), |
1193 | llvm::cl::value_desc("filename"), llvm::cl::init(Val: "")); |
1194 | |
1195 | static llvm::cl::opt<std::string> |
1196 | outputCppImplFilename("o-impl", |
1197 | llvm::cl::desc("C++ implementation file name"), |
1198 | llvm::cl::value_desc("filename"), llvm::cl::init(Val: "")); |
1199 | |
1200 | int main(int argc, char **argv) { |
1201 | llvm::cl::ParseCommandLineOptions(argc, argv, Overview: "Linalg ODS Gen from YAML"); |
1202 | |
1203 | // Set up the input file. |
1204 | std::string errorMessage; |
1205 | std::unique_ptr<llvm::MemoryBuffer> file = |
1206 | mlir::openInputFile(inputFilename, errorMessage: &errorMessage); |
1207 | if (!file) { |
1208 | llvm::errs() << errorMessage << "\n"; |
1209 | return 1; |
1210 | } |
1211 | |
1212 | MLIRContext mlirContext; |
1213 | LinalgYAMLContext yamlContext{.mlirContext: &mlirContext}; |
1214 | |
1215 | std::vector<LinalgOpConfig> opConfigs; |
1216 | |
1217 | // Parse input. |
1218 | Input yin(file->getBuffer(), &yamlContext); |
1219 | yin >> opConfigs; |
1220 | |
1221 | if (yin.error()) |
1222 | return 1; |
1223 | |
1224 | // Open output files. |
1225 | std::unique_ptr<llvm::ToolOutputFile> outputOdsDecl; |
1226 | if (!outputOdsDeclFilename.empty()) { |
1227 | outputOdsDecl = openOutputFile(outputFilename: outputOdsDeclFilename, errorMessage: &errorMessage); |
1228 | if (!outputOdsDecl) { |
1229 | llvm::errs() << errorMessage << "\n"; |
1230 | return 1; |
1231 | } |
1232 | } |
1233 | |
1234 | std::unique_ptr<llvm::ToolOutputFile> outputCppImpl; |
1235 | if (!outputCppImplFilename.empty()) { |
1236 | outputCppImpl = openOutputFile(outputFilename: outputCppImplFilename, errorMessage: &errorMessage); |
1237 | if (!outputCppImpl) { |
1238 | llvm::errs() << errorMessage << "\n"; |
1239 | return 1; |
1240 | } |
1241 | } |
1242 | |
1243 | if (!outputOdsDecl && !outputCppImpl) { |
1244 | llvm::errs() << "error: No output files specified\n"; |
1245 | return 1; |
1246 | } |
1247 | |
1248 | // Generate. |
1249 | GenerationContext genContext(&mlirContext, |
1250 | outputOdsDecl ? &outputOdsDecl->os() : nullptr, |
1251 | outputCppImpl ? &outputCppImpl->os() : nullptr); |
1252 | |
1253 | for (auto &opConfig : opConfigs) { |
1254 | if (!opConfig.metadata) { |
1255 | emitError(loc: genContext.getLoc()) |
1256 | << "missing operation metadata on subsequent op"; |
1257 | return 1; |
1258 | } |
1259 | |
1260 | genContext.setLoc(NameLoc::get( |
1261 | StringAttr::get(&mlirContext, opConfig.metadata->cppClassName))); |
1262 | if (failed(Result: generateOp(opConfig, genContext))) { |
1263 | return 1; |
1264 | } |
1265 | } |
1266 | |
1267 | if (outputOdsDecl) |
1268 | outputOdsDecl->keep(); |
1269 | if (outputCppImpl) |
1270 | outputCppImpl->keep(); |
1271 | |
1272 | return 0; |
1273 | } |
1274 |
Definitions
- LinalgYAMLContext
- LinalgOpMetadata
- SerializedAffineMap
- affineMap
- LinalgOperandDefKind
- LinalgOperandDef
- LinalgIteratorTypeDef
- LinalgIndexingMapsConfig
- ScalarFnKind
- ScalarFn
- ScalarExpression
- ScalarAssign
- LinalgStructuredOpConfig
- LinalgOpConfig
- MappingTraits
- mapping
- MappingTraits
- mapping
- MappingTraits
- mapping
- ScalarEnumerationTraits
- enumeration
- ScalarEnumerationTraits
- enumeration
- MappingTraits
- mapping
- MappingTraits
- mapping
- MappingTraits
- mapping
- MappingTraits
- mapping
- ScalarEnumerationTraits
- enumeration
- MappingTraits
- mapping
- ScalarTraits
- output
- input
- mustQuote
- GenerationContext
- GenerationContext
- getContext
- setLoc
- getLoc
- shouldGenerateOds
- shouldGenerateDefns
- odss
- defns
- generateCppExpression
- interleaveToString
- findTensorDefArgIndex
- findTypeValue
- findAssignment
- isFunctionAttribute
- isAttribute
- convertOperandKindToEnumName
- convertFunctionKindToEnumName
- bannerFormat
- structuredOpOdsHeaderFormat
- structuredOpBuilderFormat
- structuredOpIteratorTypesFormat
- rankPolyStructuredOpIteratorTypesFormat
- structuredOpIndexingMapsFormat
- rankPolyStructuredOpIndexingMapsFormat
- structuredOpFoldersFormat
- structuredOpParserFormat
- generateNamedGenericOpOds
- generateNamedGenericOpDefns
- generateOp
- inputFilename
- outputOdsDeclFilename
- outputCppImplFilename
Learn to use CMake with our Intro Training
Find out more