1//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a simple IR generation targeting MLIR from a Module AST
10// for the Toy language.
11//
12//===----------------------------------------------------------------------===//
13
14#include "toy/MLIRGen.h"
15#include "mlir/IR/Block.h"
16#include "mlir/IR/Diagnostics.h"
17#include "mlir/IR/Value.h"
18#include "toy/AST.h"
19#include "toy/Dialect.h"
20
21#include "mlir/IR/Builders.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/Verifier.h"
26#include "toy/Lexer.h"
27
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/ScopedHashTable.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/ADT/StringRef.h"
32#include "llvm/ADT/Twine.h"
33#include <cassert>
34#include <cstdint>
35#include <functional>
36#include <numeric>
37#include <optional>
38#include <vector>
39
40using namespace mlir::toy;
41using namespace toy;
42
43using llvm::ArrayRef;
44using llvm::cast;
45using llvm::dyn_cast;
46using llvm::isa;
47using llvm::ScopedHashTableScope;
48using llvm::SmallVector;
49using llvm::StringRef;
50using llvm::Twine;
51
52namespace {
53
54/// Implementation of a simple MLIR emission from the Toy AST.
55///
56/// This will emit operations that are specific to the Toy language, preserving
57/// the semantics of the language and (hopefully) allow to perform accurate
58/// analysis and transformation based on these high level semantics.
59class MLIRGenImpl {
60public:
61 MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
62
63 /// Public API: convert the AST for a Toy module (source file) to an MLIR
64 /// Module operation.
65 mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
66 // We create an empty MLIR module and codegen functions one at a time and
67 // add them to the module.
68 theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
69
70 for (FunctionAST &f : moduleAST)
71 mlirGen(f);
72
73 // Verify the module after we have finished constructing it, this will check
74 // the structural properties of the IR and invoke any specific verifiers we
75 // have on the Toy operations.
76 if (failed(mlir::verify(theModule))) {
77 theModule.emitError("module verification error");
78 return nullptr;
79 }
80
81 return theModule;
82 }
83
84private:
85 /// A "module" matches a Toy source file: containing a list of functions.
86 mlir::ModuleOp theModule;
87
88 /// The builder is a helper class to create IR inside a function. The builder
89 /// is stateful, in particular it keeps an "insertion point": this is where
90 /// the next operations will be introduced.
91 mlir::OpBuilder builder;
92
93 /// The symbol table maps a variable name to a value in the current scope.
94 /// Entering a function creates a new scope, and the function arguments are
95 /// added to the mapping. When the processing of a function is terminated, the
96 /// scope is destroyed and the mappings created in this scope are dropped.
97 llvm::ScopedHashTable<StringRef, mlir::Value> symbolTable;
98
99 /// Helper conversion for a Toy AST location to an MLIR location.
100 mlir::Location loc(const Location &loc) {
101 return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line,
102 loc.col);
103 }
104
105 /// Declare a variable in the current scope, return success if the variable
106 /// wasn't declared yet.
107 llvm::LogicalResult declare(llvm::StringRef var, mlir::Value value) {
108 if (symbolTable.count(Key: var))
109 return mlir::failure();
110 symbolTable.insert(Key: var, Val: value);
111 return mlir::success();
112 }
113
114 /// Create the prototype for an MLIR function with as many arguments as the
115 /// provided Toy AST prototype.
116 mlir::toy::FuncOp mlirGen(PrototypeAST &proto) {
117 auto location = loc(loc: proto.loc());
118
119 // This is a generic function, the return type will be inferred later.
120 // Arguments type are uniformly unranked tensors.
121 llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
122 getType(type: VarType{}));
123 auto funcType = builder.getFunctionType(argTypes, std::nullopt);
124 return builder.create<mlir::toy::FuncOp>(location, proto.getName(),
125 funcType);
126 }
127
128 /// Emit a new function and add it to the MLIR module.
129 mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
130 // Create a scope in the symbol table to hold variable declarations.
131 ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
132
133 // Create an MLIR function for the given prototype.
134 builder.setInsertionPointToEnd(theModule.getBody());
135 mlir::toy::FuncOp function = mlirGen(*funcAST.getProto());
136 if (!function)
137 return nullptr;
138
139 // Let's start the body of the function now!
140 mlir::Block &entryBlock = function.front();
141 auto protoArgs = funcAST.getProto()->getArgs();
142
143 // Declare all the function arguments in the symbol table.
144 for (const auto nameValue :
145 llvm::zip(protoArgs, entryBlock.getArguments())) {
146 if (failed(declare(std::get<0>(nameValue)->getName(),
147 std::get<1>(nameValue))))
148 return nullptr;
149 }
150
151 // Set the insertion point in the builder to the beginning of the function
152 // body, it will be used throughout the codegen to create operations in this
153 // function.
154 builder.setInsertionPointToStart(&entryBlock);
155
156 // Emit the body of the function.
157 if (mlir::failed(Result: mlirGen(blockAST&: *funcAST.getBody()))) {
158 function.erase();
159 return nullptr;
160 }
161
162 // Implicitly return void if no return statement was emitted.
163 // FIXME: we may fix the parser instead to always return the last expression
164 // (this would possibly help the REPL case later)
165 ReturnOp returnOp;
166 if (!entryBlock.empty())
167 returnOp = dyn_cast<ReturnOp>(entryBlock.back());
168 if (!returnOp) {
169 builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
170 } else if (returnOp.hasOperand()) {
171 // Otherwise, if this return operation has an operand then add a result to
172 // the function.
173 function.setType(builder.getFunctionType(
174 inputs: function.getFunctionType().getInputs(), results: getType(type: VarType{})));
175 }
176
177 return function;
178 }
179
180 /// Emit a binary operation
181 mlir::Value mlirGen(BinaryExprAST &binop) {
182 // First emit the operations for each side of the operation before emitting
183 // the operation itself. For example if the expression is `a + foo(a)`
184 // 1) First it will visiting the LHS, which will return a reference to the
185 // value holding `a`. This value should have been emitted at declaration
186 // time and registered in the symbol table, so nothing would be
187 // codegen'd. If the value is not in the symbol table, an error has been
188 // emitted and nullptr is returned.
189 // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
190 // and the result value is returned. If an error occurs we get a nullptr
191 // and propagate.
192 //
193 mlir::Value lhs = mlirGen(expr&: *binop.getLHS());
194 if (!lhs)
195 return nullptr;
196 mlir::Value rhs = mlirGen(expr&: *binop.getRHS());
197 if (!rhs)
198 return nullptr;
199 auto location = loc(loc: binop.loc());
200
201 // Derive the operation name from the binary operator. At the moment we only
202 // support '+' and '*'.
203 switch (binop.getOp()) {
204 case '+':
205 return builder.create<AddOp>(location, lhs, rhs);
206 case '*':
207 return builder.create<MulOp>(location, lhs, rhs);
208 }
209
210 emitError(loc: location, message: "invalid binary operator '") << binop.getOp() << "'";
211 return nullptr;
212 }
213
214 /// This is a reference to a variable in an expression. The variable is
215 /// expected to have been declared and so should have a value in the symbol
216 /// table, otherwise emit an error and return nullptr.
217 mlir::Value mlirGen(VariableExprAST &expr) {
218 if (auto variable = symbolTable.lookup(Key: expr.getName()))
219 return variable;
220
221 emitError(loc: loc(loc: expr.loc()), message: "error: unknown variable '")
222 << expr.getName() << "'";
223 return nullptr;
224 }
225
226 /// Emit a return operation. This will return failure if any generation fails.
227 llvm::LogicalResult mlirGen(ReturnExprAST &ret) {
228 auto location = loc(loc: ret.loc());
229
230 // 'return' takes an optional expression, handle that case here.
231 mlir::Value expr = nullptr;
232 if (ret.getExpr().has_value()) {
233 if (!(expr = mlirGen(expr&: **ret.getExpr())))
234 return mlir::failure();
235 }
236
237 // Otherwise, this return operation has zero operands.
238 builder.create<ReturnOp>(location,
239 expr ? ArrayRef(expr) : ArrayRef<mlir::Value>());
240 return mlir::success();
241 }
242
243 /// Emit a literal/constant array. It will be emitted as a flattened array of
244 /// data in an Attribute attached to a `toy.constant` operation.
245 /// See documentation on [Attributes](LangRef.md#attributes) for more details.
246 /// Here is an excerpt:
247 ///
248 /// Attributes are the mechanism for specifying constant data in MLIR in
249 /// places where a variable is never allowed [...]. They consist of a name
250 /// and a concrete attribute value. The set of expected attributes, their
251 /// structure, and their interpretation are all contextually dependent on
252 /// what they are attached to.
253 ///
254 /// Example, the source level statement:
255 /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
256 /// will be converted to:
257 /// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
258 /// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
259 /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
260 ///
261 mlir::Value mlirGen(LiteralExprAST &lit) {
262 auto type = getType(shape: lit.getDims());
263
264 // The attribute is a vector with a floating point value per element
265 // (number) in the array, see `collectData()` below for more details.
266 std::vector<double> data;
267 data.reserve(n: std::accumulate(first: lit.getDims().begin(), last: lit.getDims().end(), init: 1,
268 binary_op: std::multiplies<int>()));
269 collectData(expr&: lit, data);
270
271 // The type of this attribute is tensor of 64-bit floating-point with the
272 // shape of the literal.
273 mlir::Type elementType = builder.getF64Type();
274 auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType);
275
276 // This is the actual attribute that holds the list of values for this
277 // tensor literal.
278 auto dataAttribute =
279 mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data));
280
281 // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
282 // method.
283 return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
284 }
285
286 /// Recursive helper function to accumulate the data that compose an array
287 /// literal. It flattens the nested structure in the supplied vector. For
288 /// example with this array:
289 /// [[1, 2], [3, 4]]
290 /// we will generate:
291 /// [ 1, 2, 3, 4 ]
292 /// Individual numbers are represented as doubles.
293 /// Attributes are the way MLIR attaches constant to operations.
294 void collectData(ExprAST &expr, std::vector<double> &data) {
295 if (auto *lit = dyn_cast<LiteralExprAST>(Val: &expr)) {
296 for (auto &value : lit->getValues())
297 collectData(expr&: *value, data);
298 return;
299 }
300
301 assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
302 data.push_back(x: cast<NumberExprAST>(Val&: expr).getValue());
303 }
304
305 /// Emit a call expression. It emits specific operations for the `transpose`
306 /// builtin. Other identifiers are assumed to be user-defined functions.
307 mlir::Value mlirGen(CallExprAST &call) {
308 llvm::StringRef callee = call.getCallee();
309 auto location = loc(loc: call.loc());
310
311 // Codegen the operands first.
312 SmallVector<mlir::Value, 4> operands;
313 for (auto &expr : call.getArgs()) {
314 auto arg = mlirGen(expr&: *expr);
315 if (!arg)
316 return nullptr;
317 operands.push_back(Elt: arg);
318 }
319
320 // Builtin calls have their custom operation, meaning this is a
321 // straightforward emission.
322 if (callee == "transpose") {
323 if (call.getArgs().size() != 1) {
324 emitError(loc: location, message: "MLIR codegen encountered an error: toy.transpose "
325 "does not accept multiple arguments");
326 return nullptr;
327 }
328 return builder.create<TransposeOp>(location, operands[0]);
329 }
330
331 // Otherwise this is a call to a user-defined function. Calls to
332 // user-defined functions are mapped to a custom call that takes the callee
333 // name as an attribute.
334 return builder.create<GenericCallOp>(location, callee, operands);
335 }
336
337 /// Emit a print expression. It emits specific operations for two builtins:
338 /// transpose(x) and print(x).
339 llvm::LogicalResult mlirGen(PrintExprAST &call) {
340 auto arg = mlirGen(expr&: *call.getArg());
341 if (!arg)
342 return mlir::failure();
343
344 builder.create<PrintOp>(loc(call.loc()), arg);
345 return mlir::success();
346 }
347
348 /// Emit a constant for a single number (FIXME: semantic? broadcast?)
349 mlir::Value mlirGen(NumberExprAST &num) {
350 return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
351 }
352
353 /// Dispatch codegen for the right expression subclass using RTTI.
354 mlir::Value mlirGen(ExprAST &expr) {
355 switch (expr.getKind()) {
356 case toy::ExprAST::Expr_BinOp:
357 return mlirGen(binop&: cast<BinaryExprAST>(Val&: expr));
358 case toy::ExprAST::Expr_Var:
359 return mlirGen(expr&: cast<VariableExprAST>(Val&: expr));
360 case toy::ExprAST::Expr_Literal:
361 return mlirGen(lit&: cast<LiteralExprAST>(Val&: expr));
362 case toy::ExprAST::Expr_Call:
363 return mlirGen(call&: cast<CallExprAST>(Val&: expr));
364 case toy::ExprAST::Expr_Num:
365 return mlirGen(num&: cast<NumberExprAST>(Val&: expr));
366 default:
367 emitError(loc: loc(loc: expr.loc()))
368 << "MLIR codegen encountered an unhandled expr kind '"
369 << Twine(expr.getKind()) << "'";
370 return nullptr;
371 }
372 }
373
374 /// Handle a variable declaration, we'll codegen the expression that forms the
375 /// initializer and record the value in the symbol table before returning it.
376 /// Future expressions will be able to reference this variable through symbol
377 /// table lookup.
378 mlir::Value mlirGen(VarDeclExprAST &vardecl) {
379 auto *init = vardecl.getInitVal();
380 if (!init) {
381 emitError(loc: loc(loc: vardecl.loc()),
382 message: "missing initializer in variable declaration");
383 return nullptr;
384 }
385
386 mlir::Value value = mlirGen(expr&: *init);
387 if (!value)
388 return nullptr;
389
390 // We have the initializer value, but in case the variable was declared
391 // with specific shape, we emit a "reshape" operation. It will get
392 // optimized out later as needed.
393 if (!vardecl.getType().shape.empty()) {
394 value = builder.create<ReshapeOp>(loc(vardecl.loc()),
395 getType(vardecl.getType()), value);
396 }
397
398 // Register the value in the symbol table.
399 if (failed(Result: declare(var: vardecl.getName(), value)))
400 return nullptr;
401 return value;
402 }
403
404 /// Codegen a list of expression, return failure if one of them hit an error.
405 llvm::LogicalResult mlirGen(ExprASTList &blockAST) {
406 ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
407 for (auto &expr : blockAST) {
408 // Specific handling for variable declarations, return statement, and
409 // print. These can only appear in block list and not in nested
410 // expressions.
411 if (auto *vardecl = dyn_cast<VarDeclExprAST>(Val: expr.get())) {
412 if (!mlirGen(vardecl&: *vardecl))
413 return mlir::failure();
414 continue;
415 }
416 if (auto *ret = dyn_cast<ReturnExprAST>(Val: expr.get()))
417 return mlirGen(ret&: *ret);
418 if (auto *print = dyn_cast<PrintExprAST>(Val: expr.get())) {
419 if (mlir::failed(Result: mlirGen(call&: *print)))
420 return mlir::success();
421 continue;
422 }
423
424 // Generic expression dispatch codegen.
425 if (!mlirGen(expr&: *expr))
426 return mlir::failure();
427 }
428 return mlir::success();
429 }
430
431 /// Build a tensor type from a list of shape dimensions.
432 mlir::Type getType(ArrayRef<int64_t> shape) {
433 // If the shape is empty, then this type is unranked.
434 if (shape.empty())
435 return mlir::UnrankedTensorType::get(builder.getF64Type());
436
437 // Otherwise, we use the given shape.
438 return mlir::RankedTensorType::get(shape, builder.getF64Type());
439 }
440
441 /// Build an MLIR type from a Toy AST variable type (forward to the generic
442 /// getType above).
443 mlir::Type getType(const VarType &type) { return getType(shape: type.shape); }
444};
445
446} // namespace
447
448namespace toy {
449
450// The public API for codegen.
451mlir::OwningOpRef<mlir::ModuleOp> mlirGen(mlir::MLIRContext &context,
452 ModuleAST &moduleAST) {
453 return MLIRGenImpl(context).mlirGen(moduleAST);
454}
455
456} // namespace toy
457

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/examples/toy/Ch3/mlir/MLIRGen.cpp