1 | //===-- CUFOps.cpp --------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/Dialect/CUF/CUFOps.h" |
14 | #include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h" |
15 | #include "flang/Optimizer/Dialect/CUF/CUFDialect.h" |
16 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
17 | #include "flang/Optimizer/Dialect/FIRType.h" |
18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
20 | #include "mlir/IR/Attributes.h" |
21 | #include "mlir/IR/BuiltinAttributes.h" |
22 | #include "mlir/IR/BuiltinOps.h" |
23 | #include "mlir/IR/Diagnostics.h" |
24 | #include "mlir/IR/Matchers.h" |
25 | #include "mlir/IR/OpDefinition.h" |
26 | #include "mlir/IR/PatternMatch.h" |
27 | #include "llvm/ADT/SmallVector.h" |
28 | |
29 | //===----------------------------------------------------------------------===// |
30 | // AllocOp |
31 | //===----------------------------------------------------------------------===// |
32 | |
33 | static mlir::Type wrapAllocaResultType(mlir::Type intype) { |
34 | if (mlir::isa<fir::ReferenceType>(intype)) |
35 | return {}; |
36 | return fir::ReferenceType::get(intype); |
37 | } |
38 | |
39 | void cuf::AllocOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
40 | mlir::Type inType, llvm::StringRef uniqName, |
41 | llvm::StringRef bindcName, |
42 | cuf::DataAttributeAttr cudaAttr, |
43 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
44 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
45 | mlir::StringAttr nameAttr = |
46 | uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); |
47 | mlir::StringAttr bindcAttr = |
48 | bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); |
49 | build(builder, result, wrapAllocaResultType(inType), |
50 | mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape, |
51 | cudaAttr); |
52 | result.addAttributes(attributes); |
53 | } |
54 | |
55 | template <typename Op> |
56 | static llvm::LogicalResult checkCudaAttr(Op op) { |
57 | if (op.getDataAttr() == cuf::DataAttribute::Device || |
58 | op.getDataAttr() == cuf::DataAttribute::Managed || |
59 | op.getDataAttr() == cuf::DataAttribute::Unified || |
60 | op.getDataAttr() == cuf::DataAttribute::Pinned || |
61 | op.getDataAttr() == cuf::DataAttribute::Shared) |
62 | return mlir::success(); |
63 | return op.emitOpError() |
64 | << "expect device, managed, pinned or unified cuda attribute" ; |
65 | } |
66 | |
67 | llvm::LogicalResult cuf::AllocOp::verify() { return checkCudaAttr(*this); } |
68 | |
69 | //===----------------------------------------------------------------------===// |
70 | // FreeOp |
71 | //===----------------------------------------------------------------------===// |
72 | |
73 | llvm::LogicalResult cuf::FreeOp::verify() { return checkCudaAttr(*this); } |
74 | |
75 | //===----------------------------------------------------------------------===// |
76 | // AllocateOp |
77 | //===----------------------------------------------------------------------===// |
78 | |
79 | template <typename OpTy> |
80 | static llvm::LogicalResult checkStreamType(OpTy op) { |
81 | if (!op.getStream()) |
82 | return mlir::success(); |
83 | if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getStream().getType())) |
84 | if (!refTy.getEleTy().isInteger(64)) |
85 | return op.emitOpError("stream is expected to be an i64 reference" ); |
86 | return mlir::success(); |
87 | } |
88 | |
89 | llvm::LogicalResult cuf::AllocateOp::verify() { |
90 | if (getPinned() && getStream()) |
91 | return emitOpError("pinned and stream cannot appears at the same time" ); |
92 | if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType()))) |
93 | return emitOpError( |
94 | "expect box to be a reference to a class or box type value" ); |
95 | if (getSource() && |
96 | !mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getSource().getType()))) |
97 | return emitOpError( |
98 | "expect source to be a reference to/or a class or box type value" ); |
99 | if (getErrmsg() && |
100 | !mlir::isa<fir::BoxType>(fir::unwrapRefType(getErrmsg().getType()))) |
101 | return emitOpError( |
102 | "expect errmsg to be a reference to/or a box type value" ); |
103 | if (getErrmsg() && !getHasStat()) |
104 | return emitOpError("expect stat attribute when errmsg is provided" ); |
105 | return checkStreamType(*this); |
106 | } |
107 | |
108 | //===----------------------------------------------------------------------===// |
109 | // DataTransferOp |
110 | //===----------------------------------------------------------------------===// |
111 | |
112 | llvm::LogicalResult cuf::DataTransferOp::verify() { |
113 | mlir::Type srcTy = getSrc().getType(); |
114 | mlir::Type dstTy = getDst().getType(); |
115 | if (getShape()) { |
116 | if (!fir::isa_ref_type(srcTy) && !fir::isa_ref_type(dstTy)) |
117 | return emitOpError() |
118 | << "shape can only be specified on data transfer with references" ; |
119 | } |
120 | if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) || |
121 | (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) || |
122 | (fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) || |
123 | (fir::isa_box_type(srcTy) && fir::isa_ref_type(dstTy))) |
124 | return mlir::success(); |
125 | if (fir::isa_trivial(srcTy) && |
126 | matchPattern(getSrc().getDefiningOp(), mlir::m_Constant())) |
127 | return mlir::success(); |
128 | |
129 | return emitOpError() |
130 | << "expect src and dst to be references or descriptors or src to " |
131 | "be a constant: " |
132 | << srcTy << " - " << dstTy; |
133 | } |
134 | |
135 | //===----------------------------------------------------------------------===// |
136 | // DeallocateOp |
137 | //===----------------------------------------------------------------------===// |
138 | |
139 | llvm::LogicalResult cuf::DeallocateOp::verify() { |
140 | if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType()))) |
141 | return emitOpError( |
142 | "expect box to be a reference to class or box type value" ); |
143 | if (getErrmsg() && |
144 | !mlir::isa<fir::BoxType>(fir::unwrapRefType(getErrmsg().getType()))) |
145 | return emitOpError( |
146 | "expect errmsg to be a reference to/or a box type value" ); |
147 | if (getErrmsg() && !getHasStat()) |
148 | return emitOpError("expect stat attribute when errmsg is provided" ); |
149 | return mlir::success(); |
150 | } |
151 | |
152 | //===----------------------------------------------------------------------===// |
153 | // KernelLaunchOp |
154 | //===----------------------------------------------------------------------===// |
155 | |
156 | llvm::LogicalResult cuf::KernelLaunchOp::verify() { |
157 | return checkStreamType(*this); |
158 | } |
159 | |
160 | //===----------------------------------------------------------------------===// |
161 | // KernelOp |
162 | //===----------------------------------------------------------------------===// |
163 | |
164 | llvm::SmallVector<mlir::Region *> cuf::KernelOp::getLoopRegions() { |
165 | return {&getRegion()}; |
166 | } |
167 | |
168 | mlir::ParseResult parseCUFKernelValues( |
169 | mlir::OpAsmParser &parser, |
170 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values, |
171 | llvm::SmallVectorImpl<mlir::Type> &types) { |
172 | if (mlir::succeeded(Result: parser.parseOptionalStar())) |
173 | return mlir::success(); |
174 | |
175 | if (mlir::succeeded(Result: parser.parseOptionalLParen())) { |
176 | if (mlir::failed(Result: parser.parseCommaSeparatedList( |
177 | delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() { |
178 | if (parser.parseOperand(result&: values.emplace_back())) |
179 | return mlir::failure(); |
180 | return mlir::success(); |
181 | }))) |
182 | return mlir::failure(); |
183 | auto builder = parser.getBuilder(); |
184 | for (size_t i = 0; i < values.size(); i++) { |
185 | types.emplace_back(Args: builder.getI32Type()); |
186 | } |
187 | if (parser.parseRParen()) |
188 | return mlir::failure(); |
189 | } else { |
190 | if (parser.parseOperand(result&: values.emplace_back())) |
191 | return mlir::failure(); |
192 | auto builder = parser.getBuilder(); |
193 | types.emplace_back(Args: builder.getI32Type()); |
194 | return mlir::success(); |
195 | } |
196 | return mlir::success(); |
197 | } |
198 | |
199 | void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op, |
200 | mlir::ValueRange values, mlir::TypeRange types) { |
201 | if (values.empty()) |
202 | p << "*" ; |
203 | |
204 | if (values.size() > 1) |
205 | p << "(" ; |
206 | llvm::interleaveComma(c: values, os&: p, each_fn: [&p](mlir::Value v) { p << v; }); |
207 | if (values.size() > 1) |
208 | p << ")" ; |
209 | } |
210 | |
211 | mlir::ParseResult parseCUFKernelLoopControl( |
212 | mlir::OpAsmParser &parser, mlir::Region ®ion, |
213 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound, |
214 | llvm::SmallVectorImpl<mlir::Type> &lowerboundType, |
215 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound, |
216 | llvm::SmallVectorImpl<mlir::Type> &upperboundType, |
217 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &step, |
218 | llvm::SmallVectorImpl<mlir::Type> &stepType) { |
219 | |
220 | llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars; |
221 | if (parser.parseLParen() || |
222 | parser.parseArgumentList(result&: inductionVars, |
223 | delimiter: mlir::OpAsmParser::Delimiter::None, |
224 | /*allowType=*/true) || |
225 | parser.parseRParen() || parser.parseEqual() || parser.parseLParen() || |
226 | parser.parseOperandList(result&: lowerbound, requiredOperandCount: inductionVars.size(), |
227 | delimiter: mlir::OpAsmParser::Delimiter::None) || |
228 | parser.parseColonTypeList(result&: lowerboundType) || parser.parseRParen() || |
229 | parser.parseKeyword(keyword: "to" ) || parser.parseLParen() || |
230 | parser.parseOperandList(result&: upperbound, requiredOperandCount: inductionVars.size(), |
231 | delimiter: mlir::OpAsmParser::Delimiter::None) || |
232 | parser.parseColonTypeList(result&: upperboundType) || parser.parseRParen() || |
233 | parser.parseKeyword(keyword: "step" ) || parser.parseLParen() || |
234 | parser.parseOperandList(result&: step, requiredOperandCount: inductionVars.size(), |
235 | delimiter: mlir::OpAsmParser::Delimiter::None) || |
236 | parser.parseColonTypeList(result&: stepType) || parser.parseRParen()) |
237 | return mlir::failure(); |
238 | return parser.parseRegion(region, arguments: inductionVars); |
239 | } |
240 | |
241 | void printCUFKernelLoopControl( |
242 | mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Region ®ion, |
243 | mlir::ValueRange lowerbound, mlir::TypeRange lowerboundType, |
244 | mlir::ValueRange upperbound, mlir::TypeRange upperboundType, |
245 | mlir::ValueRange steps, mlir::TypeRange stepType) { |
246 | mlir::ValueRange regionArgs = region.front().getArguments(); |
247 | if (!regionArgs.empty()) { |
248 | p << "(" ; |
249 | llvm::interleaveComma( |
250 | c: regionArgs, os&: p, each_fn: [&p](mlir::Value v) { p << v << " : " << v.getType(); }); |
251 | p << ") = (" << lowerbound << " : " << lowerboundType << ") to (" |
252 | << upperbound << " : " << upperboundType << ") " |
253 | << " step (" << steps << " : " << stepType << ") " ; |
254 | } |
255 | p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false); |
256 | } |
257 | |
258 | llvm::LogicalResult cuf::KernelOp::verify() { |
259 | if (getLowerbound().size() != getUpperbound().size() || |
260 | getLowerbound().size() != getStep().size()) |
261 | return emitOpError( |
262 | "expect same number of values in lowerbound, upperbound and step" ); |
263 | auto reduceAttrs = getReduceAttrs(); |
264 | std::size_t reduceAttrsSize = reduceAttrs ? reduceAttrs->size() : 0; |
265 | if (getReduceOperands().size() != reduceAttrsSize) |
266 | return emitOpError("expect same number of values in reduce operands and " |
267 | "reduce attributes" ); |
268 | if (reduceAttrs) { |
269 | for (const auto &attr : reduceAttrs.value()) { |
270 | if (!mlir::isa<fir::ReduceAttr>(attr)) |
271 | return emitOpError("expect reduce attributes to be ReduceAttr" ); |
272 | } |
273 | } |
274 | return checkStreamType(*this); |
275 | } |
276 | |
277 | //===----------------------------------------------------------------------===// |
278 | // RegisterKernelOp |
279 | //===----------------------------------------------------------------------===// |
280 | |
281 | mlir::StringAttr cuf::RegisterKernelOp::getKernelModuleName() { |
282 | return getName().getRootReference(); |
283 | } |
284 | |
285 | mlir::StringAttr cuf::RegisterKernelOp::getKernelName() { |
286 | return getName().getLeafReference(); |
287 | } |
288 | |
289 | mlir::LogicalResult cuf::RegisterKernelOp::verify() { |
290 | if (getKernelName() == getKernelModuleName()) |
291 | return emitOpError("expect a module and a kernel name" ); |
292 | |
293 | auto mod = getOperation()->getParentOfType<mlir::ModuleOp>(); |
294 | if (!mod) |
295 | return emitOpError("expect to be in a module" ); |
296 | |
297 | mlir::SymbolTable symTab(mod); |
298 | auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(getKernelModuleName()); |
299 | if (!gpuMod) { |
300 | // If already a gpu.binary then stop the check here. |
301 | if (symTab.lookup<mlir::gpu::BinaryOp>(getKernelModuleName())) |
302 | return mlir::success(); |
303 | return emitOpError("gpu module not found" ); |
304 | } |
305 | |
306 | mlir::SymbolTable gpuSymTab(gpuMod); |
307 | if (auto func = gpuSymTab.lookup<mlir::gpu::GPUFuncOp>(getKernelName())) { |
308 | if (!func.isKernel()) |
309 | return emitOpError("only kernel gpu.func can be registered" ); |
310 | return mlir::success(); |
311 | } else if (auto func = |
312 | gpuSymTab.lookup<mlir::LLVM::LLVMFuncOp>(getKernelName())) { |
313 | if (!func->getAttrOfType<mlir::UnitAttr>( |
314 | mlir::gpu::GPUDialect::getKernelFuncAttrName())) |
315 | return emitOpError("only gpu.kernel llvm.func can be registered" ); |
316 | return mlir::success(); |
317 | } |
318 | return emitOpError("device function not found" ); |
319 | } |
320 | |
321 | //===----------------------------------------------------------------------===// |
322 | // SharedMemoryOp |
323 | //===----------------------------------------------------------------------===// |
324 | |
325 | void cuf::SharedMemoryOp::build( |
326 | mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, |
327 | llvm::StringRef uniqName, llvm::StringRef bindcName, |
328 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
329 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
330 | mlir::StringAttr nameAttr = |
331 | uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); |
332 | mlir::StringAttr bindcAttr = |
333 | bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); |
334 | build(builder, result, wrapAllocaResultType(inType), |
335 | mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape, |
336 | /*offset=*/mlir::Value{}); |
337 | result.addAttributes(attributes); |
338 | } |
339 | |
340 | //===----------------------------------------------------------------------===// |
341 | // StreamCastOp |
342 | //===----------------------------------------------------------------------===// |
343 | |
344 | llvm::LogicalResult cuf::StreamCastOp::verify() { |
345 | return checkStreamType(*this); |
346 | } |
347 | |
348 | // Tablegen operators |
349 | |
350 | #define GET_OP_CLASSES |
351 | #include "flang/Optimizer/Dialect/CUF/CUFOps.cpp.inc" |
352 | |