1//===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements an ODS (and C++) generator from a YAML form
10// derived from the mathematical expression of linalg named ops. Typically a
11// math oriented DSL will be used to export the essential representation to
12// this form, and maintaining the SOT at the math level (versus recreating it
13// in MLIR) is deemed to have systemic value.
14//
15//===----------------------------------------------------------------------===//
16
17#include "mlir/AsmParser/AsmParser.h"
18#include "mlir/IR/AffineMap.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/MLIRContext.h"
21#include "mlir/Support/FileUtilities.h"
22#include "mlir/Support/LLVM.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/Support/ToolOutputFile.h"
28#include "llvm/Support/YAMLTraits.h"
29#include <optional>
30
31using namespace mlir;
32
33using llvm::yaml::Input;
34using llvm::yaml::MappingTraits;
35using llvm::yaml::ScalarEnumerationTraits;
36using llvm::yaml::ScalarTraits;
37
38#define DEBUG_TYPE "linalg-ods-gen"
39
40//===----------------------------------------------------------------------===//
41// Mapping structs (correspond to data types in the YAML description).
42// TODO: Since this is a schema/part of the contract, it should be moved to
43// a real header.
44//===----------------------------------------------------------------------===//
45
46namespace {
47
48struct LinalgYAMLContext {
49 MLIRContext *mlirContext;
50};
51
52struct LinalgOpMetadata {
53 std::string name;
54 std::string cppClassName;
55 std::optional<std::string> doc;
56 SmallVector<std::string> implements;
57 SmallVector<std::string> defines;
58};
59
60struct SerializedAffineMap {
61 AffineMapAttr affineMapAttr;
62
63 AffineMap affineMap() { return affineMapAttr.getValue(); }
64};
65
66enum class LinalgOperandDefKind {
67 InputTensor,
68 Scalar,
69 OutputTensor,
70 IndexAttr,
71 UnaryFnAttr,
72 BinaryFnAttr,
73 TypeFnAttr
74};
75
76struct LinalgOperandDef {
77 std::string name;
78 LinalgOperandDefKind kind;
79 std::optional<std::string> typeVar;
80 std::optional<SerializedAffineMap> shapeMap;
81 std::optional<SerializedAffineMap> indexAttrMap;
82 std::optional<SmallVector<int64_t>> defaultIndices;
83 std::optional<std::string> defaultFn;
84};
85
86enum class LinalgIteratorTypeDef {
87 parallel,
88 reduction,
89};
90
91struct LinalgIndexingMapsConfig {
92 std::optional<SmallVector<SerializedAffineMap>> staticIndexingMaps;
93};
94
95struct ScalarExpression;
96
97enum class ScalarFnKind { Unary, Binary, Type };
98
99struct ScalarFn {
100 ScalarFnKind kind;
101 std::optional<std::string> fnName;
102 std::optional<std::string> attrName;
103 std::optional<std::string> typeVar;
104 // NOTE: This must be of arity 1, but to break the self-referential cycle,
105 // we use a heap allocated vector.
106 std::vector<ScalarExpression> operands;
107};
108
109struct ScalarExpression {
110 std::optional<std::string> arg;
111 std::optional<std::string> constant;
112 std::optional<int64_t> index;
113 std::optional<ScalarFn> scalarFn;
114};
115
116struct ScalarAssign {
117 std::string arg;
118 ScalarExpression value;
119};
120
121struct LinalgStructuredOpConfig {
122 SmallVector<LinalgOperandDef> args;
123 LinalgIndexingMapsConfig indexingMaps;
124 SmallVector<LinalgIteratorTypeDef> iteratorTypes;
125 std::vector<ScalarAssign> assignments;
126};
127
128struct LinalgOpConfig {
129 std::optional<LinalgOpMetadata> metadata;
130 std::optional<LinalgStructuredOpConfig> structuredOp;
131};
132
133} // namespace
134
135//===----------------------------------------------------------------------===//
136// Mapping traits.
137//===----------------------------------------------------------------------===//
138
139LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef)
140LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap)
141LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef)
142LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign)
143LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression)
144LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig)
145
146namespace llvm {
147namespace yaml {
148
149/// Top-level type containing op metadata and one of a concrete op type.
150/// Currently, the only defined op type is `structured_op` (maps to
151/// `LinalgStructuredOpConfig`).
152template <>
153struct MappingTraits<LinalgOpConfig> {
154 static void mapping(IO &io, LinalgOpConfig &info) {
155 io.mapOptional(Key: "metadata", Val&: info.metadata);
156 io.mapOptional(Key: "structured_op", Val&: info.structuredOp);
157 }
158};
159
160/// A structured op models (at most) a single contraction by modeling
161/// - A list of named arguments (`LinalgOperandDef`), which can be inputs,
162/// outputs, or index attributes.
163/// - List of indexing maps (see `LinalgIndexingMaps`).
164/// - Iterator types (see `LinalgIteratorTypeDef`).
165/// - List of scalar level assignment (see `ScalarAssign`).
166template <>
167struct MappingTraits<LinalgStructuredOpConfig> {
168 static void mapping(IO &io, LinalgStructuredOpConfig &info) {
169 io.mapRequired(Key: "args", Val&: info.args);
170 io.mapRequired(Key: "indexing_maps", Val&: info.indexingMaps);
171 io.mapRequired(Key: "iterator_types", Val&: info.iteratorTypes);
172 io.mapRequired(Key: "assignments", Val&: info.assignments);
173 }
174};
175
176/// Maps a named tensor, scalar or attribute argument to an operation,
177/// consisting of:
178/// - `name`: Must be unique within the operation.
179/// - `usage`: How the argument is used (input, output, attribute, etc).
180/// - `type_var`: The symbolic type variable that binds to the element or self
181/// type of the tensor or scalar argument, respectively.
182/// - `shape_map`: An optional AffineMap from all op symbols to the shape of
183/// the argument. Only tensor arguments have a `shape_map`. Each shape must
184/// be normalized over the same list of symbols and have no dimension
185/// inputs.
186/// - `index_attr_map`: An optional AffineMap from all op symbols to the
187/// index attribute symbols. During op creation these symbols are replaced
188/// by the corresponding `name` index attribue values. Only index attribute
189/// arguments have an `index_attr_map`.
190/// - `default_indices`: An optional default initialization for index
191/// attribute arguments.
192/// - `default_fn`: An optional default initialization for function attribute
193/// arguments.
194template <>
195struct MappingTraits<LinalgOperandDef> {
196 static void mapping(IO &io, LinalgOperandDef &info) {
197 io.mapRequired(Key: "name", Val&: info.name);
198 io.mapRequired(Key: "kind", Val&: info.kind);
199 io.mapOptional(Key: "type_var", Val&: info.typeVar);
200 io.mapOptional(Key: "shape_map", Val&: info.shapeMap);
201 io.mapOptional(Key: "index_attr_map", Val&: info.indexAttrMap);
202 io.mapOptional(Key: "default_indices", Val&: info.defaultIndices);
203 io.mapOptional(Key: "default_fn", Val&: info.defaultFn);
204 }
205};
206
207/// Usage enum for a named argument.
208template <>
209struct ScalarEnumerationTraits<LinalgOperandDefKind> {
210 static void enumeration(IO &io, LinalgOperandDefKind &value) {
211 io.enumCase(Val&: value, Str: "input_tensor", ConstVal: LinalgOperandDefKind::InputTensor);
212 io.enumCase(Val&: value, Str: "scalar", ConstVal: LinalgOperandDefKind::Scalar);
213 io.enumCase(Val&: value, Str: "output_tensor", ConstVal: LinalgOperandDefKind::OutputTensor);
214 io.enumCase(Val&: value, Str: "index_attr", ConstVal: LinalgOperandDefKind::IndexAttr);
215 io.enumCase(Val&: value, Str: "unary_fn_attr", ConstVal: LinalgOperandDefKind::UnaryFnAttr);
216 io.enumCase(Val&: value, Str: "binary_fn_attr", ConstVal: LinalgOperandDefKind::BinaryFnAttr);
217 io.enumCase(Val&: value, Str: "type_fn_attr", ConstVal: LinalgOperandDefKind::TypeFnAttr);
218 }
219};
220
221/// Iterator type enum.
222template <>
223struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
224 static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
225 io.enumCase(Val&: value, Str: "parallel", ConstVal: LinalgIteratorTypeDef::parallel);
226 io.enumCase(Val&: value, Str: "reduction", ConstVal: LinalgIteratorTypeDef::reduction);
227 }
228};
229
230/// Metadata about the op (name, C++ name, and documentation).
231template <>
232struct MappingTraits<LinalgOpMetadata> {
233 static void mapping(IO &io, LinalgOpMetadata &info) {
234 io.mapRequired(Key: "name", Val&: info.name);
235 io.mapRequired(Key: "cpp_class_name", Val&: info.cppClassName);
236 io.mapOptional(Key: "doc", Val&: info.doc);
237 io.mapOptional(Key: "implements", Val&: info.implements);
238 io.mapOptional(Key: "defines", Val&: info.defines);
239 }
240};
241
242/// How the ops indexing maps are produced. Must be one of:
243/// - static_indexing_maps: A static list of AffineMaps, possibly with
244/// some symbols that bind to attributes of the op. Each indexing map must
245/// be normalized over the same list of dimensions, and its symbols must
246/// match the symbols for argument shapes.
247template <>
248struct MappingTraits<LinalgIndexingMapsConfig> {
249 static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
250 io.mapOptional(Key: "static_indexing_maps", Val&: info.staticIndexingMaps);
251 }
252};
253
254/// Models an assignment to a named output.
255/// - The `arg` name must match a named output.
256/// - The `value` is a scalar expression for computing the value to
257/// assign (see `ScalarExpression`).
258template <>
259struct MappingTraits<ScalarAssign> {
260 static void mapping(IO &io, ScalarAssign &info) {
261 io.mapRequired(Key: "arg", Val&: info.arg);
262 io.mapRequired(Key: "value", Val&: info.value);
263 }
264};
265
266/// A scalar expression (RHS of an assignment). Must be one of:
267/// - `scalar_arg`: An operation argument.
268/// - `scalar_const`: A constant definition.
269/// - `scalar_index`: An iteration index.
270/// - `scalar_fn`: A named function (see `ScalarFn`).
271template <>
272struct MappingTraits<ScalarExpression> {
273 static void mapping(IO &io, ScalarExpression &info) {
274 io.mapOptional(Key: "scalar_arg", Val&: info.arg);
275 io.mapOptional(Key: "scalar_const", Val&: info.constant);
276 io.mapOptional(Key: "scalar_index", Val&: info.index);
277 io.mapOptional(Key: "scalar_fn", Val&: info.scalarFn);
278 }
279};
280
281/// Scalar function kind enum.
282template <>
283struct ScalarEnumerationTraits<ScalarFnKind> {
284 static void enumeration(IO &io, ScalarFnKind &value) {
285 io.enumCase(Val&: value, Str: "unary", ConstVal: ScalarFnKind::Unary);
286 io.enumCase(Val&: value, Str: "binary", ConstVal: ScalarFnKind::Binary);
287 io.enumCase(Val&: value, Str: "type", ConstVal: ScalarFnKind::Type);
288 }
289};
290
291/// A scalar expression that evaluates a named function.
292/// Functions are generally "math" level and type polymorphic. Builtin
293/// functions include:
294/// - `add(lhs, rhs)`
295/// - `mul(lhs, rhs)`
296template <>
297struct MappingTraits<ScalarFn> {
298 static void mapping(IO &io, ScalarFn &info) {
299 io.mapRequired(Key: "kind", Val&: info.kind);
300 io.mapOptional(Key: "fn_name", Val&: info.fnName);
301 io.mapOptional(Key: "attr_name", Val&: info.attrName);
302 io.mapOptional(Key: "type_var", Val&: info.typeVar);
303 io.mapRequired(Key: "operands", Val&: info.operands);
304 }
305};
306
307/// Helper mapping which accesses an AffineMapAttr as a serialized string of
308/// the same.
309template <>
310struct ScalarTraits<SerializedAffineMap> {
311 static void output(const SerializedAffineMap &value, void *rawYamlContext,
312 raw_ostream &out) {
313 assert(value.affineMapAttr);
314 value.affineMapAttr.print(os&: out);
315 }
316 static StringRef input(StringRef scalar, void *rawYamlContext,
317 SerializedAffineMap &value) {
318 assert(rawYamlContext);
319 auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
320 if (auto attr = dyn_cast_or_null<AffineMapAttr>(
321 Val: mlir::parseAttribute(attrStr: scalar, context: yamlContext->mlirContext)))
322 value.affineMapAttr = attr;
323 else if (!value.affineMapAttr || !isa<AffineMapAttr>(Val: value.affineMapAttr))
324 return "could not parse as an affine map attribute";
325 return StringRef();
326 }
327 static QuotingType mustQuote(StringRef) { return QuotingType::None; }
328};
329
330} // namespace yaml
331} // namespace llvm
332
333namespace {
334
335//===----------------------------------------------------------------------===//
336// Generation utilities
337//===----------------------------------------------------------------------===//
338
339class GenerationContext {
340public:
341 GenerationContext(MLIRContext *context, raw_ostream *odsOut,
342 raw_ostream *defnOut)
343 : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut),
344 defnOut(defnOut) {}
345
346 MLIRContext *getContext() { return context; }
347
348 void setLoc(Location loc) { this->loc = loc; }
349 Location getLoc() { return loc; }
350
351 bool shouldGenerateOds() { return odsOut; }
352 bool shouldGenerateDefns() { return defnOut; }
353
354 raw_ostream &odss() {
355 assert(odsOut && "ODS stream not defined");
356 return *odsOut;
357 }
358
359 raw_ostream &defns() {
360 assert(defnOut && "Definition stream not defined");
361 return *defnOut;
362 }
363
364private:
365 MLIRContext *context;
366 Location loc;
367 raw_ostream *odsOut;
368 raw_ostream *defnOut;
369};
370
371} // namespace
372
373static std::string generateCppExpression(SerializedAffineMap self,
374 StringRef contextName) {
375 std::string printedStr;
376 llvm::raw_string_ostream printedSs(printedStr);
377 self.affineMapAttr.print(os&: printedSs);
378 printedSs.flush();
379
380 static const char exprFormat[] =
381 R"FMT(llvm::cast<AffineMapAttr>(mlir::parseAttribute("{0}", {1})).getValue())FMT";
382 return llvm::formatv(Fmt: exprFormat, Vals&: printedStr, Vals&: contextName);
383}
384
385template <typename Container>
386static std::string interleaveToString(Container &container,
387 StringRef separator) {
388 std::string result;
389 llvm::raw_string_ostream ss(result);
390 llvm::interleave(container, ss, separator);
391 ss.flush();
392 return result;
393}
394
395static std::optional<int>
396findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) {
397 for (const auto &it : llvm::enumerate(First&: args)) {
398 if (it.value().name == name)
399 return it.index();
400 }
401 return std::nullopt;
402}
403
404// Try to map the TypeVar to a predefined or an argument type.
405static std::optional<std::string>
406findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
407 // Handle all predefined types.
408 if (typeVar == "I32")
409 return std::string("helper.getIntegerType(32)");
410 if (typeVar == "I64")
411 return std::string("helper.getIntegerType(64)");
412 if (typeVar == "F32")
413 return std::string("helper.getFloat32Type()");
414 if (typeVar == "F64")
415 return std::string("helper.getFloat64Type()");
416
417 // Search all argument types.
418 for (const auto &it : llvm::enumerate(First&: args)) {
419 if (it.value().kind != LinalgOperandDefKind::InputTensor &&
420 it.value().kind != LinalgOperandDefKind::Scalar &&
421 it.value().kind != LinalgOperandDefKind::OutputTensor)
422 continue;
423 if (*it.value().typeVar == typeVar)
424 return llvm::formatv(Fmt: "block.getArgument({0}).getType()", Vals: it.index())
425 .str();
426 }
427
428 return std::nullopt;
429}
430
431static ScalarAssign *findAssignment(StringRef name,
432 std::vector<ScalarAssign> &assignments) {
433 for (auto &assign : assignments) {
434 if (assign.arg == name)
435 return &assign;
436 }
437 return nullptr;
438}
439
440// Return true if the operand is a function attribute.
441static bool isFunctionAttribute(LinalgOperandDefKind kind) {
442 return kind == LinalgOperandDefKind::UnaryFnAttr ||
443 kind == LinalgOperandDefKind::BinaryFnAttr ||
444 kind == LinalgOperandDefKind::TypeFnAttr;
445}
446
447// Return true if the operand is an attribute.
448static bool isAttribute(LinalgOperandDefKind kind) {
449 return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind);
450}
451
452// Get the enum name for the given operand kind.
453std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
454 switch (kind) {
455 case LinalgOperandDefKind::UnaryFnAttr:
456 return std::string("UnaryFn");
457 case LinalgOperandDefKind::BinaryFnAttr:
458 return std::string("BinaryFn");
459 case LinalgOperandDefKind::TypeFnAttr:
460 return std::string("TypeFn");
461 default:
462 break;
463 }
464 llvm_unreachable("unsupported function attribute kind");
465}
466
467// Get the enum name for the given function kind.
468std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
469 switch (kind) {
470 case ScalarFnKind::Unary:
471 return std::string("UnaryFn");
472 case ScalarFnKind::Binary:
473 return std::string("BinaryFn");
474 case ScalarFnKind::Type:
475 return std::string("TypeFn");
476 }
477 llvm_unreachable("unsupported function kind");
478}
479
480//===----------------------------------------------------------------------===//
481// Templates
482//===----------------------------------------------------------------------===//
483
484// A single line banner format. Parameters:
485// {0}: Single line comment
486static const char bannerFormat[] = R"FMT(
487//===----------------------------------------------------------------------===//
488// {0}
489//===----------------------------------------------------------------------===//
490)FMT";
491
492//===----------------------------------------------------------------------===//
493// Named generic op generation.
494// These ops map at most a single contraction that complies with the limitations
495// of a linalg.generic.
496//===----------------------------------------------------------------------===//
497
498// Template for Linalg named ops' ODS definitions. Parameters:
499// {0}: ODS/C++ op name
500// {1}: assembly op mnemonic
501// {2}: op interface list
502// {3}: documentation (summary + description)
503// {4}: op attribute list
504// {5}: builder methods taking standalone attribute parameters
505// {6}: additional method defintions
506// {7}: additional methods for attributes used by indexing maps
507static const char structuredOpOdsHeaderFormat[] = R"FMT(
508//===----------------------------------------------------------------------===//
509// Op definition for {0}
510//===----------------------------------------------------------------------===//
511
512def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
513 /*extraInterfaces=*/[{2}])> {
514 {3}
515 let arguments = (ins
516 Variadic<AnyType>:$inputs,
517 Variadic<AnyShaped>:$outputs{4}
518 );
519 let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
520 let regions = (region AnyRegion:$region);
521
522 let skipDefaultBuilders = 1;
523 let builders = [
524 OpBuilder<
525 (ins "ValueRange":$inputs, "ValueRange":$outputs,
526 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
527 [{{
528 buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
529 attributes, {0}::getRegionBuilder());
530 }]>,
531 OpBuilder<
532 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
533 "ValueRange":$outputs,
534 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
535 [{{
536 buildStructuredOp($_builder, $_state, resultTensorTypes,
537 inputs, outputs, attributes, {0}::getRegionBuilder());
538 }]>,
539 OpBuilder<
540 (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
541 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
542 [{{
543 $_state.addOperands(operands);
544 $_state.addAttributes(attributes);
545 $_state.addTypes(resultTensorTypes);
546 (void)$_state.addRegion();
547 }]>
548 {5}
549 ];
550 let hasCustomAssemblyFormat = 1;
551 let hasFolder = 1;
552 {6}
553
554 let extraClassDeclaration = structuredOpsBaseDecls # [{{
555 // Auto-generated.
556 SmallVector<utils::IteratorType> getIteratorTypesArray();
557 ArrayAttr getIndexingMaps();
558 static void regionBuilder(ImplicitLocOpBuilder &b,
559 Block &block, ArrayRef<NamedAttribute> attrs);
560 static std::function<void(ImplicitLocOpBuilder &,
561 Block &, ArrayRef<NamedAttribute>)>
562 getRegionBuilder() {{
563 return regionBuilder;
564 }
565
566 ::mlir::MutableOperandRange getDpsInitsMutable() {{
567 return getOutputsMutable();
568 }
569
570 // Generic methods.
571 static unsigned getNumRegionArgs();
572 std::string getLibraryCallName();
573 {7}
574 }];
575}
576)FMT";
577
578// Builder method taking attribute parameters. Parameters:
579// {0}: Class name
580// {1}: Comma interleaved attribute parameters
581// {2}: Attribute initialization
582static const char structuredOpBuilderFormat[] = R"FMT(
583 , OpBuilder<
584 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
585 "ValueRange":$outputs, {1},
586 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
587 [{{
588 {2}
589 buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
590 attributes, {0}::getRegionBuilder());
591 }]>
592)FMT";
593
594// The getIteratorTypesArray() method for structured ops. Parameters:
595// {0}: Class name
596// {1}: Comma interleaved iterator type names.
597static const char structuredOpIteratorTypesFormat[] =
598 R"FMT(
599SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
600 return SmallVector<utils::IteratorType>{{ {1} };
601}
602)FMT";
603
604// The getIteratorTypesArray() method for rank polymorphic structured ops.
605// Parameters:
606// {0}: Class name
607static const char rankPolyStructuredOpIteratorTypesFormat[] =
608 R"FMT(
609SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{
610 int64_t rank = getRank(getDpsInitOperand(0));
611 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
612}
613)FMT";
614
615// The indexing_maps() method for structured ops. Parameters:
616// {0}: Class name
617// {1}: Comma-separated list of dimension variable names.
618// {2}: Statements
619static const char structuredOpIndexingMapsFormat[] = R"FMT(
620ArrayAttr {0}::getIndexingMaps() {{
621 static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
622 ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
623 if (cached)
624 return cached;
625
626 MLIRContext *context = getContext();
627 auto symbolBindings = getSymbolBindings(*this);
628 SmallVector<AffineMap> maps;
629 {2}
630 cached = Builder(context).getAffineMapArrayAttr(maps);
631 getOperation()->setAttr(memoizeAttr, cached);
632 return cached;
633}
634)FMT";
635
636// The indexing_maps() method for rank polymorphic structured ops. Parameters:
637// {0}: Class name
638static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT(
639ArrayAttr {0}::getIndexingMaps() {{
640 MLIRContext *context = getContext();
641 AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
642 AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
643 getNumParallelLoops(), context);
644 SmallVector<AffineMap> indexingMaps;
645 for (OpOperand &opOperand : getOperation()->getOpOperands())
646 indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap);
647 return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
648}
649)FMT";
650
651// Implementations of fold and getEffects.
652// Parameters:
653// {0}: Class name
654const char structuredOpFoldersFormat[] = R"FMT(
655LogicalResult {0}::fold(FoldAdaptor,
656 SmallVectorImpl<OpFoldResult> &) {{
657 return memref::foldMemRefCast(*this);
658}
659void {0}::getEffects(SmallVectorImpl<
660 SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
661 if (hasPureTensorSemantics()) return;
662 getGenericEffectsImpl(effects,
663 getOperation()->getResults(), getDpsInputs(), getDpsInits());
664}
665)FMT";
666
667// Implementation of parse/print.
668// Parameters:
669// {0}: Class name
670static const char structuredOpParserFormat[] = R"FMT(
671ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
672 return ::parseNamedStructuredOp(parser, result,
673 {0}::getNumRegionArgs(), {0}::getRegionBuilder());
674}
675void {0}::print(OpAsmPrinter &p) {{
676 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
677}
678)FMT";
679
680static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
681 GenerationContext &genContext) {
682 if (!genContext.shouldGenerateOds())
683 return success();
684
685 raw_ostream &os = genContext.odss();
686
687 std::string interfaceNameList;
688 std::string attrList;
689 std::string attrMethods;
690 std::string attrBuilder;
691
692 std::string doc;
693 if (opConfig.metadata->doc) {
694 static const char structuredOpDocFmt[] = R"FMT(
695 let summary = [{{{0}}];
696 let description = [{{{1}}];
697)FMT";
698 StringRef summary, description;
699 std::tie(args&: summary, args&: description) =
700 StringRef(*opConfig.metadata->doc).trim().split(Separator: "\n\n");
701
702 doc = llvm::formatv(Fmt: structuredOpDocFmt, Vals: summary.trim(), Vals: description.trim());
703 }
704
705 interfaceNameList = interleaveToString(container&: opConfig.metadata->implements, separator: ", ");
706
707 std::string definitionList;
708 for (const std::string &definition : opConfig.metadata->defines) {
709 static const char definitionFmt[] = "let {0} = 1;\n";
710 definitionList.append(str: llvm::formatv(Fmt: definitionFmt, Vals: definition));
711 }
712
713 if (llvm::any_of(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) {
714 return isAttribute(kind: arg.kind);
715 })) {
716 SmallVector<std::string> attrDefs;
717 SmallVector<std::string> attrParams;
718 SmallVector<std::string> attrStmts;
719 for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
720 static const char paramFmt[] = "\"Attribute\":${0}";
721 static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
722 // Add the type conversion attributes to the op definition and builders.
723 if (isFunctionAttribute(kind: arg.kind)) {
724 assert(arg.defaultFn);
725 std::string enumName = convertOperandKindToEnumName(kind: arg.kind);
726 static const char typeFmt[] = "{0}::{1}";
727 static const char defFmt[] =
728 "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}";
729 attrDefs.push_back(Elt: llvm::formatv(
730 Fmt: defFmt, Vals: llvm::formatv(Fmt: "{0}Attr", Vals&: enumName),
731 Vals: llvm::formatv(Fmt: typeFmt, Vals&: enumName, Vals&: arg.defaultFn), Vals&: arg.name));
732 attrParams.push_back(Elt: llvm::formatv(Fmt: paramFmt, Vals&: arg.name));
733 attrStmts.push_back(Elt: llvm::formatv(Fmt: stmtFmt, Vals&: arg.name));
734 }
735 // Add the index attributes to the op definition and builders.
736 if (arg.kind == LinalgOperandDefKind::IndexAttr) {
737 assert(arg.indexAttrMap.has_value());
738 assert(arg.defaultIndices.has_value());
739 size_t size = arg.indexAttrMap->affineMap().getNumResults();
740 assert(arg.defaultIndices->size() == size);
741 static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
742 static const char defFmt[] =
743 "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}";
744 std::string defaultVals;
745 llvm::raw_string_ostream ss(defaultVals);
746 llvm::interleave(
747 c: *arg.defaultIndices, os&: ss,
748 each_fn: [&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
749 separator: ", ");
750 attrDefs.push_back(Elt: llvm::formatv(Fmt: defFmt, Vals: llvm::formatv(Fmt: typeFmt, Vals&: size),
751 Vals&: ss.str(), Vals&: arg.name));
752 attrParams.push_back(Elt: llvm::formatv(Fmt: paramFmt, Vals&: arg.name));
753 attrStmts.push_back(Elt: llvm::formatv(Fmt: stmtFmt, Vals&: arg.name));
754 }
755 }
756 if (llvm::any_of(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) {
757 return arg.kind == LinalgOperandDefKind::IndexAttr;
758 })) {
759 attrMethods = R"(
760 bool hasDynamicIndexingMaps();
761 LogicalResult verifyIndexingMapRequiredAttributes();
762 )";
763 }
764 attrList = ",\n" + llvm::join(R&: attrDefs, Separator: ",\n");
765 attrBuilder = llvm::formatv(
766 Fmt: structuredOpBuilderFormat, Vals&: opConfig.metadata->cppClassName,
767 Vals: llvm::join(R&: attrParams, Separator: ", "), Vals: llvm::join(R&: attrStmts, Separator: "\n"));
768 }
769
770 os << llvm::formatv(Fmt: structuredOpOdsHeaderFormat,
771 Vals&: opConfig.metadata->cppClassName, Vals&: opConfig.metadata->name,
772 Vals&: interfaceNameList, Vals&: doc, Vals&: attrList, Vals&: attrBuilder,
773 Vals&: definitionList, Vals&: attrMethods);
774
775 return success();
776}
777
778static LogicalResult
779generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
780 GenerationContext &genContext) {
781 if (!genContext.shouldGenerateDefns())
782 return success();
783
784 raw_ostream &os = genContext.defns();
785 StringRef className = opConfig.metadata->cppClassName;
786
787 // Implementation banner.
788 std::string bannerComment = llvm::formatv(Fmt: "Implementation of {0}", Vals&: className);
789 os << llvm::formatv(Fmt: bannerFormat, Vals&: bannerComment);
790
791 // Compute the number of scalar and tensor arguments.
792 int64_t numOfArgs =
793 llvm::count_if(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) {
794 return arg.kind == LinalgOperandDefKind::InputTensor ||
795 arg.kind == LinalgOperandDefKind::Scalar ||
796 arg.kind == LinalgOperandDefKind::OutputTensor;
797 });
798
799 // An operation that accesses only scalars and scalar/rank zero tensors is
800 // rank polymorhpic. We implement rank polymorphism by generating different
801 // indexing maps and iterators that match the rank of the first output tensor.
802 // An operation is rank polymorphic if the iteration domain has rank zero.
803 bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty();
804
805 // Generate the iterator_types() method.
806 if (!isRankPolymorphic) {
807 std::string iteratorsStr;
808 llvm::raw_string_ostream ss(iteratorsStr);
809 llvm::interleaveComma(c: opConfig.structuredOp->iteratorTypes, os&: ss,
810 each_fn: [&](LinalgIteratorTypeDef it) {
811 switch (it) {
812 case LinalgIteratorTypeDef::parallel:
813 ss << "utils::IteratorType::parallel";
814 break;
815 case LinalgIteratorTypeDef::reduction:
816 ss << "utils::IteratorType::reduction";
817 break;
818 }
819 });
820 ss.flush();
821 os << llvm::formatv(Fmt: structuredOpIteratorTypesFormat, Vals&: className,
822 Vals&: iteratorsStr);
823 } else {
824 os << llvm::formatv(Fmt: rankPolyStructuredOpIteratorTypesFormat, Vals&: className);
825 }
826
827 // Generating the getIndexingMaps() method.
828 if (auto &staticMaps =
829 opConfig.structuredOp->indexingMaps.staticIndexingMaps) {
830 if (staticMaps->empty())
831 return emitError(loc: genContext.getLoc()) << "op has no indexing maps";
832 if (!isRankPolymorphic) {
833 AffineMap firstMap = staticMaps->front().affineMap();
834
835 // Symbol bindings.
836 {
837 // For each symbol, generate a declaration for it, either with an
838 // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
839 // an attribute).
840 // TODO: Possibly lift into a top-level method.
841 static const char structuredOpSymbolBindingsFormat[] = R"FMT(
842static SmallVector<AffineExpr> getSymbolBindings({0} self) {
843 MLIRContext *context = self.getContext();
844 SmallVector<AffineExpr> exprs;
845{1}
846 return exprs;
847}
848)FMT";
849
850 unsigned symbolCount = firstMap.getNumSymbols();
851 SmallVector<std::string> symbolBindings;
852 for (unsigned i = 0; i < symbolCount; ++i) {
853 symbolBindings.push_back(Elt: llvm::formatv(
854 Fmt: " exprs.push_back(getAffineSymbolExpr({0}, context));", Vals&: i));
855 }
856
857 // Access an index attribute. Parameters:
858 // {0}: Attribute name
859 // {1}: Symbol position
860 // {2}: Attribute index
861 static const char structuredOpAccessAttrFormat[] = R"FMT(
862int64_t cst{1} = self.get{0}().getValues<int64_t>()[{2}];
863exprs.push_back(getAffineConstantExpr(cst{1}, context));
864)FMT";
865 // Update all symbol bindings mapped to an attribute.
866 for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
867 if (arg.kind != LinalgOperandDefKind::IndexAttr)
868 continue;
869 assert(arg.indexAttrMap);
870 for (auto [idx, result] :
871 llvm::enumerate(First: arg.indexAttrMap->affineMap().getResults())) {
872 if (auto symbol = dyn_cast<AffineSymbolExpr>(Val: result)) {
873 std::string argName = arg.name;
874 argName[0] = toupper(c: argName[0]);
875 symbolBindings[symbol.getPosition()] =
876 llvm::formatv(Fmt: structuredOpAccessAttrFormat, Vals&: argName,
877 Vals: symbol.getPosition(), Vals&: idx);
878 }
879 }
880 }
881
882 std::string symbolBindingsStr;
883 llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
884 llvm::interleave(c: symbolBindings, os&: symbolBindingsSs, separator: "\n");
885 symbolBindingsSs.flush();
886
887 os << llvm::formatv(Fmt: structuredOpSymbolBindingsFormat, Vals&: className,
888 Vals&: symbolBindingsStr);
889 }
890
891 // Indexing maps.
892 {
893 unsigned dimCount = firstMap.getNumDims();
894
895 // Generate a comma-separated list of dim identifiers to be passed to
896 // bindDims, ensuring tht AffineExpr identifiers are bound in the right
897 // order to the proper AffineDimExpr.
898 // This results in vars in scope like: d0, d1, d2...
899 SmallVector<unsigned> dimIndices;
900 for (unsigned i = 0; i < dimCount; ++i)
901 dimIndices.push_back(Elt: i);
902 std::string dimIdentsStr;
903 llvm::raw_string_ostream dimIdentsSs(dimIdentsStr);
904 llvm::interleaveComma(c: dimIndices, os&: dimIdentsSs,
905 each_fn: [&](unsigned i) { dimIdentsSs << "d" << i; });
906 dimIdentsSs.flush();
907
908 // Statements to add and simplify each affine map.
909 SmallVector<std::string> stmts;
910 for (auto &indexingMap : *staticMaps) {
911 // TODO: Assert that dim and symbol count match the first.
912 stmts.push_back(
913 Elt: llvm::formatv(Fmt: "maps.push_back({0});",
914 Vals: generateCppExpression(self: indexingMap, contextName: "context")));
915 stmts.push_back(Elt: llvm::formatv(
916 Fmt: "maps.back() = "
917 "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
918 "symbolBindings, {0}, 0));",
919 Vals&: dimCount));
920 }
921
922 // TODO: This needs to be memoized and/or converted to non-parser based
923 // C++ codegen prior to real use.
924 os << llvm::formatv(Fmt: structuredOpIndexingMapsFormat, Vals&: className,
925 Vals&: dimIdentsStr, Vals: interleaveToString(container&: stmts, separator: "\n "));
926 }
927 } else {
928 os << llvm::formatv(Fmt: rankPolyStructuredOpIndexingMapsFormat, Vals&: className);
929 }
930 } else {
931 return emitError(loc: genContext.getLoc())
932 << "generating code for non static indexing maps not currently "
933 "supported";
934 }
935
936 // getNumRegionArgs()
937 {
938 // Generates a getNumRegionArgs() method. Parameters:
939 // {0}: Class name
940 // {1}: Number of region args
941 static const char structuredOpGetNumRegionArgsFormat[] = R"FMT(
942unsigned {0}::getNumRegionArgs() {{ return {1}; }
943)FMT";
944 os << llvm::formatv(Fmt: structuredOpGetNumRegionArgsFormat, Vals&: className,
945 Vals&: numOfArgs);
946 }
947
948 // getLibraryCallName()
949 {
950 // Generates a getLibraryCallName method. Parameters:
951 // {0}: Class name
952 static const char structuredOpGetLibraryCallFormat[] = R"FMT(
953std::string {0}::getLibraryCallName() {{
954 return generateLibraryCallName(getOperation());
955}
956)FMT";
957 os << llvm::formatv(Fmt: structuredOpGetLibraryCallFormat, Vals&: className);
958 }
959
960 // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
961 if (llvm::any_of(Range&: opConfig.structuredOp->args, P: [](LinalgOperandDef &arg) {
962 return arg.kind == LinalgOperandDefKind::IndexAttr;
963 })) {
964 std::vector<std::string> attrVerifications;
965 for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
966 if (arg.kind != LinalgOperandDefKind::IndexAttr)
967 continue;
968 assert(arg.indexAttrMap);
969 // Verify index attribute. Paramters:
970 // {0}: Attribute name
971 // {1}: Attribute size
972 static const char attrFmt[] = R"FMT(
973if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
974 if (!attr.getType().getElementType().isInteger(64))
975 return op->emitError("incorrect element type for index attribute '{0}'");
976 if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
977 return op->emitError("incorrect shape for index attribute '{0}'");
978}
979)FMT";
980 attrVerifications.push_back(x: llvm::formatv(
981 Fmt: attrFmt, Vals&: arg.name, Vals: arg.indexAttrMap->affineMap().getNumResults()));
982 }
983
984 // Generates the verifyIndexingMapRequiredAttributes method. Parameters:
985 // {0}: Class name
986 // {1}: Attribute verification
987 static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT(
988bool {0}::hasDynamicIndexingMaps() {{ return true; }
989LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
990 Operation *op = getOperation();
991 {1}
992 return success();
993}
994)FMT";
995 os << llvm::formatv(Fmt: structuredOpVerifyIndexingMapRequiredAttributes,
996 Vals&: className, Vals: llvm::join(R&: attrVerifications, Separator: "\n"));
997 }
998
999 // regionBuilder()
1000 {
1001 // Generates a regionBuilder method. Parameters.
1002 // {0}: Class name
1003 // {1}: Number of args
1004 // {2}: Attributes
1005 // {3}: Statements
1006 static const char structuredOpRegionBuilderFormat[] = R"FMT(
1007void {0}::regionBuilder(ImplicitLocOpBuilder &b,
1008 Block &block, ArrayRef<NamedAttribute> attrs) {{
1009 assert({1} > 0 && block.getNumArguments() == {1} &&
1010 "{0} regionBuilder expects {1} (>=0) args");
1011 RegionBuilderHelper helper(b, block);
1012 SmallVector<Value> yields;
1013 {2}
1014 {3}
1015 helper.yieldOutputs(yields);
1016}
1017)FMT";
1018 auto &args = opConfig.structuredOp->args;
1019 auto &assignments = opConfig.structuredOp->assignments;
1020 size_t generatedAssignmentCount = 0;
1021 int localCounter = 0;
1022 SmallVector<std::string> attrs;
1023 SmallVector<std::string> stmts;
1024 for (LinalgOperandDef &arg : args) {
1025 if (!isFunctionAttribute(kind: arg.kind))
1026 continue;
1027 // Obtain the type function attribute values. Parameters.
1028 // {0}: enum name
1029 // {1}: attribute name
1030 // {2}: default type function name
1031 static const char attrDef[] = R"FMT(
1032 {0} {1}Val = {0}::{2};
1033 auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
1034 return attr.getName() == "{1}"; });
1035 if ({1}Iter != attrs.end()) {{
1036 if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
1037 {1}Val = attr.getValue();
1038 }
1039)FMT";
1040 std::string enumName = convertOperandKindToEnumName(kind: arg.kind);
1041 attrs.push_back(
1042 Elt: llvm::formatv(Fmt: attrDef, Vals&: enumName, Vals&: arg.name, Vals&: arg.defaultFn));
1043 }
1044 for (LinalgOperandDef &arg : args) {
1045 if (arg.kind != LinalgOperandDefKind::OutputTensor)
1046 continue;
1047
1048 // Find the assignment that correlates with the argument.
1049 ScalarAssign *assignment = findAssignment(name: arg.name, assignments);
1050 if (!assignment)
1051 return emitError(loc: genContext.getLoc())
1052 << "no assignment found for output argument " << arg.name;
1053 ++generatedAssignmentCount;
1054
1055 // Recursively generate the expression.
1056 std::function<std::optional<std::string>(ScalarExpression &)>
1057 generateExpression =
1058 [&](ScalarExpression &expression) -> std::optional<std::string> {
1059 if (expression.arg) {
1060 // Argument reference.
1061 std::optional<int> argIndex =
1062 findTensorDefArgIndex(name: *expression.arg, args);
1063 if (!argIndex) {
1064 emitError(loc: genContext.getLoc())
1065 << "scalar argument not defined on the op: " << *expression.arg;
1066 return std::nullopt;
1067 }
1068 return std::string(
1069 llvm::formatv(Fmt: "block.getArgument({0})", Vals&: *argIndex));
1070 }
1071 if (expression.constant) {
1072 std::string cppIdent = llvm::formatv(Fmt: "value{0}", Vals&: ++localCounter);
1073 stmts.push_back(
1074 Elt: llvm::formatv(Fmt: R"FMT(Value {0} = helper.constant("{1}");)FMT",
1075 Vals&: cppIdent, Vals&: expression.constant));
1076 return cppIdent;
1077 }
1078 if (expression.index) {
1079 // Access an iteration index.
1080 std::string cppIdent = llvm::formatv(Fmt: "value{0}", Vals&: ++localCounter);
1081 stmts.push_back(Elt: llvm::formatv(Fmt: "Value {0} = helper.index({1});",
1082 Vals&: cppIdent, Vals&: *expression.index));
1083 return cppIdent;
1084 }
1085 if (expression.scalarFn) {
1086 std::string enumName =
1087 convertFunctionKindToEnumName(kind: expression.scalarFn->kind);
1088
1089 // Get the function or attribute name.
1090 assert(expression.scalarFn->fnName || expression.scalarFn->attrName);
1091 std::string funcType;
1092 if (expression.scalarFn->fnName) {
1093 funcType = llvm::formatv(Fmt: "{0}::{1}", Vals&: enumName,
1094 Vals&: *expression.scalarFn->fnName);
1095 }
1096 if (expression.scalarFn->attrName) {
1097 if (llvm::none_of(Range&: args, P: [&](LinalgOperandDef &arg) {
1098 return isFunctionAttribute(kind: arg.kind) &&
1099 arg.name == *expression.scalarFn->attrName;
1100 })) {
1101 emitError(loc: genContext.getLoc()) << "missing function attribute "
1102 << *expression.scalarFn->attrName;
1103 }
1104 funcType = llvm::formatv(Fmt: "{0}Val", Vals&: *expression.scalarFn->attrName);
1105 }
1106 assert(!funcType.empty());
1107
1108 // Add the optional type parameter to the operands.
1109 SmallVector<std::string> operandCppValues;
1110 if (expression.scalarFn->kind == ScalarFnKind::Type) {
1111 assert(expression.scalarFn->typeVar.has_value());
1112 std::optional<std::string> typeCppValue =
1113 findTypeValue(typeVar: *expression.scalarFn->typeVar, args);
1114 if (!typeCppValue) {
1115 emitError(loc: genContext.getLoc())
1116 << "type variable " << *expression.scalarFn->typeVar
1117 << ", used in a type conversion, must map to a predefined or "
1118 << "an argument type but it does not";
1119 return std::nullopt;
1120 }
1121 operandCppValues.push_back(Elt: *typeCppValue);
1122 }
1123
1124 // Collect the scalar operands.
1125 for (ScalarExpression &operand : expression.scalarFn->operands) {
1126 auto operandCppValue = generateExpression(operand);
1127 if (!operandCppValue)
1128 return std::nullopt;
1129 operandCppValues.push_back(Elt: *operandCppValue);
1130 }
1131
1132 // Call the function builder.
1133 std::string cppIdent = llvm::formatv(Fmt: "value{0}", Vals&: ++localCounter);
1134 stmts.push_back(Elt: llvm::formatv(
1135 Fmt: "Value {0} = helper.build{1}({2}, {3});", Vals&: cppIdent, Vals&: enumName,
1136 Vals&: funcType, Vals: interleaveToString(container&: operandCppValues, separator: ", ")));
1137 return cppIdent;
1138 }
1139 emitError(loc: genContext.getLoc()) << "unknown ScalarExpression type";
1140 return std::nullopt;
1141 };
1142 std::optional<std::string> cppValue =
1143 generateExpression(assignment->value);
1144 if (!cppValue)
1145 return failure();
1146 stmts.push_back(Elt: llvm::formatv(Fmt: "yields.push_back({0});", Vals&: *cppValue));
1147 }
1148
1149 if (generatedAssignmentCount != assignments.size())
1150 return emitError(loc: genContext.getLoc())
1151 << "mismatched number of assignments vs output arguments";
1152
1153 os << llvm::formatv(Fmt: structuredOpRegionBuilderFormat, Vals&: className, Vals&: numOfArgs,
1154 Vals: interleaveToString(container&: attrs, separator: "\n "),
1155 Vals: interleaveToString(container&: stmts, separator: "\n "));
1156 }
1157
1158 // Parser and printer.
1159 os << llvm::formatv(Fmt: structuredOpParserFormat, Vals&: className);
1160
1161 // Canonicalizers and folders.
1162 os << llvm::formatv(Fmt: structuredOpFoldersFormat, Vals&: className);
1163
1164 return success();
1165}
1166
1167static LogicalResult generateOp(LinalgOpConfig &opConfig,
1168 GenerationContext &genContext) {
1169 // Switch on op type being generated.
1170 if (opConfig.structuredOp) {
1171 return success(
1172 isSuccess: succeeded(result: generateNamedGenericOpOds(opConfig, genContext)) &&
1173 succeeded(result: generateNamedGenericOpDefns(opConfig, genContext)));
1174 }
1175 return emitError(loc: genContext.getLoc()) << "unsupported operation type";
1176}
1177
1178//===----------------------------------------------------------------------===//
1179// Command line options and main
1180//===----------------------------------------------------------------------===//
1181
1182static llvm::cl::opt<std::string>
1183 inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
1184 llvm::cl::init(Val: "-"), llvm::cl::value_desc("YAML filename"));
1185
1186static llvm::cl::opt<std::string>
1187 outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
1188 llvm::cl::value_desc("filename"), llvm::cl::init(Val: ""));
1189
1190static llvm::cl::opt<std::string>
1191 outputCppImplFilename("o-impl",
1192 llvm::cl::desc("C++ implementation file name"),
1193 llvm::cl::value_desc("filename"), llvm::cl::init(Val: ""));
1194
1195int main(int argc, char **argv) {
1196 llvm::cl::ParseCommandLineOptions(argc, argv, Overview: "Linalg ODS Gen from YAML");
1197
1198 // Set up the input file.
1199 std::string errorMessage;
1200 std::unique_ptr<llvm::MemoryBuffer> file =
1201 mlir::openInputFile(inputFilename, errorMessage: &errorMessage);
1202 if (!file) {
1203 llvm::errs() << errorMessage << "\n";
1204 return 1;
1205 }
1206
1207 MLIRContext mlirContext;
1208 LinalgYAMLContext yamlContext{.mlirContext: &mlirContext};
1209
1210 std::vector<LinalgOpConfig> opConfigs;
1211
1212 // Parse input.
1213 Input yin(file->getBuffer(), &yamlContext);
1214 yin >> opConfigs;
1215
1216 if (yin.error())
1217 return 1;
1218
1219 // Open output files.
1220 std::unique_ptr<llvm::ToolOutputFile> outputOdsDecl;
1221 if (!outputOdsDeclFilename.empty()) {
1222 outputOdsDecl = openOutputFile(outputFilename: outputOdsDeclFilename, errorMessage: &errorMessage);
1223 if (!outputOdsDecl) {
1224 llvm::errs() << errorMessage << "\n";
1225 return 1;
1226 }
1227 }
1228
1229 std::unique_ptr<llvm::ToolOutputFile> outputCppImpl;
1230 if (!outputCppImplFilename.empty()) {
1231 outputCppImpl = openOutputFile(outputFilename: outputCppImplFilename, errorMessage: &errorMessage);
1232 if (!outputCppImpl) {
1233 llvm::errs() << errorMessage << "\n";
1234 return 1;
1235 }
1236 }
1237
1238 if (!outputOdsDecl && !outputCppImpl) {
1239 llvm::errs() << "error: No output files specified\n";
1240 return 1;
1241 }
1242
1243 // Generate.
1244 GenerationContext genContext(&mlirContext,
1245 outputOdsDecl ? &outputOdsDecl->os() : nullptr,
1246 outputCppImpl ? &outputCppImpl->os() : nullptr);
1247
1248 for (auto &opConfig : opConfigs) {
1249 if (!opConfig.metadata) {
1250 emitError(loc: genContext.getLoc())
1251 << "missing operation metadata on subsequent op";
1252 return 1;
1253 }
1254
1255 genContext.setLoc(NameLoc::get(
1256 StringAttr::get(&mlirContext, opConfig.metadata->cppClassName)));
1257 if (failed(result: generateOp(opConfig, genContext))) {
1258 return 1;
1259 }
1260 }
1261
1262 if (outputOdsDecl)
1263 outputOdsDecl->keep();
1264 if (outputCppImpl)
1265 outputCppImpl->keep();
1266
1267 return 0;
1268}
1269

source code of mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp