1//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
10#include "mlir/IR/Builders.h"
11#include "mlir/Interfaces/FunctionImplementation.h"
12
13using namespace mlir;
14using namespace mlir::ml_program;
15
16//===----------------------------------------------------------------------===//
17// Custom asm helpers
18//===----------------------------------------------------------------------===//
19
20/// Parse and print an ordering clause for a variadic of consuming tokens
21/// and an producing token.
22///
23/// Syntax:
24/// ordering(%0, %1 -> !ml_program.token)
25/// ordering(() -> !ml_program.token)
26///
27/// If both the consuming and producing token are not present on the op, then
28/// the clause prints nothing.
29static ParseResult parseTokenOrdering(
30 OpAsmParser &parser,
31 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
32 Type &produceTokenType) {
33 if (failed(result: parser.parseOptionalKeyword(keyword: "ordering")) ||
34 failed(result: parser.parseLParen()))
35 return success();
36
37 // Parse consuming token list. If there are no consuming tokens, the
38 // '()' null list represents this.
39 if (succeeded(result: parser.parseOptionalLParen())) {
40 if (failed(result: parser.parseRParen()))
41 return failure();
42 } else {
43 if (failed(result: parser.parseOperandList(result&: consumeTokens,
44 /*requiredOperandCount=*/-1)))
45 return failure();
46 }
47
48 // Parse producer token.
49 if (failed(result: parser.parseArrow()))
50 return failure();
51 if (failed(result: parser.parseType(result&: produceTokenType)))
52 return failure();
53
54 if (failed(result: parser.parseRParen()))
55 return failure();
56
57 return success();
58}
59
60static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
61 OperandRange consumeTokens,
62 Type produceTokenType) {
63 if (consumeTokens.empty() && !produceTokenType)
64 return;
65
66 p << " ordering(";
67 if (consumeTokens.empty())
68 p << "()";
69 else
70 p.printOperands(container: consumeTokens);
71 if (produceTokenType) {
72 p << " -> ";
73 p.printType(type: produceTokenType);
74 }
75 p << ")";
76}
77
78/// some.op custom<TypeOrAttr>($type, $attr)
79///
80/// Uninitialized:
81/// some.op : tensor<3xi32>
82/// Initialized to narrower type than op:
83/// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
84static ParseResult parseTypedInitialValue(OpAsmParser &parser,
85 TypeAttr &typeAttr, Attribute &attr) {
86 if (succeeded(result: parser.parseOptionalLParen())) {
87 if (failed(result: parser.parseAttribute(result&: attr)))
88 return failure();
89 if (failed(result: parser.parseRParen()))
90 return failure();
91 }
92
93 Type type;
94 if (failed(result: parser.parseColonType(result&: type)))
95 return failure();
96 typeAttr = TypeAttr::get(type);
97 return success();
98}
99
100static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
101 TypeAttr type, Attribute attr) {
102 if (attr) {
103 p << "(";
104 p.printAttribute(attr);
105 p << ")";
106 }
107
108 p << " : ";
109 p.printAttribute(attr: type);
110}
111
112/// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
113/// ->
114/// some.op public @foo
115/// some.op private @foo
116static ParseResult parseSymbolVisibility(OpAsmParser &parser,
117 StringAttr &symVisibilityAttr) {
118 StringRef symVisibility;
119 (void)parser.parseOptionalKeyword(keyword: &symVisibility,
120 allowedValues: {"public", "private", "nested"});
121 if (symVisibility.empty())
122 return parser.emitError(loc: parser.getCurrentLocation())
123 << "expected 'public', 'private', or 'nested'";
124 if (!symVisibility.empty())
125 symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
126 return success();
127}
128
129static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
130 StringAttr symVisibilityAttr) {
131 if (!symVisibilityAttr)
132 p << "public";
133 else
134 p << symVisibilityAttr.getValue();
135}
136
137//===----------------------------------------------------------------------===//
138// TableGen'd op method definitions
139//===----------------------------------------------------------------------===//
140
141#define GET_OP_CLASSES
142#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
143
144//===----------------------------------------------------------------------===//
145// FuncOp
146//===----------------------------------------------------------------------===//
147
148ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
149 auto buildFuncType =
150 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
151 function_interface_impl::VariadicFlag,
152 std::string &) { return builder.getFunctionType(argTypes, results); };
153
154 return function_interface_impl::parseFunctionOp(
155 parser, result, /*allowVariadic=*/false,
156 getFunctionTypeAttrName(result.name), buildFuncType,
157 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
158}
159
160void FuncOp::print(OpAsmPrinter &p) {
161 function_interface_impl::printFunctionOp(
162 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
163 getArgAttrsAttrName(), getResAttrsAttrName());
164}
165
166//===----------------------------------------------------------------------===//
167// GlobalOp
168//===----------------------------------------------------------------------===//
169
170LogicalResult GlobalOp::verify() {
171 if (!getIsMutable() && !getValue())
172 return emitOpError() << "immutable global must have an initial value";
173 return success();
174}
175
176//===----------------------------------------------------------------------===//
177// GlobalLoadOp
178//===----------------------------------------------------------------------===//
179
180GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
181 for (auto *parent = getOperation()->getParentOp(); parent;
182 parent = parent->getParentOp()) {
183 if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
184 parent, getGlobalAttr())) {
185 return nearest;
186 }
187 }
188 return {};
189}
190
191LogicalResult
192GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
193 GlobalOp referrent = getGlobalOp(symbolTable);
194 if (!referrent)
195 return emitOpError() << "undefined global: " << getGlobal();
196
197 if (referrent.getType() != getResult().getType()) {
198 return emitOpError() << "cannot load from global typed "
199 << referrent.getType() << " as "
200 << getResult().getType();
201 }
202
203 return success();
204}
205
206//===----------------------------------------------------------------------===//
207// GlobalLoadConstOp
208//===----------------------------------------------------------------------===//
209
210GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
211 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
212 getOperation()->getParentOp(), getGlobalAttr());
213}
214
215LogicalResult
216GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
217 GlobalOp referrent = getGlobalOp(symbolTable);
218 if (!referrent)
219 return emitOpError() << "undefined global: " << getGlobal();
220
221 if (referrent.getIsMutable())
222 return emitOpError() << "cannot load as const from mutable global "
223 << getGlobal();
224
225 if (referrent.getType() != getResult().getType())
226 return emitOpError() << "cannot load from global typed "
227 << referrent.getType() << " as "
228 << getResult().getType();
229
230 return success();
231}
232
233//===----------------------------------------------------------------------===//
234// GlobalLoadGraphOp
235//===----------------------------------------------------------------------===//
236
237GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
238 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
239 getOperation()->getParentOp(), getGlobalAttr());
240}
241
242LogicalResult
243GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
244 GlobalOp referrent = getGlobalOp(symbolTable);
245 if (!referrent)
246 return emitOpError() << "undefined global: " << getGlobal();
247
248 if (referrent.getType() != getResult().getType()) {
249 return emitOpError() << "cannot load from global typed "
250 << referrent.getType() << " as "
251 << getResult().getType();
252 }
253
254 return success();
255}
256
257//===----------------------------------------------------------------------===//
258// GlobalStoreOp
259//===----------------------------------------------------------------------===//
260
261GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
262 for (auto *parent = getOperation()->getParentOp(); parent;) {
263 if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
264 parent, getGlobalAttr())) {
265 return nearest;
266 }
267 parent = parent->getParentOp();
268 }
269 return {};
270}
271
272LogicalResult
273GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
274 GlobalOp referrent = getGlobalOp(symbolTable);
275 if (!referrent)
276 return emitOpError() << "undefined global: " << getGlobal();
277
278 if (!referrent.getIsMutable()) {
279 return emitOpError() << "cannot store to an immutable global "
280 << getGlobal();
281 }
282
283 if (referrent.getType() != getValue().getType()) {
284 return emitOpError() << "cannot store to a global typed "
285 << referrent.getType() << " from "
286 << getValue().getType();
287 }
288
289 return success();
290}
291
292//===----------------------------------------------------------------------===//
293// GlobalStoreGraphOp
294//===----------------------------------------------------------------------===//
295
296GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
297 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
298 getOperation()->getParentOp(), getGlobalAttr());
299}
300
301LogicalResult
302GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
303 GlobalOp referrent = getGlobalOp(symbolTable);
304 if (!referrent)
305 return emitOpError() << "undefined global: " << getGlobal();
306
307 if (!referrent.getIsMutable()) {
308 return emitOpError() << "cannot store to an immutable global "
309 << getGlobal();
310 }
311
312 if (referrent.getType() != getValue().getType()) {
313 return emitOpError() << "cannot store to a global typed "
314 << referrent.getType() << " from "
315 << getValue().getType();
316 }
317
318 return success();
319}
320
321//===----------------------------------------------------------------------===//
322// SubgraphOp
323//===----------------------------------------------------------------------===//
324
325ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
326 auto buildFuncType =
327 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
328 function_interface_impl::VariadicFlag,
329 std::string &) { return builder.getFunctionType(argTypes, results); };
330
331 return function_interface_impl::parseFunctionOp(
332 parser, result, /*allowVariadic=*/false,
333 getFunctionTypeAttrName(result.name), buildFuncType,
334 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
335}
336
337void SubgraphOp::print(OpAsmPrinter &p) {
338 function_interface_impl::printFunctionOp(
339 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
340 getArgAttrsAttrName(), getResAttrsAttrName());
341}
342
343//===----------------------------------------------------------------------===//
344// OutputOp
345//===----------------------------------------------------------------------===//
346
347LogicalResult OutputOp::verify() {
348 auto function = cast<SubgraphOp>((*this)->getParentOp());
349
350 // The operand number and types must match the function signature.
351 const auto &results = function.getFunctionType().getResults();
352 if (getNumOperands() != results.size())
353 return emitOpError("has ")
354 << getNumOperands() << " operands, but enclosing function (@"
355 << function.getName() << ") outputs " << results.size();
356
357 for (unsigned i = 0, e = results.size(); i != e; ++i)
358 if (getOperand(i).getType() != results[i])
359 return emitError() << "type of output operand " << i << " ("
360 << getOperand(i).getType()
361 << ") doesn't match function result type ("
362 << results[i] << ")"
363 << " in function @" << function.getName();
364
365 return success();
366}
367
368//===----------------------------------------------------------------------===//
369// ReturnOp
370//===----------------------------------------------------------------------===//
371
372LogicalResult ReturnOp::verify() {
373 auto function = cast<FuncOp>((*this)->getParentOp());
374
375 // The operand number and types must match the function signature.
376 const auto &results = function.getFunctionType().getResults();
377 if (getNumOperands() != results.size())
378 return emitOpError("has ")
379 << getNumOperands() << " operands, but enclosing function (@"
380 << function.getName() << ") returns " << results.size();
381
382 for (unsigned i = 0, e = results.size(); i != e; ++i)
383 if (getOperand(i).getType() != results[i])
384 return emitError() << "type of return operand " << i << " ("
385 << getOperand(i).getType()
386 << ") doesn't match function result type ("
387 << results[i] << ")"
388 << " in function @" << function.getName();
389
390 return success();
391}
392

source code of mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp