1//===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements an ODS (and C++) generator from a YAML form
10// derived from the mathematical expression of linalg named ops. Typically a
11// math oriented DSL will be used to export the essential representation to
12// this form, and maintaining the SOT at the math level (versus recreating it
13// in MLIR) is deemed to have systemic value.
14//
15//===----------------------------------------------------------------------===//
16
17#include "mlir/AsmParser/AsmParser.h"
18#include "mlir/IR/AffineMap.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/MLIRContext.h"
21#include "mlir/Support/FileUtilities.h"
22#include "mlir/Support/LLVM.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/Support/ToolOutputFile.h"
28#include "llvm/Support/YAMLTraits.h"
29#include <optional>
30
31using namespace mlir;
32
33using llvm::yaml::Input;
34
35#define DEBUG_TYPE "linalg-ods-gen"
36
37//===----------------------------------------------------------------------===//
38// Mapping structs (correspond to data types in the YAML description).
39// TODO: Since this is a schema/part of the contract, it should be moved to
40// a real header.
41//===----------------------------------------------------------------------===//
42
43namespace {
44
45struct LinalgYAMLContext {
46 MLIRContext *mlirContext;
47};
48
49struct LinalgOpMetadata {
50 std::string name;
51 std::string cppClassName;
52 std::optional<std::string> doc;
53 SmallVector<std::string> implements;
54 SmallVector<std::string> defines;
55};
56
57struct SerializedAffineMap {
58 AffineMapAttr affineMapAttr;
59
60 AffineMap affineMap() { return affineMapAttr.getValue(); }
61};
62
63enum class LinalgOperandDefKind {
64 InputTensor,
65 Scalar,
66 OutputTensor,
67 IndexAttr,
68 UnaryFnAttr,
69 BinaryFnAttr,
70 TernaryFnAttr,
71 TypeFnAttr
72};
73
74struct LinalgOperandDef {
75 std::string name;
76 LinalgOperandDefKind kind;
77 std::optional<std::string> typeVar;
78 std::optional<SerializedAffineMap> shapeMap;
79 std::optional<SerializedAffineMap> indexAttrMap;
80 std::optional<SmallVector<int64_t>> defaultIndices;
81 std::optional<std::string> defaultFn;
82};
83
84enum class LinalgIteratorTypeDef {
85 parallel,
86 reduction,
87};
88
89struct LinalgIndexingMapsConfig {
90 std::optional<SmallVector<SerializedAffineMap>> staticIndexingMaps;
91};
92
93struct ScalarExpression;
94
95enum class ScalarFnKind { Unary, Binary, Ternary, Type };
96
97struct ScalarFn {
98 ScalarFnKind kind;
99 std::optional<std::string> fnName;
100 std::optional<std::string> attrName;
101 std::optional<std::string> typeVar;
102 // NOTE: This must be of arity 1, but to break the self-referential cycle,
103 // we use a heap allocated vector.
104 std::vector<ScalarExpression> operands;
105};
106
107struct ScalarExpression {
108 std::optional<std::string> arg;
109 std::optional<std::string> constant;
110 std::optional<int64_t> index;
111 std::optional<ScalarFn> scalarFn;
112};
113
114struct ScalarAssign {
115 std::string arg;
116 ScalarExpression value;
117};
118
119struct LinalgStructuredOpConfig {
120 SmallVector<LinalgOperandDef, 4> args;
121 LinalgIndexingMapsConfig indexingMaps;
122 SmallVector<LinalgIteratorTypeDef, 4> iteratorTypes;
123 std::vector<ScalarAssign> assignments;
124};
125
126struct LinalgOpConfig {
127 std::optional<LinalgOpMetadata> metadata;
128 std::optional<LinalgStructuredOpConfig> structuredOp;
129};
130
131} // namespace
132
133//===----------------------------------------------------------------------===//
134// Mapping traits.
135//===----------------------------------------------------------------------===//
136
137LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef)
138LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap)
139LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef)
140LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign)
141LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression)
142LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig)
143
144namespace llvm {
145namespace yaml {
146
147/// Top-level type containing op metadata and one of a concrete op type.
148/// Currently, the only defined op type is `structured_op` (maps to
149/// `LinalgStructuredOpConfig`).
150template <>
151struct MappingTraits<LinalgOpConfig> {
152 static void mapping(IO &io, LinalgOpConfig &info) {
153 io.mapOptional(Key: "metadata", Val&: info.metadata);
154 io.mapOptional(Key: "structured_op", Val&: info.structuredOp);
155 }
156};
157
158/// A structured op models (at most) a single contraction by modeling
159/// - A list of named arguments (`LinalgOperandDef`), which can be inputs,
160/// outputs, or index attributes.
161/// - List of indexing maps (see `LinalgIndexingMaps`).
162/// - Iterator types (see `LinalgIteratorTypeDef`).
163/// - List of scalar level assignment (see `ScalarAssign`).
164template <>
165struct MappingTraits<LinalgStructuredOpConfig> {
166 static void mapping(IO &io, LinalgStructuredOpConfig &info) {
167 io.mapRequired(Key: "args", Val&: info.args);
168 io.mapRequired(Key: "indexing_maps", Val&: info.indexingMaps);
169 io.mapRequired(Key: "iterator_types", Val&: info.iteratorTypes);
170 io.mapRequired(Key: "assignments", Val&: info.assignments);
171 }
172};
173
174/// Maps a named tensor, scalar or attribute argument to an operation,
175/// consisting of:
176/// - `name`: Must be unique within the operation.
177/// - `usage`: How the argument is used (input, output, attribute, etc).
178/// - `type_var`: The symbolic type variable that binds to the element or self
179/// type of the tensor or scalar argument, respectively.
180/// - `shape_map`: An optional AffineMap from all op symbols to the shape of
181/// the argument. Only tensor arguments have a `shape_map`. Each shape must
182/// be normalized over the same list of symbols and have no dimension
183/// inputs.
184/// - `index_attr_map`: An optional AffineMap from all op symbols to the
185/// index attribute symbols. During op creation these symbols are replaced
186/// by the corresponding `name` index attribue values. Only index attribute
187/// arguments have an `index_attr_map`.
188/// - `default_indices`: An optional default initialization for index
189/// attribute arguments.
190/// - `default_fn`: An optional default initialization for function attribute
191/// arguments.
192template <>
193struct MappingTraits<LinalgOperandDef> {
194 static void mapping(IO &io, LinalgOperandDef &info) {
195 io.mapRequired(Key: "name", Val&: info.name);
196 io.mapRequired(Key: "kind", Val&: info.kind);
197 io.mapOptional(Key: "type_var", Val&: info.typeVar);
198 io.mapOptional(Key: "shape_map", Val&: info.shapeMap);
199 io.mapOptional(Key: "index_attr_map", Val&: info.indexAttrMap);
200 io.mapOptional(Key: "default_indices", Val&: info.defaultIndices);
201 io.mapOptional(Key: "default_fn", Val&: info.defaultFn);
202 }
203};
204
205/// Usage enum for a named argument.
206template <>
207struct ScalarEnumerationTraits<LinalgOperandDefKind> {
208 static void enumeration(IO &io, LinalgOperandDefKind &value) {
209 io.enumCase(Val&: value, Str: "input_tensor", ConstVal: LinalgOperandDefKind::InputTensor);
210 io.enumCase(Val&: value, Str: "scalar", ConstVal: LinalgOperandDefKind::Scalar);
211 io.enumCase(Val&: value, Str: "output_tensor", ConstVal: LinalgOperandDefKind::OutputTensor);
212 io.enumCase(Val&: value, Str: "index_attr", ConstVal: LinalgOperandDefKind::IndexAttr);
213 io.enumCase(Val&: value, Str: "unary_fn_attr", ConstVal: LinalgOperandDefKind::UnaryFnAttr);
214 io.enumCase(Val&: value, Str: "binary_fn_attr", ConstVal: LinalgOperandDefKind::BinaryFnAttr);
215 io.enumCase(Val&: value, Str: "ternary_fn_attr", ConstVal: LinalgOperandDefKind::TernaryFnAttr);
216 io.enumCase(Val&: value, Str: "type_fn_attr", ConstVal: LinalgOperandDefKind::TypeFnAttr);
217 }
218};
219
220/// Iterator type enum.
221template <>
222struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
223 static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
224 io.enumCase(Val&: value, Str: "parallel", ConstVal: LinalgIteratorTypeDef::parallel);
225 io.enumCase(Val&: value, Str: "reduction", ConstVal: LinalgIteratorTypeDef::reduction);
226 }
227};
228
229/// Metadata about the op (name, C++ name, and documentation).
230template <>
231struct MappingTraits<LinalgOpMetadata> {
232 static void mapping(IO &io, LinalgOpMetadata &info) {
233 io.mapRequired(Key: "name", Val&: info.name);
234 io.mapRequired(Key: "cpp_class_name", Val&: info.cppClassName);
235 io.mapOptional(Key: "doc", Val&: info.doc);
236 io.mapOptional(Key: "implements", Val&: info.implements);
237 io.mapOptional(Key: "defines", Val&: info.defines);
238 }
239};
240
241/// How the ops indexing maps are produced. Must be one of:
242/// - static_indexing_maps: A static list of AffineMaps, possibly with
243/// some symbols that bind to attributes of the op. Each indexing map must
244/// be normalized over the same list of dimensions, and its symbols must
245/// match the symbols for argument shapes.
246template <>
247struct MappingTraits<LinalgIndexingMapsConfig> {
248 static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
249 io.mapOptional(Key: "static_indexing_maps", Val&: info.staticIndexingMaps);
250 }
251};
252
253/// Models an assignment to a named output.
254/// - The `arg` name must match a named output.
255/// - The `value` is a scalar expression for computing the value to
256/// assign (see `ScalarExpression`).
257template <>
258struct MappingTraits<ScalarAssign> {
259 static void mapping(IO &io, ScalarAssign &info) {
260 io.mapRequired(Key: "arg", Val&: info.arg);
261 io.mapRequired(Key: "value", Val&: info.value);
262 }
263};
264
265/// A scalar expression (RHS of an assignment). Must be one of:
266/// - `scalar_arg`: An operation argument.
267/// - `scalar_const`: A constant definition.
268/// - `scalar_index`: An iteration index.
269/// - `scalar_fn`: A named function (see `ScalarFn`).
270template <>
271struct MappingTraits<ScalarExpression> {
272 static void mapping(IO &io, ScalarExpression &info) {
273 io.mapOptional(Key: "scalar_arg", Val&: info.arg);
274 io.mapOptional(Key: "scalar_const", Val&: info.constant);
275 io.mapOptional(Key: "scalar_index", Val&: info.index);
276 io.mapOptional(Key: "scalar_fn", Val&: info.scalarFn);
277 }
278};
279
280/// Scalar function kind enum.
281template <>
282struct ScalarEnumerationTraits<ScalarFnKind> {
283 static void enumeration(IO &io, ScalarFnKind &value) {
284 io.enumCase(Val&: value, Str: "unary", ConstVal: ScalarFnKind::Unary);
285 io.enumCase(Val&: value, Str: "binary", ConstVal: ScalarFnKind::Binary);
286 io.enumCase(Val&: value, Str: "ternary", ConstVal: ScalarFnKind::Ternary);
287 io.enumCase(Val&: value, Str: "type", ConstVal: ScalarFnKind::Type);
288 }
289};
290
291/// A scalar expression that evaluates a named function.
292/// Functions are generally "math" level and type polymorphic. Builtin
293/// functions include:
294/// - `add(lhs, rhs)`
295/// - `mul(lhs, rhs)`
296template <>
297struct MappingTraits<ScalarFn> {
298 static void mapping(IO &io, ScalarFn &info) {
299 io.mapRequired(Key: "kind", Val&: info.kind);
300 io.mapOptional(Key: "fn_name", Val&: info.fnName);
301 io.mapOptional(Key: "attr_name", Val&: info.attrName);
302 io.mapOptional(Key: "type_var", Val&: info.typeVar);
303 io.mapRequired(Key: "operands", Val&: info.operands);
304 }
305};
306
307/// Helper mapping which accesses an AffineMapAttr as a serialized string of
308/// the same.
309template <>
310struct ScalarTraits<SerializedAffineMap> {
311 static void output(const SerializedAffineMap &value, void *rawYamlContext,
312 raw_ostream &out) {
313 assert(value.affineMapAttr);
314 value.affineMapAttr.print(os&: out);
315 }
316 static StringRef input(StringRef scalar, void *rawYamlContext,
317 SerializedAffineMap &value) {
318 assert(rawYamlContext);
319 auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
320 if (auto attr = dyn_cast_or_null<AffineMapAttr>(
321 Val: mlir::parseAttribute(attrStr: scalar, context: yamlContext->mlirContext)))
322 value.affineMapAttr = attr;
323 else if (!value.affineMapAttr || !isa<AffineMapAttr>(Val: value.affineMapAttr))
324 return "could not parse as an affine map attribute";
325 return StringRef();
326 }
327 static QuotingType mustQuote(StringRef) { return QuotingType::None; }
328};
329
330} // namespace yaml
331} // namespace llvm
332
333namespace {
334
335//===----------------------------------------------------------------------===//
336// Generation utilities
337//===----------------------------------------------------------------------===//
338
339class GenerationContext {
340public:
341 GenerationContext(MLIRContext *context, raw_ostream *odsOut,
342 raw_ostream *defnOut)
343 : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut),
344 defnOut(defnOut) {}
345
346 MLIRContext *getContext() { return context; }
347
348 void setLoc(Location loc) { this->loc = loc; }
349 Location getLoc() { return loc; }
350
351 bool shouldGenerateOds() { return odsOut; }
352 bool shouldGenerateDefns() { return defnOut; }
353
354 raw_ostream &odss() {
355 assert(odsOut && "ODS stream not defined");
356 return *odsOut;
357 }
358
359 raw_ostream &defns() {
360 assert(defnOut && "Definition stream not defined");
361 return *defnOut;
362 }
363
364private:
365 MLIRContext *context;
366 Location loc;
367 raw_ostream *odsOut;
368 raw_ostream *defnOut;
369};
370
371} // namespace
372
373static std::string generateCppExpression(SerializedAffineMap self,
374 StringRef contextName) {
375 std::string printedStr;
376 llvm::raw_string_ostream printedSs(printedStr);
377 self.affineMapAttr.print(os&: printedSs);
378
379 static const char exprFormat[] =
380 R"FMT(llvm::cast<AffineMapAttr>(mlir::parseAttribute("{0}", {1})).getValue())FMT";
381 return llvm::formatv(Fmt: exprFormat, Vals&: printedStr, Vals&: contextName);
382}
383
384template <typename Container>
385static std::string interleaveToString(Container &container,
386 StringRef separator) {
387 std::string result;
388 llvm::raw_string_ostream ss(result);
389 llvm::interleave(container, ss, separator);
390 return result;
391}
392
393static std::optional<int>
394findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) {
395 for (const auto &it : llvm::enumerate(First&: args)) {
396 if (it.value().name == name)
397 return it.index();
398 }
399 return std::nullopt;
400}
401
402// Try to map the TypeVar to a predefined or an argument type.
403static std::optional<std::string>
404findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
405 // Handle all predefined types.
406 if (typeVar == "I32")
407 return std::string("helper.getIntegerType(32)");
408 if (typeVar == "I64")
409 return std::string("helper.getIntegerType(64)");
410 if (typeVar == "F32")
411 return std::string("helper.getFloat32Type()");
412 if (typeVar == "F64")
413 return std::string("helper.getFloat64Type()");
414
415 // Search all argument types.
416 for (const auto &it : llvm::enumerate(First&: args)) {
417 if (it.value().kind != LinalgOperandDefKind::InputTensor &&
418 it.value().kind != LinalgOperandDefKind::Scalar &&
419 it.value().kind != LinalgOperandDefKind::OutputTensor)
420 continue;
421 if (*it.value().typeVar == typeVar)
422 return llvm::formatv(Fmt: "block.getArgument({0}).getType()", Vals: it.index())
423 .str();
424 }
425
426 return std::nullopt;
427}
428
429static ScalarAssign *findAssignment(StringRef name,
430 std::vector<ScalarAssign> &assignments) {
431 for (auto &assign : assignments) {
432 if (assign.arg == name)
433 return &assign;
434 }
435 return nullptr;
436}
437
438// Return true if the operand is a function attribute.
439static bool isFunctionAttribute(LinalgOperandDefKind kind) {
440 return kind == LinalgOperandDefKind::UnaryFnAttr ||
441 kind == LinalgOperandDefKind::BinaryFnAttr ||
442 kind == LinalgOperandDefKind::TernaryFnAttr ||
443 kind == LinalgOperandDefKind::TypeFnAttr;
444}
445
446// Return true if the operand is an attribute.
447static bool isAttribute(LinalgOperandDefKind kind) {
448 return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind);
449}
450
451// Get the enum name for the given operand kind.
452std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
453 switch (kind) {
454 case LinalgOperandDefKind::UnaryFnAttr:
455 return std::string("UnaryFn");
456 case LinalgOperandDefKind::BinaryFnAttr:
457 return std::string("BinaryFn");
458 case LinalgOperandDefKind::TernaryFnAttr:
459 return std::string("TernaryFn");
460 case LinalgOperandDefKind::TypeFnAttr:
461 return std::string("TypeFn");
462 default:
463 break;
464 }
465 llvm_unreachable("unsupported function attribute kind");
466}
467
468// Get the enum name for the given function kind.
469std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
470 switch (kind) {
471 case ScalarFnKind::Unary:
472 return std::string("UnaryFn");
473 case ScalarFnKind::Binary:
474 return std::string("BinaryFn");
475 case ScalarFnKind::Ternary:
476 return std::string("TernaryFn");
477 case ScalarFnKind::Type:
478 return std::string("TypeFn");
479 }
480 llvm_unreachable("unsupported function kind");
481}
482
483//===----------------------------------------------------------------------===//
484// Templates
485//===----------------------------------------------------------------------===//
486
487// A single line banner format. Parameters:
488// {0}: Single line comment
489static const char bannerFormat[] = R"FMT(
490//===----------------------------------------------------------------------===//
491// {0}
492//===----------------------------------------------------------------------===//
493)FMT";
494
495//===----------------------------------------------------------------------===//
496// Named generic op generation.
497// These ops map at most a single contraction that complies with the limitations
498// of a linalg.generic.
499//===----------------------------------------------------------------------===//
500
501// Template for Linalg named ops' ODS definitions. Parameters:
502// {0}: ODS/C++ op name
503// {1}: assembly op mnemonic
504// {2}: op interface list
505// {3}: documentation (summary + description)
506// {4}: op attribute list
507// {5}: builder methods taking standalone attribute parameters
508// {6}: additional method defintions
509// {7}: additional methods for attributes used by indexing maps
510static const char structuredOpOdsHeaderFormat[] = R"FMT(
511//===----------------------------------------------------------------------===//
512// Op definition for {0}
513//===----------------------------------------------------------------------===//
514
515def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
516 /*extraInterfaces=*/[{2}])> {
517 {3}
518 let arguments = (ins
519 Variadic<AnyType>:$inputs,
520 Variadic<AnyShaped>:$outputs{4}
521 );
522 let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
523 let regions = (region AnyRegion:$region);
524
525 let skipDefaultBuilders = 1;
526 let builders = [
527 OpBuilder<
528 (ins "ValueRange":$inputs, "ValueRange":$outputs,
529 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
530 [{{
531 buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
532 attributes, {0}::getRegionBuilder());
533 }]>,
534 OpBuilder<
535 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
536 "ValueRange":$outputs,
537 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
538 [{{
539 buildStructuredOp($_builder, $_state, resultTensorTypes,
540 inputs, outputs, attributes, {0}::getRegionBuilder());
541 }]>,
542 OpBuilder<
543 (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
544 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
545 [{{
546 $_state.addOperands(operands);
547 $_state.addAttributes(attributes);
548 $_state.addTypes(resultTensorTypes);
549 (void)$_state.addRegion();
550 }]>
551 {5}
552 ];
553 let hasCustomAssemblyFormat = 1;
554 let hasFolder = 1;
555 {6}
556
557 let extraClassDeclaration = structuredOpsBaseDecls # [{{
558 // Auto-generated.
559 SmallVector<utils::IteratorType> getIteratorTypesArray();
560 ArrayAttr getIndexingMaps();
561 static void regionBuilder(ImplicitLocOpBuilder &b,
562 Block &block, ArrayRef<NamedAttribute> attrs,
563 function_ref<InFlightDiagnostic()> emitError);
564 static std::function<void(ImplicitLocOpBuilder &,
565 Block &, ArrayRef<NamedAttribute>, function_ref<InFlightDiagnostic()> emitError)>
566 getRegionBuilder() {{
567 return regionBuilder;
568 }
569
570 ::mlir::MutableOperandRange getDpsInitsMutable() {{
571 return getOutputsMutable();
572 }
573
574 // Generic methods.
575 static unsigned getNumRegionArgs();
576 std::string getLibraryCallName();
577 {7}
578 }];
579}
580)FMT";
581
582// Builder method taking attribute parameters. Parameters:
583// {0}: Class name
584// {1}: Comma interleaved attribute parameters
585// {2}: Attribute initialization
586static const char structuredOpBuilderFormat[] = R"FMT(
587 , OpBuilder<
588 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
589 "ValueRange":$outputs, {1},
590 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
591 [{{
592 {2}
593 buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
594 attributes, {0}::getRegionBuilder());
595 }]>
596)FMT";
597
598// The getIteratorTypesArray() method for structured ops. Parameters:
599// {0}: Class name
600// {1}: Comma interleaved iterator type names.
601static const char structuredOpIteratorTypesFormat[] =
602 R"FMT(
603SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
604 return SmallVector<utils::IteratorType>{{ {1} };
605}
606)FMT";
607
608// The getIteratorTypesArray() method for rank polymorphic structured ops.
609// Parameters:
610// {0}: Class name
611static const char rankPolyStructuredOpIteratorTypesFormat[] =
612 R"FMT(
613SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
614 int64_t rank = getRank(getDpsInitOperand(0));
615 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
616}
617)FMT";
618
619// The indexing_maps() method for structured ops. Parameters:
620// {0}: Class name
621// {1}: Comma-separated list of dimension variable names.
622// {2}: Statements
623static const char structuredOpIndexingMapsFormat[] = R"FMT(
624ArrayAttr {0}::getIndexingMaps() {{
625 static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
626 ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
627 if (cached)
628 return cached;
629
630 MLIRContext *context = getContext();
631 auto symbolBindings = getSymbolBindings(*this);
632 SmallVector<AffineMap> maps;
633 {1}
634 cached = Builder(context).getAffineMapArrayAttr(maps);
635 getOperation()->setAttr(memoizeAttr, cached);
636 return cached;
637}
638)FMT";
639
640// The indexing_maps() method for rank polymorphic structured ops. Parameters:
641// {0}: Class name
642static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT(
643ArrayAttr {0}::getIndexingMaps() {{
644 MLIRContext *context = getContext();
645 AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
646 AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
647 getNumParallelLoops(), context);
648 SmallVector<AffineMap> indexingMaps;
649 for (OpOperand &opOperand : getOperation()->getOpOperands())
650 indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap);
651 return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
652}
653)FMT";
654
655// Implementations of fold, getEffects and getSpeculatability.
656// Parameters:
657// {0}: Class name
658const char structuredOpFoldersFormat[] = R"FMT(
659LogicalResult {0}::fold(FoldAdaptor,
660 SmallVectorImpl<OpFoldResult> &) {{
661 return memref::foldMemRefCast(*this);
662}
663void {0}::getEffects(SmallVectorImpl<
664 SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
665 if (hasPureTensorSemantics()) return;
666 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
667}
668Speculation::Speculatability {0}::getSpeculatability() {{
669 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
670}
671)FMT";
672
673// Implementation of parse/print.
674// Parameters:
675// {0}: Class name
676static const char structuredOpParserFormat[] = R"FMT(
677ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
678 return ::parseNamedStructuredOp(parser, result,
679 {0}::getNumRegionArgs(), {0}::getRegionBuilder());
680}
681void {0}::print(OpAsmPrinter &p) {{
682 SmallVector<StringRef, 3> elidedAttrs = {{"operandSegmentSizes",
683 "linalg.memoized_indexing_maps"};
684 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
685 elidedAttrs);
686}
687)FMT";
688
689static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
690 GenerationContext &genContext) {
691 if (!genContext.shouldGenerateOds())
692 return success();
693
694 raw_ostream &os = genContext.odss();
695
696 std::string interfaceNameList;
697 std::string attrList;
698 std::string attrMethods;
699 std::string attrBuilder;
700
701 std::string doc;
702 if (opConfig.metadata->doc) {
703 static const char structuredOpDocFmt[] = R"FMT(
704 let summary = [{{{0}}];
705 let description = [{{{1}}];
706)FMT";
707 StringRef summary, description;
708 std::tie(args&: summary, args&: description) =
709 StringRef(*opConfig.metadata->doc).trim().split(Separator: "\n\n");
710
711 doc = llvm::formatv(Fmt: structuredOpDocFmt, Vals: summary.trim(), Vals: description.trim());
712 }
713
714 interfaceNameList = interleaveToString(container&: opConfig.metadata->implements, separator: ", ");
715
716 std::string definitionList;
717 for (const std::string &definition : opConfig.metadata->defines) {
718 static const char definitionFmt[] = "let {0} = 1;\n";
719 definitionList.append(str: llvm::formatv(Fmt: definitionFmt, Vals: definition));
720 }
721
722 if (llvm::any_of(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) {
723 return isAttribute(kind: arg.kind);
724 })) {
725 SmallVector<std::string> attrDefs;
726 SmallVector<std::string> attrParams;
727 SmallVector<std::string> attrStmts;
728 for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
729 static const char paramFmt[] = "\"Attribute\":${0}";
730 static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
731 // Add the type conversion attributes to the op definition and builders.
732 if (isFunctionAttribute(kind: arg.kind)) {
733 assert(arg.defaultFn);
734 std::string enumName = convertOperandKindToEnumName(kind: arg.kind);
735 static const char typeFmt[] = "{0}::{1}";
736 static const char defFmt[] =
737 "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}";
738 attrDefs.push_back(Elt: llvm::formatv(
739 Fmt: defFmt, Vals: llvm::formatv(Fmt: "{0}Attr", Vals&: enumName),
740 Vals: llvm::formatv(Fmt: typeFmt, Vals&: enumName, Vals&: arg.defaultFn), Vals&: arg.name));
741 attrParams.push_back(Elt: llvm::formatv(Fmt: paramFmt, Vals&: arg.name));
742 attrStmts.push_back(Elt: llvm::formatv(Fmt: stmtFmt, Vals&: arg.name));
743 }
744 // Add the index attributes to the op definition and builders.
745 if (arg.kind == LinalgOperandDefKind::IndexAttr) {
746 assert(arg.indexAttrMap.has_value());
747 assert(arg.defaultIndices.has_value());
748 size_t size = arg.indexAttrMap->affineMap().getNumResults();
749 assert(arg.defaultIndices->size() == size);
750 static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
751 static const char defFmt[] =
752 "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}";
753 std::string defaultVals;
754 llvm::raw_string_ostream ss(defaultVals);
755 llvm::interleave(
756 c: *arg.defaultIndices, os&: ss,
757 each_fn: [&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
758 separator: ", ");
759 attrDefs.push_back(Elt: llvm::formatv(Fmt: defFmt, Vals: llvm::formatv(Fmt: typeFmt, Vals&: size),
760 Vals&: ss.str(), Vals&: arg.name));
761 attrParams.push_back(Elt: llvm::formatv(Fmt: paramFmt, Vals&: arg.name));
762 attrStmts.push_back(Elt: llvm::formatv(Fmt: stmtFmt, Vals&: arg.name));
763 }
764 }
765 if (llvm::any_of(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) {
766 return arg.kind == LinalgOperandDefKind::IndexAttr;
767 })) {
768 attrMethods = R"(
769 bool hasDynamicIndexingMaps();
770 LogicalResult verifyIndexingMapRequiredAttributes();
771 )";
772 }
773 attrList = ",\n" + llvm::join(R&: attrDefs, Separator: ",\n");
774 attrBuilder = llvm::formatv(
775 Fmt: structuredOpBuilderFormat, Vals&: opConfig.metadata->cppClassName,
776 Vals: llvm::join(R&: attrParams, Separator: ", "), Vals: llvm::join(R&: attrStmts, Separator: "\n"));
777 }
778
779 os << llvm::formatv(Fmt: structuredOpOdsHeaderFormat,
780 Vals&: opConfig.metadata->cppClassName, Vals&: opConfig.metadata->name,
781 Vals&: interfaceNameList, Vals&: doc, Vals&: attrList, Vals&: attrBuilder,
782 Vals&: definitionList, Vals&: attrMethods);
783
784 return success();
785}
786
787static LogicalResult
788generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
789 GenerationContext &genContext) {
790 if (!genContext.shouldGenerateDefns())
791 return success();
792
793 raw_ostream &os = genContext.defns();
794 StringRef className = opConfig.metadata->cppClassName;
795
796 // Implementation banner.
797 std::string bannerComment = llvm::formatv(Fmt: "Implementation of {0}", Vals&: className);
798 os << llvm::formatv(Fmt: bannerFormat, Vals&: bannerComment);
799
800 // Compute the number of scalar and tensor arguments.
801 int64_t numOfArgs =
802 llvm::count_if(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) {
803 return arg.kind == LinalgOperandDefKind::InputTensor ||
804 arg.kind == LinalgOperandDefKind::Scalar ||
805 arg.kind == LinalgOperandDefKind::OutputTensor;
806 });
807
808 // An operation that accesses only scalars and scalar/rank zero tensors is
809 // rank polymorhpic. We implement rank polymorphism by generating different
810 // indexing maps and iterators that match the rank of the first output tensor.
811 // An operation is rank polymorphic if the iteration domain has rank zero.
812 bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty();
813
814 // Generate the iterator_types() method.
815 if (!isRankPolymorphic) {
816 std::string iteratorsStr;
817 llvm::raw_string_ostream ss(iteratorsStr);
818 llvm::interleaveComma(c: opConfig.structuredOp->iteratorTypes, os&: ss,
819 each_fn: [&](LinalgIteratorTypeDef it) {
820 switch (it) {
821 case LinalgIteratorTypeDef::parallel:
822 ss << "utils::IteratorType::parallel";
823 break;
824 case LinalgIteratorTypeDef::reduction:
825 ss << "utils::IteratorType::reduction";
826 break;
827 }
828 });
829 os << llvm::formatv(Fmt: structuredOpIteratorTypesFormat, Vals&: className,
830 Vals&: iteratorsStr);
831 } else {
832 os << llvm::formatv(Fmt: rankPolyStructuredOpIteratorTypesFormat, Vals&: className);
833 }
834
835 // Generating the getIndexingMaps() method.
836 if (auto &staticMaps =
837 opConfig.structuredOp->indexingMaps.staticIndexingMaps) {
838 if (staticMaps->empty())
839 return emitError(loc: genContext.getLoc()) << "op has no indexing maps";
840 if (!isRankPolymorphic) {
841 AffineMap firstMap = staticMaps->front().affineMap();
842
843 // Symbol bindings.
844 {
845 // For each symbol, generate a declaration for it, either with an
846 // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
847 // an attribute).
848 // TODO: Possibly lift into a top-level method.
849 static const char structuredOpSymbolBindingsFormat[] = R"FMT(
850static SmallVector<AffineExpr> getSymbolBindings({0} self) {
851 MLIRContext *context = self.getContext();
852 SmallVector<AffineExpr> exprs;
853{1}
854 return exprs;
855}
856)FMT";
857
858 unsigned symbolCount = firstMap.getNumSymbols();
859 SmallVector<std::string> symbolBindings;
860 for (unsigned i = 0; i < symbolCount; ++i) {
861 symbolBindings.push_back(Elt: llvm::formatv(
862 Fmt: " exprs.push_back(getAffineSymbolExpr({0}, context));", Vals&: i));
863 }
864
865 // Access an index attribute. Parameters:
866 // {0}: Attribute name
867 // {1}: Symbol position
868 // {2}: Attribute index
869 static const char structuredOpAccessAttrFormat[] = R"FMT(
870int64_t cst{1} = self.get{0}().getValues<int64_t>()[{2}];
871exprs.push_back(getAffineConstantExpr(cst{1}, context));
872)FMT";
873 // Update all symbol bindings mapped to an attribute.
874 for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
875 if (arg.kind != LinalgOperandDefKind::IndexAttr)
876 continue;
877 assert(arg.indexAttrMap);
878 for (auto [idx, result] :
879 llvm::enumerate(First: arg.indexAttrMap->affineMap().getResults())) {
880 if (auto symbol = dyn_cast<AffineSymbolExpr>(Val: result)) {
881 std::string argName = arg.name;
882 argName[0] = toupper(c: argName[0]);
883 symbolBindings[symbol.getPosition()] =
884 llvm::formatv(Fmt: structuredOpAccessAttrFormat, Vals&: argName,
885 Vals: symbol.getPosition(), Vals&: idx);
886 }
887 }
888 }
889
890 std::string symbolBindingsStr;
891 llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
892 llvm::interleave(c: symbolBindings, os&: symbolBindingsSs, separator: "\n");
893
894 os << llvm::formatv(Fmt: structuredOpSymbolBindingsFormat, Vals&: className,
895 Vals&: symbolBindingsStr);
896 }
897
898 // Indexing maps.
899 {
900 unsigned dimCount = firstMap.getNumDims();
901
902 // Generate a comma-separated list of dim identifiers to be passed to
903 // bindDims, ensuring tht AffineExpr identifiers are bound in the right
904 // order to the proper AffineDimExpr.
905 // This results in vars in scope like: d0, d1, d2...
906 SmallVector<unsigned> dimIndices;
907 for (unsigned i = 0; i < dimCount; ++i)
908 dimIndices.push_back(Elt: i);
909 std::string dimIdentsStr;
910 llvm::raw_string_ostream dimIdentsSs(dimIdentsStr);
911 llvm::interleaveComma(c: dimIndices, os&: dimIdentsSs,
912 each_fn: [&](unsigned i) { dimIdentsSs << "d" << i; });
913
914 // Statements to add and simplify each affine map.
915 SmallVector<std::string> stmts;
916 for (auto &indexingMap : *staticMaps) {
917 // TODO: Assert that dim and symbol count match the first.
918 stmts.push_back(
919 Elt: llvm::formatv(Fmt: "maps.push_back({0});",
920 Vals: generateCppExpression(self: indexingMap, contextName: "context")));
921 stmts.push_back(Elt: llvm::formatv(
922 Fmt: "maps.back() = "
923 "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
924 "symbolBindings, {0}, 0));",
925 Vals&: dimCount));
926 }
927
928 // TODO: This needs to be memoized and/or converted to non-parser based
929 // C++ codegen prior to real use.
930 os << llvm::formatv(Fmt: structuredOpIndexingMapsFormat, Vals&: className,
931 Vals: interleaveToString(container&: stmts, separator: "\n "));
932 }
933 } else {
934 os << llvm::formatv(Fmt: rankPolyStructuredOpIndexingMapsFormat, Vals&: className);
935 }
936 } else {
937 return emitError(loc: genContext.getLoc())
938 << "generating code for non static indexing maps not currently "
939 "supported";
940 }
941
942 // getNumRegionArgs()
943 {
944 // Generates a getNumRegionArgs() method. Parameters:
945 // {0}: Class name
946 // {1}: Number of region args
947 static const char structuredOpGetNumRegionArgsFormat[] = R"FMT(
948unsigned {0}::getNumRegionArgs() {{ return {1}; }
949)FMT";
950 os << llvm::formatv(Fmt: structuredOpGetNumRegionArgsFormat, Vals&: className,
951 Vals&: numOfArgs);
952 }
953
954 // getLibraryCallName()
955 {
956 // Generates a getLibraryCallName method. Parameters:
957 // {0}: Class name
958 static const char structuredOpGetLibraryCallFormat[] = R"FMT(
959std::string {0}::getLibraryCallName() {{
960 return generateLibraryCallName(getOperation());
961}
962)FMT";
963 os << llvm::formatv(Fmt: structuredOpGetLibraryCallFormat, Vals&: className);
964 }
965
966 // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
967 if (llvm::any_of(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) {
968 return arg.kind == LinalgOperandDefKind::IndexAttr;
969 })) {
970 std::vector<std::string> attrVerifications;
971 for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
972 if (arg.kind != LinalgOperandDefKind::IndexAttr)
973 continue;
974 assert(arg.indexAttrMap);
975 // Verify index attribute. Paramters:
976 // {0}: Attribute name
977 // {1}: Attribute size
978 static const char attrFmt[] = R"FMT(
979if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
980 if (!attr.getType().getElementType().isInteger(64))
981 return op->emitError("incorrect element type for index attribute '{0}'");
982 if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
983 return op->emitError("incorrect shape for index attribute '{0}'");
984}
985)FMT";
986 attrVerifications.push_back(x: llvm::formatv(
987 Fmt: attrFmt, Vals&: arg.name, Vals: arg.indexAttrMap->affineMap().getNumResults()));
988 }
989
990 // Generates the verifyIndexingMapRequiredAttributes method. Parameters:
991 // {0}: Class name
992 // {1}: Attribute verification
993 static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT(
994bool {0}::hasDynamicIndexingMaps() {{ return true; }
995LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
996 Operation *op = getOperation();
997 {1}
998 return success();
999}
1000)FMT";
1001 os << llvm::formatv(Fmt: structuredOpVerifyIndexingMapRequiredAttributes,
1002 Vals&: className, Vals: llvm::join(R&: attrVerifications, Separator: "\n"));
1003 }
1004
1005 // regionBuilder()
1006 {
1007 // Generates a regionBuilder method. Parameters.
1008 // {0}: Class name
1009 // {1}: Number of args
1010 // {2}: Attributes
1011 // {3}: Statements
1012 static const char structuredOpRegionBuilderFormat[] = R"FMT(
1013void {0}::regionBuilder(ImplicitLocOpBuilder &b,
1014 Block &block, ArrayRef<NamedAttribute> attrs,
1015 function_ref<InFlightDiagnostic()> emitError) {{
1016 assert({1} > 0 && block.getNumArguments() == {1} &&
1017 "{0} regionBuilder expects {1} (>=0) args");
1018 RegionBuilderHelper helper(b, block);
1019 SmallVector<Value> yields;
1020 {2}
1021 {3}
1022 helper.yieldOutputs(yields);
1023}
1024)FMT";
1025 auto &args = opConfig.structuredOp->args;
1026 auto &assignments = opConfig.structuredOp->assignments;
1027 size_t generatedAssignmentCount = 0;
1028 int localCounter = 0;
1029 SmallVector<std::string> attrs;
1030 SmallVector<std::string> stmts;
1031 for (LinalgOperandDef &arg : args) {
1032 if (!isFunctionAttribute(kind: arg.kind))
1033 continue;
1034 // Obtain the type function attribute values. Parameters.
1035 // {0}: enum name
1036 // {1}: attribute name
1037 // {2}: default type function name
1038 static const char attrDef[] = R"FMT(
1039 {0} {1}Val = {0}::{2};
1040 auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
1041 return attr.getName() == "{1}"; });
1042 if ({1}Iter != attrs.end()) {{
1043 if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
1044 {1}Val = attr.getValue();
1045 }
1046)FMT";
1047 std::string enumName = convertOperandKindToEnumName(kind: arg.kind);
1048 attrs.push_back(
1049 Elt: llvm::formatv(Fmt: attrDef, Vals&: enumName, Vals&: arg.name, Vals&: arg.defaultFn));
1050 }
1051 for (LinalgOperandDef &arg : args) {
1052 if (arg.kind != LinalgOperandDefKind::OutputTensor)
1053 continue;
1054
1055 // Find the assignment that correlates with the argument.
1056 ScalarAssign *assignment = findAssignment(name: arg.name, assignments);
1057 if (!assignment)
1058 return emitError(loc: genContext.getLoc())
1059 << "no assignment found for output argument " << arg.name;
1060 ++generatedAssignmentCount;
1061
1062 // Recursively generate the expression.
1063 std::function<std::optional<std::string>(ScalarExpression &)>
1064 generateExpression =
1065 [&](ScalarExpression &expression) -> std::optional<std::string> {
1066 if (expression.arg) {
1067 // Argument reference.
1068 std::optional<int> argIndex =
1069 findTensorDefArgIndex(name: *expression.arg, args);
1070 if (!argIndex) {
1071 emitError(loc: genContext.getLoc())
1072 << "scalar argument not defined on the op: " << *expression.arg;
1073 return std::nullopt;
1074 }
1075 return std::string(
1076 llvm::formatv(Fmt: "block.getArgument({0})", Vals&: *argIndex));
1077 }
1078 if (expression.constant) {
1079 std::string cppIdent = llvm::formatv(Fmt: "value{0}", Vals&: ++localCounter);
1080 stmts.push_back(
1081 Elt: llvm::formatv(Fmt: R"FMT(Value {0} = helper.constant("{1}");)FMT",
1082 Vals&: cppIdent, Vals&: expression.constant));
1083 return cppIdent;
1084 }
1085 if (expression.index) {
1086 // Access an iteration index.
1087 std::string cppIdent = llvm::formatv(Fmt: "value{0}", Vals&: ++localCounter);
1088 stmts.push_back(Elt: llvm::formatv(Fmt: "Value {0} = helper.index({1});",
1089 Vals&: cppIdent, Vals&: *expression.index));
1090 return cppIdent;
1091 }
1092 if (expression.scalarFn) {
1093 std::string enumName =
1094 convertFunctionKindToEnumName(kind: expression.scalarFn->kind);
1095
1096 // Get the function or attribute name.
1097 assert(expression.scalarFn->fnName || expression.scalarFn->attrName);
1098 std::string funcType;
1099 if (expression.scalarFn->fnName) {
1100 funcType = llvm::formatv(Fmt: "{0}::{1}", Vals&: enumName,
1101 Vals&: *expression.scalarFn->fnName);
1102 }
1103 if (expression.scalarFn->attrName) {
1104 if (llvm::none_of(Range&: args, P: [&](LinalgOperandDef &arg) {
1105 return isFunctionAttribute(kind: arg.kind) &&
1106 arg.name == *expression.scalarFn->attrName;
1107 })) {
1108 emitError(loc: genContext.getLoc()) << "missing function attribute "
1109 << *expression.scalarFn->attrName;
1110 }
1111 funcType = llvm::formatv(Fmt: "{0}Val", Vals&: *expression.scalarFn->attrName);
1112 }
1113 assert(!funcType.empty());
1114
1115 // Add the optional type parameter to the operands.
1116 SmallVector<std::string> operandCppValues;
1117 if (expression.scalarFn->kind == ScalarFnKind::Type) {
1118 assert(expression.scalarFn->typeVar.has_value());
1119 std::optional<std::string> typeCppValue =
1120 findTypeValue(typeVar: *expression.scalarFn->typeVar, args);
1121 if (!typeCppValue) {
1122 emitError(loc: genContext.getLoc())
1123 << "type variable " << *expression.scalarFn->typeVar
1124 << ", used in a type conversion, must map to a predefined or "
1125 << "an argument type but it does not";
1126 return std::nullopt;
1127 }
1128 operandCppValues.push_back(Elt: *typeCppValue);
1129 }
1130
1131 // Collect the scalar operands.
1132 for (ScalarExpression &operand : expression.scalarFn->operands) {
1133 auto operandCppValue = generateExpression(operand);
1134 if (!operandCppValue)
1135 return std::nullopt;
1136 operandCppValues.push_back(Elt: *operandCppValue);
1137 }
1138
1139 // Call the function builder.
1140 std::string cppIdent = llvm::formatv(Fmt: "value{0}", Vals&: ++localCounter);
1141 stmts.push_back(Elt: llvm::formatv(
1142 Fmt: R"mlir(
1143 Value {0} = helper.build{1}({2}, {3}, emitError);
1144 if (!{0})
1145 return;
1146 )mlir",
1147 Vals&: cppIdent, Vals&: enumName, Vals&: funcType,
1148 Vals: interleaveToString(container&: operandCppValues, separator: ", ")));
1149 return cppIdent;
1150 }
1151 emitError(loc: genContext.getLoc()) << "unknown ScalarExpression type";
1152 return std::nullopt;
1153 };
1154 std::optional<std::string> cppValue =
1155 generateExpression(assignment->value);
1156 if (!cppValue)
1157 return failure();
1158 stmts.push_back(Elt: llvm::formatv(Fmt: "yields.push_back({0});", Vals&: *cppValue));
1159 }
1160
1161 if (generatedAssignmentCount != assignments.size())
1162 return emitError(loc: genContext.getLoc())
1163 << "mismatched number of assignments vs output arguments";
1164
1165 os << llvm::formatv(Fmt: structuredOpRegionBuilderFormat, Vals&: className, Vals&: numOfArgs,
1166 Vals: interleaveToString(container&: attrs, separator: "\n "),
1167 Vals: interleaveToString(container&: stmts, separator: "\n "));
1168 }
1169
1170 // Parser and printer.
1171 os << llvm::formatv(Fmt: structuredOpParserFormat, Vals&: className);
1172
1173 // Canonicalizers and folders.
1174 os << llvm::formatv(Fmt: structuredOpFoldersFormat, Vals&: className);
1175
1176 return success();
1177}
1178
1179static LogicalResult generateOp(LinalgOpConfig &opConfig,
1180 GenerationContext &genContext) {
1181 // Switch on op type being generated.
1182 if (opConfig.structuredOp) {
1183 return success(
1184 IsSuccess: succeeded(Result: generateNamedGenericOpOds(opConfig, genContext)) &&
1185 succeeded(Result: generateNamedGenericOpDefns(opConfig, genContext)));
1186 }
1187 return emitError(loc: genContext.getLoc()) << "unsupported operation type";
1188}
1189
1190//===----------------------------------------------------------------------===//
1191// Command line options and main
1192//===----------------------------------------------------------------------===//
1193
1194static llvm::cl::opt<std::string>
1195 inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
1196 llvm::cl::init(Val: "-"), llvm::cl::value_desc("YAML filename"));
1197
1198static llvm::cl::opt<std::string>
1199 outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
1200 llvm::cl::value_desc("filename"), llvm::cl::init(Val: ""));
1201
1202static llvm::cl::opt<std::string>
1203 outputCppImplFilename("o-impl",
1204 llvm::cl::desc("C++ implementation file name"),
1205 llvm::cl::value_desc("filename"), llvm::cl::init(Val: ""));
1206
1207int main(int argc, char **argv) {
1208 llvm::cl::ParseCommandLineOptions(argc, argv, Overview: "Linalg ODS Gen from YAML");
1209
1210 // Set up the input file.
1211 std::string errorMessage;
1212 std::unique_ptr<llvm::MemoryBuffer> file =
1213 mlir::openInputFile(inputFilename, errorMessage: &errorMessage);
1214 if (!file) {
1215 llvm::errs() << errorMessage << "\n";
1216 return 1;
1217 }
1218
1219 MLIRContext mlirContext;
1220 LinalgYAMLContext yamlContext{.mlirContext: &mlirContext};
1221
1222 std::vector<LinalgOpConfig> opConfigs;
1223
1224 // Parse input.
1225 Input yin(file->getBuffer(), &yamlContext);
1226 yin >> opConfigs;
1227
1228 if (yin.error())
1229 return 1;
1230
1231 // Open output files.
1232 std::unique_ptr<llvm::ToolOutputFile> outputOdsDecl;
1233 if (!outputOdsDeclFilename.empty()) {
1234 outputOdsDecl = openOutputFile(outputFilename: outputOdsDeclFilename, errorMessage: &errorMessage);
1235 if (!outputOdsDecl) {
1236 llvm::errs() << errorMessage << "\n";
1237 return 1;
1238 }
1239 }
1240
1241 std::unique_ptr<llvm::ToolOutputFile> outputCppImpl;
1242 if (!outputCppImplFilename.empty()) {
1243 outputCppImpl = openOutputFile(outputFilename: outputCppImplFilename, errorMessage: &errorMessage);
1244 if (!outputCppImpl) {
1245 llvm::errs() << errorMessage << "\n";
1246 return 1;
1247 }
1248 }
1249
1250 if (!outputOdsDecl && !outputCppImpl) {
1251 llvm::errs() << "error: No output files specified\n";
1252 return 1;
1253 }
1254
1255 // Generate.
1256 GenerationContext genContext(&mlirContext,
1257 outputOdsDecl ? &outputOdsDecl->os() : nullptr,
1258 outputCppImpl ? &outputCppImpl->os() : nullptr);
1259
1260 for (auto &opConfig : opConfigs) {
1261 if (!opConfig.metadata) {
1262 emitError(loc: genContext.getLoc())
1263 << "missing operation metadata on subsequent op";
1264 return 1;
1265 }
1266
1267 genContext.setLoc(NameLoc::get(
1268 name: StringAttr::get(context: &mlirContext, bytes: opConfig.metadata->cppClassName)));
1269 if (failed(Result: generateOp(opConfig, genContext))) {
1270 return 1;
1271 }
1272 }
1273
1274 if (outputOdsDecl)
1275 outputOdsDecl->keep();
1276 if (outputCppImpl)
1277 outputCppImpl->keep();
1278
1279 return 0;
1280}
1281

source code of mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp