1 | //===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" |
10 | #include "llvm/ADT/StringExtras.h" |
11 | |
12 | using namespace mlir; |
13 | using namespace mlir::tosa; |
14 | |
15 | TosaProfileCompliance::TosaProfileCompliance() { |
16 | const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1}; |
17 | const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4}; |
18 | const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8}; |
19 | const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16}; |
20 | const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32}; |
21 | const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48}; |
22 | const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16}; |
23 | const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16}; |
24 | const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32}; |
25 | const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8}; |
26 | const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8}; |
27 | |
28 | // The profile-based compliance content below is auto-generated by a script |
29 | // in https://git.mlplatform.org/tosa/specification.git |
30 | #include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc" |
31 | // End of auto-generated metadata |
32 | } |
33 | |
34 | template <> |
35 | OperationProfileComplianceMap TosaProfileCompliance::getProfileComplianceMap() { |
36 | return profileComplianceMap; |
37 | } |
38 | |
39 | template <> |
40 | OperationExtensionComplianceMap |
41 | TosaProfileCompliance::getProfileComplianceMap() { |
42 | return extensionComplianceMap; |
43 | } |
44 | |
45 | // Base populating function |
46 | LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands, |
47 | Value output) { |
48 | for (auto operand : operands) |
49 | addValue(v: operand); |
50 | addValue(v: output); |
51 | return success(); |
52 | } |
53 | |
54 | template <> |
55 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) { |
56 | addValue(v: op.getInput1().front()); |
57 | addValue(v: op.getOutput()); |
58 | return success(); |
59 | } |
60 | |
61 | template <> |
62 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) { |
63 | addValue(v: op.getInput()); |
64 | addValue(v: op.getInputZp()); |
65 | addValue(v: op.getOutputZp()); |
66 | addType(t: op.getAccType()); |
67 | addValue(v: op.getOutput()); |
68 | return success(); |
69 | } |
70 | |
71 | template <typename T> |
72 | LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) { |
73 | addValue(v: op.getInput()); |
74 | addValue(v: op.getWeight()); |
75 | addValue(v: op.getBias()); |
76 | addValue(v: op.getInputZp()); |
77 | addValue(v: op.getWeightZp()); |
78 | addType(t: op.getAccType()); |
79 | addValue(v: op.getOutput()); |
80 | return success(); |
81 | } |
82 | |
83 | template <> |
84 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) { |
85 | return populateProfileInfoConv(op); |
86 | } |
87 | |
88 | template <> |
89 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) { |
90 | return populateProfileInfoConv(op); |
91 | } |
92 | |
93 | template <> |
94 | LogicalResult |
95 | ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) { |
96 | return populateProfileInfoConv(op); |
97 | } |
98 | |
99 | template <> |
100 | LogicalResult |
101 | ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) { |
102 | return populateProfileInfoConv(op); |
103 | } |
104 | |
105 | template <> |
106 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) { |
107 | addValue(v: op.getInput1()); |
108 | addValue(v: op.getPadConst()); |
109 | addValue(v: op.getOutput()); |
110 | return success(); |
111 | } |
112 | |
113 | template <typename T> |
114 | LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) { |
115 | addValue(v: op.getInput1()); |
116 | addValue(v: op.getOutput()); |
117 | return success(); |
118 | } |
119 | |
120 | template <> |
121 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) { |
122 | return populateProfileInfoDataLayout(op); |
123 | } |
124 | |
125 | template <> |
126 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) { |
127 | return populateProfileInfoDataLayout(op); |
128 | } |
129 | |
130 | template <> |
131 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) { |
132 | return populateProfileInfoDataLayout(op); |
133 | } |
134 | |
135 | template <> |
136 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) { |
137 | return populateProfileInfoDataLayout(op); |
138 | } |
139 | |
140 | template <> |
141 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) { |
142 | addValue(v: op.getValues()); |
143 | addValue(v: op.getIndices()); |
144 | addValue(v: op.getOutput()); |
145 | return success(); |
146 | } |
147 | |
148 | template <> |
149 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) { |
150 | addValue(v: op.getValuesIn()); |
151 | addValue(v: op.getIndices()); |
152 | addValue(v: op.getInput()); |
153 | addValue(v: op.getValuesOut()); |
154 | return success(); |
155 | } |
156 | |
157 | template <> |
158 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) { |
159 | addValue(v: op.getInput1()); |
160 | addValue(v: op.getInput2()); |
161 | addValue(v: op.getOutput()); |
162 | return success(); |
163 | } |
164 | |
165 | template <> |
166 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) { |
167 | addValue(v: op.getInput()); |
168 | addValue(v: op.getOutput()); |
169 | return success(); |
170 | } |
171 | |
172 | template <> |
173 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) { |
174 | addValue(v: op.getInputReal()); |
175 | addValue(v: op.getInputImag()); |
176 | addValue(v: op.getOutputReal()); |
177 | addValue(v: op.getOutputImag()); |
178 | return success(); |
179 | } |
180 | |
181 | template <> |
182 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) { |
183 | addValue(v: op.getInputReal()); |
184 | addValue(v: op.getOutputReal()); |
185 | addValue(v: op.getOutputImag()); |
186 | return success(); |
187 | } |
188 | |
189 | template <> |
190 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) { |
191 | addValue(v: op.getInput2()); |
192 | addValue(v: op.getInput3()); |
193 | addValue(v: op.getOutput()); |
194 | return success(); |
195 | } |
196 | |
197 | template <> |
198 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) { |
199 | addValue(v: op.getInput()); |
200 | addValue(v: op.getInputZp()); |
201 | addValue(v: op.getOutputZp()); |
202 | addValue(v: op.getOutput()); |
203 | return success(); |
204 | } |
205 | |
206 | template <> |
207 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) { |
208 | addValue(v: op.getA()); |
209 | addValue(v: op.getB()); |
210 | addValue(v: op.getAZp()); |
211 | addValue(v: op.getBZp()); |
212 | addValue(v: op.getOutput()); |
213 | return success(); |
214 | } |
215 | |
216 | template <> |
217 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) { |
218 | addType(t: op.getType()); |
219 | return success(); |
220 | } |
221 | |
222 | template <> |
223 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) { |
224 | addValue(v: op.getInput1()); |
225 | return success(); |
226 | } |
227 | |
228 | template <> |
229 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) { |
230 | addValue(v: op.getCondition()); |
231 | return success(); |
232 | } |
233 | |
234 | template <> |
235 | LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) { |
236 | Block *block = &op.getCondGraph().front(); |
237 | Operation *terminator = block->getTerminator(); |
238 | addValue(v: terminator->getOperands().front()); |
239 | return success(); |
240 | } |
241 | |
242 | LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { |
243 | // This helper function only populates the info for the customised operands. |
244 | #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \ |
245 | if (isa<tosa::tosaOp##Op>(op)) { \ |
246 | return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \ |
247 | } |
248 | |
249 | #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \ |
250 | if (isa<tosa::tosaOp##Op>(op)) \ |
251 | return success(); |
252 | |
253 | // This helper function populates the info for all operands. |
254 | #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \ |
255 | if (isa<tosa::tosaOp##Op>(op)) { \ |
256 | return populateProfileInfo(op->getOperands(), op->getResult(0)); \ |
257 | } |
258 | |
259 | // Skip irrelevant operands when they are independent and not tied to any |
260 | // specific profile/extension. |
261 | POPULATE_PROFILE_INFO_CUSTOM(AvgPool2d) |
262 | POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D) |
263 | POPULATE_PROFILE_INFO_CUSTOM(Conv2D) |
264 | POPULATE_PROFILE_INFO_CUSTOM(Conv3D) |
265 | POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D) |
266 | POPULATE_PROFILE_INFO_CUSTOM(Mul) |
267 | POPULATE_PROFILE_INFO_CUSTOM(FFT2d) |
268 | POPULATE_PROFILE_INFO_CUSTOM(RFFT2d) |
269 | POPULATE_PROFILE_INFO_CUSTOM(Concat) |
270 | POPULATE_PROFILE_INFO_CUSTOM(Pad) |
271 | POPULATE_PROFILE_INFO_CUSTOM(Reshape) |
272 | POPULATE_PROFILE_INFO_CUSTOM(Slice) |
273 | POPULATE_PROFILE_INFO_CUSTOM(Tile) |
274 | POPULATE_PROFILE_INFO_CUSTOM(Transpose) |
275 | POPULATE_PROFILE_INFO_CUSTOM(Gather) |
276 | POPULATE_PROFILE_INFO_CUSTOM(Scatter) |
277 | POPULATE_PROFILE_INFO_CUSTOM(Resize) |
278 | POPULATE_PROFILE_INFO_CUSTOM(Select) |
279 | POPULATE_PROFILE_INFO_CUSTOM(Rescale) |
280 | POPULATE_PROFILE_INFO_CUSTOM(MatMul) |
281 | POPULATE_PROFILE_INFO_CUSTOM(Variable) |
282 | POPULATE_PROFILE_INFO_CUSTOM(VariableWrite) |
283 | POPULATE_PROFILE_INFO_CUSTOM(If) |
284 | POPULATE_PROFILE_INFO_CUSTOM(While) |
285 | |
286 | // For the most of tosa operators, all operands are profile/extension related |
287 | // and hence are all considered in this profile-based compilance check. |
288 | POPULATE_PROFILE_INFO_COMMON(Cast) |
289 | POPULATE_PROFILE_INFO_COMMON(Const) |
290 | POPULATE_PROFILE_INFO_COMMON(ArgMax) |
291 | POPULATE_PROFILE_INFO_COMMON(Sub) |
292 | POPULATE_PROFILE_INFO_COMMON(Maximum) |
293 | POPULATE_PROFILE_INFO_COMMON(Minimum) |
294 | POPULATE_PROFILE_INFO_COMMON(MaxPool2d) |
295 | POPULATE_PROFILE_INFO_COMMON(Clamp) |
296 | POPULATE_PROFILE_INFO_COMMON(Erf) |
297 | POPULATE_PROFILE_INFO_COMMON(Sigmoid) |
298 | POPULATE_PROFILE_INFO_COMMON(Tanh) |
299 | POPULATE_PROFILE_INFO_COMMON(Add) |
300 | POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift) |
301 | POPULATE_PROFILE_INFO_COMMON(BitwiseAnd) |
302 | POPULATE_PROFILE_INFO_COMMON(BitwiseNot) |
303 | POPULATE_PROFILE_INFO_COMMON(BitwiseOr) |
304 | POPULATE_PROFILE_INFO_COMMON(BitwiseXor) |
305 | POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift) |
306 | POPULATE_PROFILE_INFO_COMMON(LogicalRightShift) |
307 | POPULATE_PROFILE_INFO_COMMON(LogicalAnd) |
308 | POPULATE_PROFILE_INFO_COMMON(LogicalNot) |
309 | POPULATE_PROFILE_INFO_COMMON(LogicalOr) |
310 | POPULATE_PROFILE_INFO_COMMON(LogicalXor) |
311 | POPULATE_PROFILE_INFO_COMMON(IntDiv) |
312 | POPULATE_PROFILE_INFO_COMMON(Pow) |
313 | POPULATE_PROFILE_INFO_COMMON(Table) |
314 | POPULATE_PROFILE_INFO_COMMON(Abs) |
315 | POPULATE_PROFILE_INFO_COMMON(Ceil) |
316 | POPULATE_PROFILE_INFO_COMMON(Clz) |
317 | POPULATE_PROFILE_INFO_COMMON(Sin) |
318 | POPULATE_PROFILE_INFO_COMMON(Cos) |
319 | POPULATE_PROFILE_INFO_COMMON(Exp) |
320 | POPULATE_PROFILE_INFO_COMMON(Floor) |
321 | POPULATE_PROFILE_INFO_COMMON(Log) |
322 | POPULATE_PROFILE_INFO_COMMON(Negate) |
323 | POPULATE_PROFILE_INFO_COMMON(Reciprocal) |
324 | POPULATE_PROFILE_INFO_COMMON(Rsqrt) |
325 | POPULATE_PROFILE_INFO_COMMON(ReduceAll) |
326 | POPULATE_PROFILE_INFO_COMMON(ReduceAny) |
327 | POPULATE_PROFILE_INFO_COMMON(ReduceMax) |
328 | POPULATE_PROFILE_INFO_COMMON(ReduceMin) |
329 | POPULATE_PROFILE_INFO_COMMON(ReduceProduct) |
330 | POPULATE_PROFILE_INFO_COMMON(ReduceSum) |
331 | POPULATE_PROFILE_INFO_COMMON(Equal) |
332 | POPULATE_PROFILE_INFO_COMMON(GreaterEqual) |
333 | POPULATE_PROFILE_INFO_COMMON(Greater) |
334 | POPULATE_PROFILE_INFO_COMMON(Reverse) |
335 | POPULATE_PROFILE_INFO_COMMON(Identity) |
336 | POPULATE_PROFILE_INFO_COMMON(VariableRead) |
337 | |
338 | // Type Invariant Extension, a capability extension that is independent |
339 | // of the data type, meaning any compatible type can be used. No type |
340 | // constraint for those operations. |
341 | POPULATE_PROFILE_INFO_SKIP(ConstShape) |
342 | POPULATE_PROFILE_INFO_SKIP(Yield) |
343 | |
344 | return failure(); |
345 | } |
346 | |
347 | //===----------------------------------------------------------------------===// |
348 | // Tosa Profile And Extension Compliance Checker |
349 | //===----------------------------------------------------------------------===// |
350 | |
351 | template <typename T> |
352 | FailureOr<SmallVector<T>> |
353 | TosaProfileCompliance::getOperatorDefinition(Operation *op, |
354 | CheckCondition &condition) { |
355 | const std::string opName = op->getName().getStringRef().str(); |
356 | const auto complianceMap = getProfileComplianceMap<T>(); |
357 | const auto it = complianceMap.find(opName); |
358 | if (it == complianceMap.end()) |
359 | return {}; |
360 | |
361 | return findMatchedProfile<T>(op, it->second, condition); |
362 | } |
363 | |
364 | template <typename T> |
365 | LogicalResult TosaProfileCompliance::checkProfileOrExtension( |
366 | Operation *op, const tosa::TargetEnv &targetEnv, |
367 | const SmallVector<ArrayRef<T>> &specRequiredModeSet) { |
368 | |
369 | // None of profile requirement is set in the specification. |
370 | if (specRequiredModeSet.size() == 0) |
371 | return success(); |
372 | |
373 | CheckCondition condition = CheckCondition::invalid; |
374 | const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition); |
375 | if (failed(maybeOpRequiredMode)) { |
376 | // Operators such as control-flow and shape ops do not have an operand type |
377 | // restriction. When the profile compliance information of operation is not |
378 | // found, confirm if the target have enabled the profile required from the |
379 | // specification. |
380 | int mode_count = 0; |
381 | for (const auto &cands : specRequiredModeSet) { |
382 | if (targetEnv.allowsAnyOf(cands)) |
383 | return success(); |
384 | mode_count += cands.size(); |
385 | } |
386 | |
387 | op->emitOpError() << "illegal: requires" |
388 | << (mode_count > 1 ? " any of " : " " ) << "[" |
389 | << llvm::join(stringifyProfile<T>(specRequiredModeSet), |
390 | ", " ) |
391 | << "] but not enabled in target\n" ; |
392 | |
393 | return failure(); |
394 | } |
395 | |
396 | // Find the required profiles or extensions according to the operand type |
397 | // combination. |
398 | const auto opRequiredMode = maybeOpRequiredMode.value(); |
399 | if (opRequiredMode.size() == 0) { |
400 | // No matched restriction found. |
401 | return success(); |
402 | } |
403 | |
404 | if (condition == CheckCondition::allOf && |
405 | !targetEnv.allowsAllOf(opRequiredMode)) { |
406 | op->emitOpError() << "illegal: requires" |
407 | << (opRequiredMode.size() > 1 ? " all of " : " " ) << "[" |
408 | << llvm::join(stringifyProfile<T>(opRequiredMode), ", " ) |
409 | << "] but not enabled in target\n" ; |
410 | return failure(); |
411 | } |
412 | |
413 | if (condition == CheckCondition::anyOf && |
414 | !targetEnv.allowsAnyOf(opRequiredMode)) { |
415 | op->emitOpError() << "illegal: requires" |
416 | << (opRequiredMode.size() > 1 ? " any of " : " " ) << "[" |
417 | << llvm::join(stringifyProfile<T>(opRequiredMode), ", " ) |
418 | << "] but not enabled in target\n" ; |
419 | return failure(); |
420 | } |
421 | |
422 | // Each extension can contain a list of profiles that it works with, usually |
423 | // have the same data type. |
424 | if constexpr (std::is_same_v<T, Extension>) { |
425 | for (const auto &mode : opRequiredMode) { |
426 | SmallVector<Profile> coProfs = getCooperativeProfiles(mode); |
427 | if (!targetEnv.allowsAnyOf(coProfs)) { |
428 | op->emitOpError() << "illegal: requires [" |
429 | << llvm::join(stringifyProfile<Profile>(coProfs), |
430 | ", " ) |
431 | << "] to work with but not enabled in target\n" ; |
432 | return failure(); |
433 | } |
434 | } |
435 | } |
436 | |
437 | // Ensure the profile inference match the profile knowledge of the |
438 | // specification. |
439 | for (const auto &cands : specRequiredModeSet) { |
440 | for (const auto &mode : opRequiredMode) { |
441 | if (!llvm::is_contained(cands, mode)) { |
442 | op->emitOpError() << "illegal: requires [" |
443 | << llvm::join(stringifyProfile<T>(opRequiredMode), |
444 | ", " ) |
445 | << "] but not included in the profile compliance [" |
446 | << llvm::join( |
447 | stringifyProfile<T>(specRequiredModeSet), ", " ) |
448 | << "]\n" ; |
449 | return failure(); |
450 | } |
451 | } |
452 | } |
453 | |
454 | return success(); |
455 | } |
456 | |
457 | LogicalResult |
458 | TosaProfileCompliance::checkProfile(Operation *op, |
459 | const tosa::TargetEnv &targetEnv) { |
460 | if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op)) |
461 | return checkProfileOrExtension<Profile>(op, targetEnv, |
462 | interface.getProfiles()); |
463 | |
464 | return success(); |
465 | } |
466 | |
467 | LogicalResult |
468 | TosaProfileCompliance::checkExtension(Operation *op, |
469 | const tosa::TargetEnv &targetEnv) { |
470 | if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op)) |
471 | return checkProfileOrExtension<Extension>(op, targetEnv, |
472 | interface.getExtensions()); |
473 | |
474 | return success(); |
475 | } |
476 | |
477 | LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { |
478 | CheckCondition condition = CheckCondition::invalid; |
479 | const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition); |
480 | const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition); |
481 | |
482 | if (!failed(maybeProfDef) && !failed(maybeExtDef) && |
483 | !maybeProfDef.value().size() && !maybeExtDef.value().size()) { |
484 | std::string message; |
485 | llvm::raw_string_ostream os(message); |
486 | os << "illegal: operation operand/result data types did not align with any " |
487 | "profile or extension, got (" ; |
488 | |
489 | ProfileInfoDepot depot(op); |
490 | SmallVector<TypeInfo> current = depot.getInfo(); |
491 | for (const auto &typeInfo : llvm::drop_end(RangeOrContainer&: current)) |
492 | os << stringifyTypeInfo(typeInfo) << "," ; |
493 | os << stringifyTypeInfo(typeInfo: current.back()) << ")" ; |
494 | |
495 | // avoid polluting the error message output by outputting only |
496 | // the best match |
497 | const std::string opName = op->getName().getStringRef().str(); |
498 | int maxMatches = -1; |
499 | SmallVector<TypeInfo> bestTypeInfo; |
500 | const auto searchBestMatch = [&](auto map) { |
501 | for (const auto &complianceInfos : map[opName]) { |
502 | for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) { |
503 | const int matches = llvm::count_if( |
504 | llvm::zip_equal(current, typeInfos), [&](const auto zipType) { |
505 | return isSameTypeInfo(a: std::get<0>(zipType), |
506 | b: std::get<1>(zipType)); |
507 | }); |
508 | if (matches > maxMatches) { |
509 | maxMatches = matches; |
510 | bestTypeInfo = typeInfos; |
511 | } |
512 | } |
513 | } |
514 | }; |
515 | searchBestMatch(getProfileComplianceMap<Profile>()); |
516 | searchBestMatch(getProfileComplianceMap<Extension>()); |
517 | |
518 | os << ", did you mean (" ; |
519 | for (const auto &typeInfo : llvm::drop_end(RangeOrContainer&: bestTypeInfo)) |
520 | os << stringifyTypeInfo(typeInfo) << "," ; |
521 | os << stringifyTypeInfo(typeInfo: bestTypeInfo.back()) << ")? " ; |
522 | os << "Otherwise, please refer to the 'supported data types' for '" |
523 | << opName << "' in the specification." ; |
524 | op->emitOpError(message); |
525 | return failure(); |
526 | } |
527 | |
528 | return success(); |
529 | } |
530 | |
531 | // Find the profiles or extensions requirement according to the signature of |
532 | // type of the operand list. |
533 | template <typename T> |
534 | SmallVector<T> TosaProfileCompliance::findMatchedProfile( |
535 | Operation *op, SmallVector<OpComplianceInfo<T>> compInfo, |
536 | CheckCondition &condition) { |
537 | assert(compInfo.size() != 0 && |
538 | "profile-based compliance information is empty" ); |
539 | |
540 | // Populate the type of profile/extension relevant operands. |
541 | ProfileInfoDepot depot(op); |
542 | SmallVector<TypeInfo> present = depot.getInfo(); |
543 | if (present.size() == 0) |
544 | return {}; |
545 | |
546 | for (size_t i = 0; i < compInfo.size(); i++) { |
547 | SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet; |
548 | for (SmallVector<TypeInfo> expected : sets) { |
549 | assert(present.size() == expected.size() && |
550 | "the entries for profile-based compliance do not match between " |
551 | "the generated metadata and the type definition retrieved from " |
552 | " the operation" ); |
553 | |
554 | bool is_found = true; |
555 | // Compare the type signature between the given operation and the |
556 | // compliance metadata. |
557 | for (size_t j = 0; j < expected.size(); j++) { |
558 | if (!isSameTypeInfo(a: present[j], b: expected[j])) { |
559 | // Verify the next mode set from the list. |
560 | is_found = false; |
561 | break; |
562 | } |
563 | } |
564 | |
565 | if (is_found == true) { |
566 | condition = compInfo[i].condition; |
567 | return compInfo[i].mode; |
568 | } |
569 | } |
570 | } |
571 | |
572 | return {}; |
573 | } |
574 | |
575 | // Debug utilites. |
576 | template <typename T> |
577 | SmallVector<StringRef> |
578 | TosaProfileCompliance::stringifyProfile(ArrayRef<T> profiles) { |
579 | SmallVector<StringRef> debugStrings; |
580 | for (const auto &profile : profiles) { |
581 | if constexpr (std::is_same_v<T, Profile>) |
582 | debugStrings.push_back(tosa::stringifyProfile(profile)); |
583 | else |
584 | debugStrings.push_back(tosa::stringifyExtension(profile)); |
585 | } |
586 | return debugStrings; |
587 | } |
588 | |
589 | template <typename T> |
590 | SmallVector<StringRef> TosaProfileCompliance::stringifyProfile( |
591 | const SmallVector<ArrayRef<T>> &profileSet) { |
592 | SmallVector<StringRef> debugStrings; |
593 | |
594 | for (const auto &profiles : profileSet) { |
595 | auto tempStrings = stringifyProfile<T>(profiles); |
596 | llvm::append_range(debugStrings, tempStrings); |
597 | } |
598 | |
599 | return debugStrings; |
600 | } |
601 | |
602 | llvm::SmallString<7> |
603 | TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) { |
604 | if (typeInfo.typeID == mlir::IntegerType::getTypeID()) { |
605 | return {"i" + llvm::utostr(X: typeInfo.bitWidth)}; |
606 | } else if (typeInfo.typeID == mlir::Float16Type::getTypeID()) { |
607 | return {"f16" }; |
608 | } else if (typeInfo.typeID == mlir::Float32Type::getTypeID()) { |
609 | return {"f32" }; |
610 | } else if (typeInfo.typeID == mlir::BFloat16Type::getTypeID()) { |
611 | return {"bf16" }; |
612 | } else if (typeInfo.typeID == mlir::Float8E4M3FNType::getTypeID()) { |
613 | return {"fp8e4m3" }; |
614 | } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) { |
615 | return {"fp8e5m2" }; |
616 | } |
617 | llvm_unreachable("unknown type" ); |
618 | } |
619 | |