1 | //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/MLProgram/IR/MLProgram.h" |
10 | #include "mlir/IR/Builders.h" |
11 | #include "mlir/Interfaces/FunctionImplementation.h" |
12 | |
13 | using namespace mlir; |
14 | using namespace mlir::ml_program; |
15 | |
16 | //===----------------------------------------------------------------------===// |
17 | // Custom asm helpers |
18 | //===----------------------------------------------------------------------===// |
19 | |
20 | /// Parse and print an ordering clause for a variadic of consuming tokens |
21 | /// and an producing token. |
22 | /// |
23 | /// Syntax: |
24 | /// ordering(%0, %1 -> !ml_program.token) |
25 | /// ordering(() -> !ml_program.token) |
26 | /// |
27 | /// If both the consuming and producing token are not present on the op, then |
28 | /// the clause prints nothing. |
29 | static ParseResult parseTokenOrdering( |
30 | OpAsmParser &parser, |
31 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens, |
32 | Type &produceTokenType) { |
33 | if (failed(result: parser.parseOptionalKeyword(keyword: "ordering" )) || |
34 | failed(result: parser.parseLParen())) |
35 | return success(); |
36 | |
37 | // Parse consuming token list. If there are no consuming tokens, the |
38 | // '()' null list represents this. |
39 | if (succeeded(result: parser.parseOptionalLParen())) { |
40 | if (failed(result: parser.parseRParen())) |
41 | return failure(); |
42 | } else { |
43 | if (failed(result: parser.parseOperandList(result&: consumeTokens, |
44 | /*requiredOperandCount=*/-1))) |
45 | return failure(); |
46 | } |
47 | |
48 | // Parse producer token. |
49 | if (failed(result: parser.parseArrow())) |
50 | return failure(); |
51 | if (failed(result: parser.parseType(result&: produceTokenType))) |
52 | return failure(); |
53 | |
54 | if (failed(result: parser.parseRParen())) |
55 | return failure(); |
56 | |
57 | return success(); |
58 | } |
59 | |
60 | static void printTokenOrdering(OpAsmPrinter &p, Operation *op, |
61 | OperandRange consumeTokens, |
62 | Type produceTokenType) { |
63 | if (consumeTokens.empty() && !produceTokenType) |
64 | return; |
65 | |
66 | p << " ordering(" ; |
67 | if (consumeTokens.empty()) |
68 | p << "()" ; |
69 | else |
70 | p.printOperands(container: consumeTokens); |
71 | if (produceTokenType) { |
72 | p << " -> " ; |
73 | p.printType(type: produceTokenType); |
74 | } |
75 | p << ")" ; |
76 | } |
77 | |
78 | /// some.op custom<TypeOrAttr>($type, $attr) |
79 | /// |
80 | /// Uninitialized: |
81 | /// some.op : tensor<3xi32> |
82 | /// Initialized to narrower type than op: |
83 | /// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32> |
84 | static ParseResult parseTypedInitialValue(OpAsmParser &parser, |
85 | TypeAttr &typeAttr, Attribute &attr) { |
86 | if (succeeded(result: parser.parseOptionalLParen())) { |
87 | if (failed(result: parser.parseAttribute(result&: attr))) |
88 | return failure(); |
89 | if (failed(result: parser.parseRParen())) |
90 | return failure(); |
91 | } |
92 | |
93 | Type type; |
94 | if (failed(result: parser.parseColonType(result&: type))) |
95 | return failure(); |
96 | typeAttr = TypeAttr::get(type); |
97 | return success(); |
98 | } |
99 | |
100 | static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, |
101 | TypeAttr type, Attribute attr) { |
102 | if (attr) { |
103 | p << "(" ; |
104 | p.printAttribute(attr); |
105 | p << ")" ; |
106 | } |
107 | |
108 | p << " : " ; |
109 | p.printAttribute(attr: type); |
110 | } |
111 | |
112 | /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name |
113 | /// -> |
114 | /// some.op public @foo |
115 | /// some.op private @foo |
116 | static ParseResult parseSymbolVisibility(OpAsmParser &parser, |
117 | StringAttr &symVisibilityAttr) { |
118 | StringRef symVisibility; |
119 | (void)parser.parseOptionalKeyword(keyword: &symVisibility, |
120 | allowedValues: {"public" , "private" , "nested" }); |
121 | if (symVisibility.empty()) |
122 | return parser.emitError(loc: parser.getCurrentLocation()) |
123 | << "expected 'public', 'private', or 'nested'" ; |
124 | if (!symVisibility.empty()) |
125 | symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); |
126 | return success(); |
127 | } |
128 | |
129 | static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, |
130 | StringAttr symVisibilityAttr) { |
131 | if (!symVisibilityAttr) |
132 | p << "public" ; |
133 | else |
134 | p << symVisibilityAttr.getValue(); |
135 | } |
136 | |
137 | //===----------------------------------------------------------------------===// |
138 | // TableGen'd op method definitions |
139 | //===----------------------------------------------------------------------===// |
140 | |
141 | #define GET_OP_CLASSES |
142 | #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" |
143 | |
144 | //===----------------------------------------------------------------------===// |
145 | // FuncOp |
146 | //===----------------------------------------------------------------------===// |
147 | |
148 | ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
149 | auto buildFuncType = |
150 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
151 | function_interface_impl::VariadicFlag, |
152 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
153 | |
154 | return function_interface_impl::parseFunctionOp( |
155 | parser, result, /*allowVariadic=*/false, |
156 | getFunctionTypeAttrName(result.name), buildFuncType, |
157 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
158 | } |
159 | |
160 | void FuncOp::print(OpAsmPrinter &p) { |
161 | function_interface_impl::printFunctionOp( |
162 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
163 | getArgAttrsAttrName(), getResAttrsAttrName()); |
164 | } |
165 | |
166 | //===----------------------------------------------------------------------===// |
167 | // GlobalOp |
168 | //===----------------------------------------------------------------------===// |
169 | |
170 | LogicalResult GlobalOp::verify() { |
171 | if (!getIsMutable() && !getValue()) |
172 | return emitOpError() << "immutable global must have an initial value" ; |
173 | return success(); |
174 | } |
175 | |
176 | //===----------------------------------------------------------------------===// |
177 | // GlobalLoadOp |
178 | //===----------------------------------------------------------------------===// |
179 | |
180 | GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) { |
181 | for (auto *parent = getOperation()->getParentOp(); parent; |
182 | parent = parent->getParentOp()) { |
183 | if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>( |
184 | parent, getGlobalAttr())) { |
185 | return nearest; |
186 | } |
187 | } |
188 | return {}; |
189 | } |
190 | |
191 | LogicalResult |
192 | GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
193 | GlobalOp referrent = getGlobalOp(symbolTable); |
194 | if (!referrent) |
195 | return emitOpError() << "undefined global: " << getGlobal(); |
196 | |
197 | if (referrent.getType() != getResult().getType()) { |
198 | return emitOpError() << "cannot load from global typed " |
199 | << referrent.getType() << " as " |
200 | << getResult().getType(); |
201 | } |
202 | |
203 | return success(); |
204 | } |
205 | |
206 | //===----------------------------------------------------------------------===// |
207 | // GlobalLoadConstOp |
208 | //===----------------------------------------------------------------------===// |
209 | |
210 | GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) { |
211 | return symbolTable.lookupNearestSymbolFrom<GlobalOp>( |
212 | getOperation()->getParentOp(), getGlobalAttr()); |
213 | } |
214 | |
215 | LogicalResult |
216 | GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
217 | GlobalOp referrent = getGlobalOp(symbolTable); |
218 | if (!referrent) |
219 | return emitOpError() << "undefined global: " << getGlobal(); |
220 | |
221 | if (referrent.getIsMutable()) |
222 | return emitOpError() << "cannot load as const from mutable global " |
223 | << getGlobal(); |
224 | |
225 | if (referrent.getType() != getResult().getType()) |
226 | return emitOpError() << "cannot load from global typed " |
227 | << referrent.getType() << " as " |
228 | << getResult().getType(); |
229 | |
230 | return success(); |
231 | } |
232 | |
233 | //===----------------------------------------------------------------------===// |
234 | // GlobalLoadGraphOp |
235 | //===----------------------------------------------------------------------===// |
236 | |
237 | GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { |
238 | return symbolTable.lookupNearestSymbolFrom<GlobalOp>( |
239 | getOperation()->getParentOp(), getGlobalAttr()); |
240 | } |
241 | |
242 | LogicalResult |
243 | GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
244 | GlobalOp referrent = getGlobalOp(symbolTable); |
245 | if (!referrent) |
246 | return emitOpError() << "undefined global: " << getGlobal(); |
247 | |
248 | if (referrent.getType() != getResult().getType()) { |
249 | return emitOpError() << "cannot load from global typed " |
250 | << referrent.getType() << " as " |
251 | << getResult().getType(); |
252 | } |
253 | |
254 | return success(); |
255 | } |
256 | |
257 | //===----------------------------------------------------------------------===// |
258 | // GlobalStoreOp |
259 | //===----------------------------------------------------------------------===// |
260 | |
261 | GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) { |
262 | for (auto *parent = getOperation()->getParentOp(); parent;) { |
263 | if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>( |
264 | parent, getGlobalAttr())) { |
265 | return nearest; |
266 | } |
267 | parent = parent->getParentOp(); |
268 | } |
269 | return {}; |
270 | } |
271 | |
272 | LogicalResult |
273 | GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
274 | GlobalOp referrent = getGlobalOp(symbolTable); |
275 | if (!referrent) |
276 | return emitOpError() << "undefined global: " << getGlobal(); |
277 | |
278 | if (!referrent.getIsMutable()) { |
279 | return emitOpError() << "cannot store to an immutable global " |
280 | << getGlobal(); |
281 | } |
282 | |
283 | if (referrent.getType() != getValue().getType()) { |
284 | return emitOpError() << "cannot store to a global typed " |
285 | << referrent.getType() << " from " |
286 | << getValue().getType(); |
287 | } |
288 | |
289 | return success(); |
290 | } |
291 | |
292 | //===----------------------------------------------------------------------===// |
293 | // GlobalStoreGraphOp |
294 | //===----------------------------------------------------------------------===// |
295 | |
296 | GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { |
297 | return symbolTable.lookupNearestSymbolFrom<GlobalOp>( |
298 | getOperation()->getParentOp(), getGlobalAttr()); |
299 | } |
300 | |
301 | LogicalResult |
302 | GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
303 | GlobalOp referrent = getGlobalOp(symbolTable); |
304 | if (!referrent) |
305 | return emitOpError() << "undefined global: " << getGlobal(); |
306 | |
307 | if (!referrent.getIsMutable()) { |
308 | return emitOpError() << "cannot store to an immutable global " |
309 | << getGlobal(); |
310 | } |
311 | |
312 | if (referrent.getType() != getValue().getType()) { |
313 | return emitOpError() << "cannot store to a global typed " |
314 | << referrent.getType() << " from " |
315 | << getValue().getType(); |
316 | } |
317 | |
318 | return success(); |
319 | } |
320 | |
321 | //===----------------------------------------------------------------------===// |
322 | // SubgraphOp |
323 | //===----------------------------------------------------------------------===// |
324 | |
325 | ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { |
326 | auto buildFuncType = |
327 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
328 | function_interface_impl::VariadicFlag, |
329 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
330 | |
331 | return function_interface_impl::parseFunctionOp( |
332 | parser, result, /*allowVariadic=*/false, |
333 | getFunctionTypeAttrName(result.name), buildFuncType, |
334 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
335 | } |
336 | |
337 | void SubgraphOp::print(OpAsmPrinter &p) { |
338 | function_interface_impl::printFunctionOp( |
339 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
340 | getArgAttrsAttrName(), getResAttrsAttrName()); |
341 | } |
342 | |
343 | //===----------------------------------------------------------------------===// |
344 | // OutputOp |
345 | //===----------------------------------------------------------------------===// |
346 | |
347 | LogicalResult OutputOp::verify() { |
348 | auto function = cast<SubgraphOp>((*this)->getParentOp()); |
349 | |
350 | // The operand number and types must match the function signature. |
351 | const auto &results = function.getFunctionType().getResults(); |
352 | if (getNumOperands() != results.size()) |
353 | return emitOpError("has " ) |
354 | << getNumOperands() << " operands, but enclosing function (@" |
355 | << function.getName() << ") outputs " << results.size(); |
356 | |
357 | for (unsigned i = 0, e = results.size(); i != e; ++i) |
358 | if (getOperand(i).getType() != results[i]) |
359 | return emitError() << "type of output operand " << i << " (" |
360 | << getOperand(i).getType() |
361 | << ") doesn't match function result type (" |
362 | << results[i] << ")" |
363 | << " in function @" << function.getName(); |
364 | |
365 | return success(); |
366 | } |
367 | |
368 | //===----------------------------------------------------------------------===// |
369 | // ReturnOp |
370 | //===----------------------------------------------------------------------===// |
371 | |
372 | LogicalResult ReturnOp::verify() { |
373 | auto function = cast<FuncOp>((*this)->getParentOp()); |
374 | |
375 | // The operand number and types must match the function signature. |
376 | const auto &results = function.getFunctionType().getResults(); |
377 | if (getNumOperands() != results.size()) |
378 | return emitOpError("has " ) |
379 | << getNumOperands() << " operands, but enclosing function (@" |
380 | << function.getName() << ") returns " << results.size(); |
381 | |
382 | for (unsigned i = 0, e = results.size(); i != e; ++i) |
383 | if (getOperand(i).getType() != results[i]) |
384 | return emitError() << "type of return operand " << i << " (" |
385 | << getOperand(i).getType() |
386 | << ") doesn't match function result type (" |
387 | << results[i] << ")" |
388 | << " in function @" << function.getName(); |
389 | |
390 | return success(); |
391 | } |
392 | |