1 | //===- Shape.cpp - MLIR Shape Operations ----------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <utility> |
10 | |
11 | #include "mlir/Dialect/Shape/IR/Shape.h" |
12 | |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
15 | #include "mlir/Dialect/CommonFolders.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "mlir/Dialect/Traits.h" |
18 | #include "mlir/Dialect/UB/IR/UBOps.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/BuiltinTypes.h" |
21 | #include "mlir/IR/DialectImplementation.h" |
22 | #include "mlir/IR/Matchers.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/IR/TypeUtilities.h" |
25 | #include "mlir/Interfaces/FunctionImplementation.h" |
26 | #include "mlir/Transforms/InliningUtils.h" |
27 | #include "llvm/ADT/SetOperations.h" |
28 | #include "llvm/ADT/SmallString.h" |
29 | #include "llvm/ADT/TypeSwitch.h" |
30 | #include "llvm/Support/raw_ostream.h" |
31 | |
32 | using namespace mlir; |
33 | using namespace mlir::shape; |
34 | |
35 | #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" |
36 | |
37 | namespace { |
38 | #include "ShapeCanonicalization.inc" |
39 | } // namespace |
40 | |
41 | RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { |
42 | return RankedTensorType::get({rank}, IndexType::get(ctx)); |
43 | } |
44 | |
45 | bool shape::isExtentTensorType(Type type) { |
46 | auto ranked = llvm::dyn_cast<RankedTensorType>(type); |
47 | return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex(); |
48 | } |
49 | |
50 | LogicalResult shape::getShapeVec(Value input, |
51 | SmallVectorImpl<int64_t> &shapeValues) { |
52 | if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) { |
53 | auto type = llvm::cast<ShapedType>(inputOp.getArg().getType()); |
54 | if (!type.hasRank()) |
55 | return failure(); |
56 | llvm::append_range(shapeValues, type.getShape()); |
57 | return success(); |
58 | } |
59 | DenseIntElementsAttr attr; |
60 | if (matchPattern(value: input, pattern: m_Constant(bind_value: &attr))) { |
61 | llvm::append_range(shapeValues, attr.getValues<int64_t>()); |
62 | return success(); |
63 | } |
64 | return failure(); |
65 | } |
66 | |
67 | static bool isErrorPropagationPossible(TypeRange operandTypes) { |
68 | return llvm::any_of(operandTypes, |
69 | llvm::IsaPred<SizeType, ShapeType, ValueShapeType>); |
70 | } |
71 | |
72 | static LogicalResult verifySizeOrIndexOp(Operation *op) { |
73 | assert(op != nullptr && op->getNumResults() == 1); |
74 | Type resultTy = op->getResultTypes().front(); |
75 | if (isErrorPropagationPossible(operandTypes: op->getOperandTypes())) { |
76 | if (!llvm::isa<SizeType>(resultTy)) |
77 | return op->emitOpError() |
78 | << "if at least one of the operands can hold error values then " |
79 | "the result must be of type `size` to propagate them"; |
80 | } |
81 | return success(); |
82 | } |
83 | |
84 | static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { |
85 | assert(op != nullptr && op->getNumResults() == 1); |
86 | Type resultTy = op->getResultTypes().front(); |
87 | if (isErrorPropagationPossible(operandTypes: op->getOperandTypes())) { |
88 | if (!llvm::isa<ShapeType>(resultTy)) |
89 | return op->emitOpError() |
90 | << "if at least one of the operands can hold error values then " |
91 | "the result must be of type `shape` to propagate them"; |
92 | } |
93 | return success(); |
94 | } |
95 | |
96 | template <typename... Ty> |
97 | static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { |
98 | return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front()); |
99 | } |
100 | |
101 | template <typename... Ty, typename... ranges> |
102 | static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { |
103 | return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...); |
104 | } |
105 | |
106 | //===----------------------------------------------------------------------===// |
107 | // InlinerInterface |
108 | //===----------------------------------------------------------------------===// |
109 | |
110 | namespace { |
111 | /// This class defines the interface for inlining shape dialect ops. |
112 | struct ShapeInlinerInterface : public DialectInlinerInterface { |
113 | using DialectInlinerInterface::DialectInlinerInterface; |
114 | |
115 | // Returns true if the given region 'src' can be inlined into the region |
116 | // 'dest' that is attached to an operation registered to the current dialect. |
117 | bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, |
118 | IRMapping &) const final { |
119 | return true; |
120 | } |
121 | |
122 | // Returns true if the given operation 'op', that is registered to this |
123 | // dialect, can be inlined into the region 'dest' that is attached to an |
124 | // operation registered to the current dialect. |
125 | bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, |
126 | IRMapping &) const final { |
127 | return true; |
128 | } |
129 | }; |
130 | } // namespace |
131 | |
132 | void ShapeDialect::initialize() { |
133 | addOperations< |
134 | #define GET_OP_LIST |
135 | #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" |
136 | >(); |
137 | addTypes< |
138 | #define GET_TYPEDEF_LIST |
139 | #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc" |
140 | >(); |
141 | addInterfaces<ShapeInlinerInterface>(); |
142 | // Allow unknown operations during prototyping and testing. As the dialect is |
143 | // still evolving it makes it simple to start with an unregistered ops and |
144 | // try different variants before actually defining the op. |
145 | allowUnknownOperations(); |
146 | declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp, |
147 | AssumingYieldOp>(); |
148 | } |
149 | |
150 | Operation *ShapeDialect::materializeConstant(OpBuilder &builder, |
151 | Attribute value, Type type, |
152 | Location loc) { |
153 | if (auto poison = dyn_cast<ub::PoisonAttr>(value)) |
154 | return builder.create<ub::PoisonOp>(loc, type, poison); |
155 | |
156 | if (llvm::isa<ShapeType>(type) || isExtentTensorType(type)) |
157 | return builder.create<ConstShapeOp>( |
158 | loc, type, llvm::cast<DenseIntElementsAttr>(value)); |
159 | if (llvm::isa<SizeType>(type)) |
160 | return builder.create<ConstSizeOp>(loc, type, |
161 | llvm::cast<IntegerAttr>(value)); |
162 | if (llvm::isa<WitnessType>(type)) |
163 | return builder.create<ConstWitnessOp>(loc, type, |
164 | llvm::cast<BoolAttr>(value)); |
165 | |
166 | return arith::ConstantOp::materialize(builder, value, type, loc); |
167 | } |
168 | |
169 | LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, |
170 | NamedAttribute attribute) { |
171 | // Verify shape.lib attribute. |
172 | if (attribute.getName() == "shape.lib") { |
173 | if (!op->hasTrait<OpTrait::SymbolTable>()) |
174 | return op->emitError( |
175 | "shape.lib attribute may only be on op implementing SymbolTable"); |
176 | |
177 | if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) { |
178 | auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); |
179 | if (!symbol) |
180 | return op->emitError("shape function library ") |
181 | << symbolRef << " not found"; |
182 | return isa<shape::FunctionLibraryOp>(symbol) |
183 | ? success() |
184 | : op->emitError() |
185 | << symbolRef << " required to be shape function library"; |
186 | } |
187 | |
188 | if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) { |
189 | // Verify all entries are function libraries and mappings in libraries |
190 | // refer to unique ops. |
191 | DenseSet<StringAttr> key; |
192 | for (auto it : arr) { |
193 | if (!llvm::isa<SymbolRefAttr>(it)) |
194 | return op->emitError( |
195 | "only SymbolRefAttr allowed in shape.lib attribute array"); |
196 | |
197 | auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>( |
198 | SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it))); |
199 | if (!shapeFnLib) |
200 | return op->emitError() |
201 | << it << " does not refer to FunctionLibraryOp"; |
202 | for (auto mapping : shapeFnLib.getMapping()) { |
203 | if (!key.insert(mapping.getName()).second) { |
204 | return op->emitError("only one op to shape mapping allowed, found " |
205 | "multiple for `") |
206 | << mapping.getName() << "`"; |
207 | } |
208 | } |
209 | } |
210 | return success(); |
211 | } |
212 | |
213 | return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs " |
214 | "allowed as shape.lib attribute"); |
215 | } |
216 | return success(); |
217 | } |
218 | |
219 | //===----------------------------------------------------------------------===// |
220 | // AnyOp |
221 | //===----------------------------------------------------------------------===// |
222 | |
223 | // TODO: Canonicalization should be implemented for shapes that can be |
224 | // determined through mixtures of the known dimensions of the inputs. |
225 | OpFoldResult AnyOp::fold(FoldAdaptor adaptor) { |
226 | // Only the last operand is checked because AnyOp is commutative. |
227 | if (adaptor.getInputs().back()) |
228 | return adaptor.getInputs().back(); |
229 | |
230 | return nullptr; |
231 | } |
232 | |
233 | //===----------------------------------------------------------------------===// |
234 | // AssumingOp |
235 | //===----------------------------------------------------------------------===// |
236 | |
237 | ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) { |
238 | result.regions.reserve(1); |
239 | Region *doRegion = result.addRegion(); |
240 | |
241 | auto &builder = parser.getBuilder(); |
242 | OpAsmParser::UnresolvedOperand cond; |
243 | if (parser.parseOperand(cond) || |
244 | parser.resolveOperand(cond, builder.getType<WitnessType>(), |
245 | result.operands)) |
246 | return failure(); |
247 | |
248 | // Parse optional results type list. |
249 | if (parser.parseOptionalArrowTypeList(result.types)) |
250 | return failure(); |
251 | |
252 | // Parse the region and add a terminator if elided. |
253 | if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{})) |
254 | return failure(); |
255 | AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location); |
256 | |
257 | // Parse the optional attribute list. |
258 | if (parser.parseOptionalAttrDict(result.attributes)) |
259 | return failure(); |
260 | return success(); |
261 | } |
262 | |
263 | void AssumingOp::print(OpAsmPrinter &p) { |
264 | bool yieldsResults = !getResults().empty(); |
265 | |
266 | p << " "<< getWitness(); |
267 | if (yieldsResults) |
268 | p << " -> ("<< getResultTypes() << ")"; |
269 | p << ' '; |
270 | p.printRegion(getDoRegion(), |
271 | /*printEntryBlockArgs=*/false, |
272 | /*printBlockTerminators=*/yieldsResults); |
273 | p.printOptionalAttrDict((*this)->getAttrs()); |
274 | } |
275 | |
276 | namespace { |
277 | // Removes AssumingOp with a passing witness and inlines the region. |
278 | struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { |
279 | using OpRewritePattern<AssumingOp>::OpRewritePattern; |
280 | |
281 | LogicalResult matchAndRewrite(AssumingOp op, |
282 | PatternRewriter &rewriter) const override { |
283 | auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>(); |
284 | if (!witness || !witness.getPassingAttr()) |
285 | return failure(); |
286 | |
287 | AssumingOp::inlineRegionIntoParent(op, rewriter); |
288 | return success(); |
289 | } |
290 | }; |
291 | |
292 | struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { |
293 | using OpRewritePattern<AssumingOp>::OpRewritePattern; |
294 | |
295 | LogicalResult matchAndRewrite(AssumingOp op, |
296 | PatternRewriter &rewriter) const override { |
297 | Block *body = op.getBody(); |
298 | auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator()); |
299 | |
300 | // Find used values. |
301 | SmallVector<Value, 4> newYieldOperands; |
302 | for (auto [opResult, yieldOperand] : |
303 | llvm::zip(op.getResults(), yieldOp.getOperands())) { |
304 | if (!opResult.getUses().empty()) { |
305 | newYieldOperands.push_back(yieldOperand); |
306 | } |
307 | } |
308 | |
309 | // Rewrite only if redundant results exist. |
310 | if (newYieldOperands.size() == yieldOp->getNumOperands()) |
311 | return failure(); |
312 | |
313 | // Replace yield op in the old assuming op's body and move the entire region |
314 | // to the new assuming op. |
315 | rewriter.setInsertionPointToEnd(body); |
316 | auto newYieldOp = |
317 | rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); |
318 | rewriter.setInsertionPoint(op); |
319 | auto newOp = rewriter.create<AssumingOp>( |
320 | op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); |
321 | newOp.getDoRegion().takeBody(op.getDoRegion()); |
322 | |
323 | // Use the new results to replace the previously used ones. |
324 | SmallVector<Value, 4> replacementValues; |
325 | auto src = newOp.getResults().begin(); |
326 | for (auto it : op.getResults()) { |
327 | if (it.getUses().empty()) |
328 | replacementValues.push_back(nullptr); |
329 | else |
330 | replacementValues.push_back(*src++); |
331 | } |
332 | rewriter.replaceOp(op, replacementValues); |
333 | return success(); |
334 | } |
335 | }; |
336 | } // namespace |
337 | |
338 | void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
339 | MLIRContext *context) { |
340 | patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context); |
341 | } |
342 | |
343 | // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td |
344 | void AssumingOp::getSuccessorRegions( |
345 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
346 | // AssumingOp has unconditional control flow into the region and back to the |
347 | // parent, so return the correct RegionSuccessor purely based on the index |
348 | // being None or 0. |
349 | if (!point.isParent()) { |
350 | regions.push_back(RegionSuccessor(getResults())); |
351 | return; |
352 | } |
353 | |
354 | regions.push_back(RegionSuccessor(&getDoRegion())); |
355 | } |
356 | |
357 | void AssumingOp::inlineRegionIntoParent(AssumingOp &op, |
358 | PatternRewriter &rewriter) { |
359 | auto *blockBeforeAssuming = rewriter.getInsertionBlock(); |
360 | auto *assumingBlock = op.getBody(); |
361 | auto initPosition = rewriter.getInsertionPoint(); |
362 | auto *blockAfterAssuming = |
363 | rewriter.splitBlock(blockBeforeAssuming, initPosition); |
364 | |
365 | // Remove the AssumingOp and AssumingYieldOp. |
366 | auto &yieldOp = assumingBlock->back(); |
367 | rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming); |
368 | rewriter.replaceOp(op, yieldOp.getOperands()); |
369 | rewriter.eraseOp(&yieldOp); |
370 | |
371 | // Merge blocks together as there was no branching behavior from the |
372 | // AssumingOp. |
373 | rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming); |
374 | rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming); |
375 | } |
376 | |
377 | void AssumingOp::build( |
378 | OpBuilder &builder, OperationState &result, Value witness, |
379 | function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { |
380 | OpBuilder::InsertionGuard g(builder); |
381 | |
382 | result.addOperands(witness); |
383 | Region *bodyRegion = result.addRegion(); |
384 | builder.createBlock(bodyRegion); |
385 | |
386 | // Build body. |
387 | SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); |
388 | builder.create<AssumingYieldOp>(result.location, yieldValues); |
389 | |
390 | SmallVector<Type, 2> assumingTypes; |
391 | for (Value v : yieldValues) |
392 | assumingTypes.push_back(v.getType()); |
393 | result.addTypes(assumingTypes); |
394 | } |
395 | |
396 | //===----------------------------------------------------------------------===// |
397 | // AddOp |
398 | //===----------------------------------------------------------------------===// |
399 | |
400 | LogicalResult mlir::shape::AddOp::inferReturnTypes( |
401 | MLIRContext *context, std::optional<Location> location, |
402 | AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { |
403 | if (llvm::isa<SizeType>(adaptor.getLhs().getType()) || |
404 | llvm::isa<SizeType>(adaptor.getRhs().getType())) |
405 | inferredReturnTypes.assign({SizeType::get(context)}); |
406 | else |
407 | inferredReturnTypes.assign({IndexType::get(context)}); |
408 | return success(); |
409 | } |
410 | |
411 | bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
412 | // SizeType is compatible with IndexType. |
413 | return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); |
414 | } |
415 | |
416 | OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) { |
417 | // add(x, 0) -> x |
418 | if (matchPattern(getRhs(), m_Zero())) |
419 | return getLhs(); |
420 | |
421 | return constFoldBinaryOp<IntegerAttr>( |
422 | adaptor.getOperands(), |
423 | [](APInt a, const APInt &b) { return std::move(a) + b; }); |
424 | } |
425 | |
426 | LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } |
427 | |
428 | //===----------------------------------------------------------------------===// |
429 | // AssumingAllOp |
430 | //===----------------------------------------------------------------------===// |
431 | |
432 | namespace { |
433 | |
434 | // Merge multiple `shape.assuming_all` operations together. |
435 | // |
436 | // %0 = shape.assuming_all %w0, %w1 |
437 | // %1 = shape.assuming_all %w2, %0 |
438 | // |
439 | // to: |
440 | // |
441 | // %0 = shape.assuming_all %w0, %w2, %w2 |
442 | struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> { |
443 | using OpRewritePattern<AssumingAllOp>::OpRewritePattern; |
444 | |
445 | LogicalResult matchAndRewrite(AssumingAllOp op, |
446 | PatternRewriter &rewriter) const override { |
447 | SmallVector<Value> operands; |
448 | |
449 | for (Value operand : op.getInputs()) { |
450 | if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>()) |
451 | operands.append(assumeAll.operand_begin(), assumeAll->operand_end()); |
452 | else |
453 | operands.push_back(operand); |
454 | } |
455 | |
456 | // We didn't find any other `assuming_all` ops to merge with. |
457 | if (operands.size() == op.getNumOperands()) |
458 | return failure(); |
459 | |
460 | // Replace with a new `assuming_all` operation with merged constraints. |
461 | rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands); |
462 | return success(); |
463 | } |
464 | }; |
465 | |
466 | // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that |
467 | // are subsumed by others. |
468 | // |
469 | // %0 = shape.cstr_broadcastable %shape0, %shape1 |
470 | // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2 |
471 | // |
472 | // %2 = shape.cstr_broadcastable %shape3, %shape4 |
473 | // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5 |
474 | // |
475 | // %4 = shape.assuming_all %0, %1, %2, %3 |
476 | // |
477 | // to: |
478 | // |
479 | // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2 |
480 | // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5 |
481 | // %2 = shape.assuming_all %0, %1 |
482 | // |
483 | // In this example if shapes [0, 1, 2] are broadcastable, then it means that |
484 | // shapes [0, 1] are broadcastable too, and can be removed from the list of |
485 | // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't |
486 | // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]). |
487 | struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> { |
488 | using OpRewritePattern<AssumingAllOp>::OpRewritePattern; |
489 | |
490 | LogicalResult matchAndRewrite(AssumingAllOp op, |
491 | PatternRewriter &rewriter) const override { |
492 | // Collect all `CstrBroadcastableOp` operands first. |
493 | SetVector<CstrBroadcastableOp> operands; |
494 | for (Value operand : op.getInputs()) { |
495 | // TODO: Apply this optimization if some of the witnesses are not |
496 | // produced by the `cstr_broadcastable`. |
497 | auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>(); |
498 | if (!broadcastable) |
499 | return failure(); |
500 | |
501 | operands.insert(broadcastable); |
502 | } |
503 | |
504 | // Skip trivial `assuming_all` operations. |
505 | if (operands.size() <= 1) |
506 | return failure(); |
507 | |
508 | // Collect shapes checked by `cstr_broadcastable` operands. |
509 | SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes; |
510 | for (auto cstr : operands) { |
511 | DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end()); |
512 | shapes.emplace_back(cstr, std::move(shapesSet)); |
513 | } |
514 | |
515 | // Sort by the number of shape operands (larger to smaller). |
516 | llvm::sort(shapes, [](auto a, auto b) { |
517 | return a.first.getNumOperands() > b.first.getNumOperands(); |
518 | }); |
519 | |
520 | // We start from the `cst_broadcastable` operations with largest number of |
521 | // shape operands, and remove redundant `cst_broadcastable` operations. We |
522 | // do this until we find a set of `cst_broadcastable` operations with |
523 | // non-overlapping constraints. |
524 | SmallVector<CstrBroadcastableOp> markedForErase; |
525 | |
526 | for (unsigned i = 0; i < shapes.size(); ++i) { |
527 | auto isSubset = [&](auto pair) { |
528 | return llvm::set_is_subset(pair.second, shapes[i].second); |
529 | }; |
530 | |
531 | // Keep redundant `cstr_broadcastable` operations to be erased. |
532 | auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset); |
533 | for (auto *it0 = it; it0 < shapes.end(); ++it0) |
534 | markedForErase.push_back(it0->first); |
535 | shapes.erase(it, shapes.end()); |
536 | } |
537 | |
538 | // We didn't find any operands that could be removed. |
539 | if (markedForErase.empty()) |
540 | return failure(); |
541 | |
542 | // Collect non-overlapping `cst_broadcastable` constraints. |
543 | SmallVector<Value> uniqueConstraints; |
544 | for (auto &shape : shapes) |
545 | uniqueConstraints.push_back(shape.first.getResult()); |
546 | |
547 | // Replace with a new `assuming_all` operation ... |
548 | rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints); |
549 | |
550 | // ... and maybe erase `cstr_broadcastable` ops without uses. |
551 | for (auto &op : markedForErase) |
552 | if (op->use_empty()) |
553 | rewriter.eraseOp(op); |
554 | |
555 | return success(); |
556 | } |
557 | }; |
558 | |
559 | struct AssumingAllToCstrEqCanonicalization |
560 | : public OpRewritePattern<AssumingAllOp> { |
561 | using OpRewritePattern<AssumingAllOp>::OpRewritePattern; |
562 | |
563 | LogicalResult matchAndRewrite(AssumingAllOp op, |
564 | PatternRewriter &rewriter) const override { |
565 | SmallVector<Value, 8> shapes; |
566 | for (Value w : op.getInputs()) { |
567 | auto cstrEqOp = w.getDefiningOp<CstrEqOp>(); |
568 | if (!cstrEqOp) |
569 | return failure(); |
570 | bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) { |
571 | return llvm::is_contained(shapes, s); |
572 | }); |
573 | if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes) |
574 | return failure(); |
575 | shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end()); |
576 | } |
577 | rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes); |
578 | return success(); |
579 | } |
580 | }; |
581 | |
582 | template <typename OpTy> |
583 | struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { |
584 | using OpRewritePattern<OpTy>::OpRewritePattern; |
585 | |
586 | LogicalResult matchAndRewrite(OpTy op, |
587 | PatternRewriter &rewriter) const override { |
588 | // Find unique operands. |
589 | SetVector<Value> unique(op.operand_begin(), op.operand_end()); |
590 | |
591 | // Reduce op to equivalent with unique operands. |
592 | if (unique.size() < op.getNumOperands()) { |
593 | rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), |
594 | unique.takeVector(), op->getAttrs()); |
595 | return success(); |
596 | } |
597 | |
598 | return failure(); |
599 | } |
600 | }; |
601 | } // namespace |
602 | |
603 | void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
604 | MLIRContext *context) { |
605 | patterns |
606 | .add<MergeAssumingAllOps, AssumingAllOneOp, |
607 | AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization, |
608 | RemoveDuplicateOperandsPattern<AssumingAllOp>>(context); |
609 | } |
610 | |
611 | OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) { |
612 | // Iterate in reverse to first handle all constant operands. They are |
613 | // guaranteed to be the tail of the inputs because this is commutative. |
614 | for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) { |
615 | Attribute a = adaptor.getInputs()[idx]; |
616 | // Cannot fold if any inputs are not constant; |
617 | if (!a) |
618 | return nullptr; |
619 | |
620 | // We do not need to keep statically known values after handling them in |
621 | // this method. |
622 | getOperation()->eraseOperand(idx); |
623 | |
624 | // Always false if any input is statically known false |
625 | if (!llvm::cast<BoolAttr>(a).getValue()) |
626 | return a; |
627 | } |
628 | // If this is reached, all inputs were statically known passing. |
629 | return BoolAttr::get(getContext(), true); |
630 | } |
631 | |
632 | LogicalResult AssumingAllOp::verify() { |
633 | // Ensure that AssumingAllOp contains at least one operand |
634 | if (getNumOperands() == 0) |
635 | return emitOpError("no operands specified"); |
636 | |
637 | return success(); |
638 | } |
639 | |
640 | //===----------------------------------------------------------------------===// |
641 | // BroadcastOp |
642 | //===----------------------------------------------------------------------===// |
643 | |
644 | OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { |
645 | if (getShapes().size() == 1) { |
646 | // Otherwise, we need a cast which would be a canonicalization, not folding. |
647 | if (getShapes().front().getType() != getType()) |
648 | return nullptr; |
649 | return getShapes().front(); |
650 | } |
651 | |
652 | if (!adaptor.getShapes().front()) |
653 | return nullptr; |
654 | |
655 | SmallVector<int64_t, 6> resultShape( |
656 | llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front()) |
657 | .getValues<int64_t>()); |
658 | |
659 | for (auto next : adaptor.getShapes().drop_front()) { |
660 | if (!next) |
661 | return nullptr; |
662 | auto nextShape = llvm::to_vector<6>( |
663 | llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>()); |
664 | |
665 | SmallVector<int64_t, 6> tmpShape; |
666 | // If the shapes are not compatible, we can't fold it. |
667 | // TODO: Fold to an "error". |
668 | if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape)) |
669 | return nullptr; |
670 | |
671 | resultShape.clear(); |
672 | std::copy(tmpShape.begin(), tmpShape.end(), |
673 | std::back_inserter(resultShape)); |
674 | } |
675 | |
676 | Builder builder(getContext()); |
677 | return builder.getIndexTensorAttr(resultShape); |
678 | } |
679 | |
680 | LogicalResult BroadcastOp::verify() { |
681 | return verifyShapeOrExtentTensorOp(*this); |
682 | } |
683 | |
684 | namespace { |
685 | template <typename OpTy> |
686 | struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { |
687 | using OpRewritePattern<OpTy>::OpRewritePattern; |
688 | |
689 | LogicalResult matchAndRewrite(OpTy op, |
690 | PatternRewriter &rewriter) const override { |
691 | auto isPotentiallyNonEmptyShape = [](Value shape) { |
692 | if (auto extentTensorTy = |
693 | llvm::dyn_cast<RankedTensorType>(shape.getType())) { |
694 | if (extentTensorTy.getDimSize(0) == 0) |
695 | return false; |
696 | } |
697 | if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { |
698 | if (constShape.getShape().empty()) |
699 | return false; |
700 | } |
701 | return true; |
702 | }; |
703 | auto newOperands = llvm::filter_to_vector<8>(op->getOperands(), |
704 | isPotentiallyNonEmptyShape); |
705 | |
706 | // Replace the op with empty shape constant if all operants are reduced to |
707 | // be empty. |
708 | if (newOperands.empty()) { |
709 | rewriter.replaceOpWithNewOp<ConstShapeOp>( |
710 | op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({})); |
711 | return success(); |
712 | } |
713 | |
714 | // Reduce op to equivalent without empty shape operands. |
715 | if (newOperands.size() < op.getNumOperands()) { |
716 | rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands, |
717 | op->getAttrs()); |
718 | return success(); |
719 | } |
720 | |
721 | return failure(); |
722 | } |
723 | }; |
724 | |
725 | struct BroadcastForwardSingleOperandPattern |
726 | : public OpRewritePattern<BroadcastOp> { |
727 | using OpRewritePattern<BroadcastOp>::OpRewritePattern; |
728 | |
729 | LogicalResult matchAndRewrite(BroadcastOp op, |
730 | PatternRewriter &rewriter) const override { |
731 | if (op.getNumOperands() != 1) |
732 | return failure(); |
733 | Value replacement = op.getShapes().front(); |
734 | |
735 | // Insert cast if needed. |
736 | if (replacement.getType() != op.getType()) { |
737 | auto loc = op.getLoc(); |
738 | if (llvm::isa<ShapeType>(op.getType())) { |
739 | replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); |
740 | } else { |
741 | assert(!llvm::isa<ShapeType>(op.getType()) && |
742 | !llvm::isa<ShapeType>(replacement.getType()) && |
743 | "expect extent tensor cast"); |
744 | replacement = |
745 | rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); |
746 | } |
747 | } |
748 | |
749 | rewriter.replaceOp(op, replacement); |
750 | return success(); |
751 | } |
752 | }; |
753 | |
754 | struct BroadcastFoldConstantOperandsPattern |
755 | : public OpRewritePattern<BroadcastOp> { |
756 | using OpRewritePattern<BroadcastOp>::OpRewritePattern; |
757 | |
758 | LogicalResult matchAndRewrite(BroadcastOp op, |
759 | PatternRewriter &rewriter) const override { |
760 | SmallVector<int64_t, 8> foldedConstantShape; |
761 | SmallVector<Value, 8> newShapeOperands; |
762 | for (Value shape : op.getShapes()) { |
763 | if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) { |
764 | SmallVector<int64_t, 8> newFoldedConstantShape; |
765 | if (OpTrait::util::getBroadcastedShape( |
766 | foldedConstantShape, |
767 | llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()), |
768 | newFoldedConstantShape)) { |
769 | foldedConstantShape = newFoldedConstantShape; |
770 | continue; |
771 | } |
772 | } |
773 | newShapeOperands.push_back(shape); |
774 | } |
775 | |
776 | // Need at least two constant operands to fold anything. |
777 | if (op.getNumOperands() - newShapeOperands.size() < 2) |
778 | return failure(); |
779 | |
780 | auto foldedConstantOperandsTy = RankedTensorType::get( |
781 | {static_cast<int64_t>(foldedConstantShape.size())}, |
782 | rewriter.getIndexType()); |
783 | newShapeOperands.push_back(rewriter.create<ConstShapeOp>( |
784 | op.getLoc(), foldedConstantOperandsTy, |
785 | rewriter.getIndexTensorAttr(foldedConstantShape))); |
786 | rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), |
787 | newShapeOperands); |
788 | return success(); |
789 | } |
790 | }; |
791 | |
792 | template <typename OpTy> |
793 | struct CanonicalizeCastExtentTensorOperandsPattern |
794 | : public OpRewritePattern<OpTy> { |
795 | using OpRewritePattern<OpTy>::OpRewritePattern; |
796 | |
797 | LogicalResult matchAndRewrite(OpTy op, |
798 | PatternRewriter &rewriter) const override { |
799 | // Canonicalize operands. |
800 | bool anyChange = false; |
801 | auto canonicalizeOperand = [&](Value operand) -> Value { |
802 | if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) { |
803 | // Only eliminate the cast if it holds no shape information. |
804 | bool isInformationLoosingCast = |
805 | llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0); |
806 | if (isInformationLoosingCast) { |
807 | anyChange = true; |
808 | return castOp.getSource(); |
809 | } |
810 | } |
811 | return operand; |
812 | }; |
813 | auto newOperands = llvm::to_vector<8>( |
814 | llvm::map_range(op.getOperands(), canonicalizeOperand)); |
815 | |
816 | // Rewrite op if any change required. |
817 | if (!anyChange) |
818 | return failure(); |
819 | rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands); |
820 | return success(); |
821 | } |
822 | }; |
823 | |
824 | struct BroadcastConcretizeResultTypePattern |
825 | : public OpRewritePattern<BroadcastOp> { |
826 | using OpRewritePattern<BroadcastOp>::OpRewritePattern; |
827 | |
828 | LogicalResult matchAndRewrite(BroadcastOp op, |
829 | PatternRewriter &rewriter) const override { |
830 | // Only concretize dynamic extent tensor result types. |
831 | auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType()); |
832 | if (!resultTy || !resultTy.isDynamicDim(0)) |
833 | return failure(); |
834 | |
835 | // Infer resulting shape rank if possible. |
836 | int64_t maxRank = 0; |
837 | for (Value shape : op.getShapes()) { |
838 | if (auto extentTensorTy = |
839 | llvm::dyn_cast<RankedTensorType>(shape.getType())) { |
840 | // Cannot infer resulting shape rank if any operand is dynamically |
841 | // ranked. |
842 | if (extentTensorTy.isDynamicDim(0)) |
843 | return failure(); |
844 | maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); |
845 | } |
846 | } |
847 | |
848 | auto newOp = rewriter.create<BroadcastOp>( |
849 | op.getLoc(), getExtentTensorType(getContext(), maxRank), |
850 | op.getShapes()); |
851 | rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); |
852 | return success(); |
853 | } |
854 | }; |
855 | } // namespace |
856 | |
857 | void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
858 | MLIRContext *context) { |
859 | patterns.add<BroadcastConcretizeResultTypePattern, |
860 | BroadcastFoldConstantOperandsPattern, |
861 | BroadcastForwardSingleOperandPattern, |
862 | CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>, |
863 | RemoveDuplicateOperandsPattern<BroadcastOp>, |
864 | RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context); |
865 | } |
866 | |
867 | //===----------------------------------------------------------------------===// |
868 | // ConcatOp |
869 | //===----------------------------------------------------------------------===// |
870 | |
871 | OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { |
872 | if (!adaptor.getLhs() || !adaptor.getRhs()) |
873 | return nullptr; |
874 | auto lhsShape = llvm::to_vector<6>( |
875 | llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>()); |
876 | auto rhsShape = llvm::to_vector<6>( |
877 | llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>()); |
878 | SmallVector<int64_t, 6> resultShape; |
879 | resultShape.append(lhsShape.begin(), lhsShape.end()); |
880 | resultShape.append(rhsShape.begin(), rhsShape.end()); |
881 | Builder builder(getContext()); |
882 | return builder.getIndexTensorAttr(resultShape); |
883 | } |
884 | |
885 | //===----------------------------------------------------------------------===// |
886 | // ConstShapeOp |
887 | //===----------------------------------------------------------------------===// |
888 | |
889 | void ConstShapeOp::print(OpAsmPrinter &p) { |
890 | p << " "; |
891 | p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"}); |
892 | p << "["; |
893 | interleaveComma(getShape().getValues<int64_t>(), p); |
894 | p << "] : "; |
895 | p.printType(getType()); |
896 | } |
897 | |
898 | ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) { |
899 | if (parser.parseOptionalAttrDict(result.attributes)) |
900 | return failure(); |
901 | // We piggy-back on ArrayAttr parsing, though we don't internally store the |
902 | // shape as an ArrayAttr. |
903 | // TODO: Implement custom parser and maybe make syntax a bit more concise. |
904 | Attribute extentsRaw; |
905 | NamedAttrList dummy; |
906 | if (parser.parseAttribute(extentsRaw, "dummy", dummy)) |
907 | return failure(); |
908 | auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw); |
909 | if (!extentsArray) |
910 | return failure(); |
911 | SmallVector<int64_t, 6> ints; |
912 | for (Attribute extent : extentsArray) { |
913 | IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent); |
914 | if (!attr) |
915 | return failure(); |
916 | ints.push_back(attr.getInt()); |
917 | } |
918 | Builder &builder = parser.getBuilder(); |
919 | result.addAttribute("shape", builder.getIndexTensorAttr(ints)); |
920 | Type resultTy; |
921 | if (parser.parseColonType(resultTy)) |
922 | return failure(); |
923 | result.types.push_back(resultTy); |
924 | return success(); |
925 | } |
926 | |
927 | OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); } |
928 | |
929 | void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
930 | MLIRContext *context) { |
931 | patterns.add<TensorCastConstShape>(context); |
932 | } |
933 | |
934 | LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( |
935 | MLIRContext *context, std::optional<Location> location, |
936 | ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { |
937 | Builder b(context); |
938 | const Properties prop = adaptor.getProperties(); |
939 | inferredReturnTypes.assign({RankedTensorType::get( |
940 | {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())}); |
941 | return success(); |
942 | } |
943 | |
944 | bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, |
945 | TypeRange r) { |
946 | if (l.size() != 1 || r.size() != 1) |
947 | return false; |
948 | |
949 | Type lhs = l.front(); |
950 | Type rhs = r.front(); |
951 | |
952 | if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs)) |
953 | // Shape type is compatible with all other valid return types. |
954 | return true; |
955 | return lhs == rhs; |
956 | } |
957 | |
958 | //===----------------------------------------------------------------------===// |
959 | // CstrBroadcastableOp |
960 | //===----------------------------------------------------------------------===// |
961 | |
962 | void CstrBroadcastableOp::getCanonicalizationPatterns( |
963 | RewritePatternSet &patterns, MLIRContext *context) { |
964 | // Canonicalization patterns have overlap with the considerations during |
965 | // folding in case additional shape information is inferred at some point that |
966 | // does not result in folding. |
967 | patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>, |
968 | CstrBroadcastableEqOps, |
969 | RemoveDuplicateOperandsPattern<CstrBroadcastableOp>, |
970 | RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context); |
971 | } |
972 | |
973 | // Return true if there is exactly one attribute not representing a scalar |
974 | // broadcast. |
975 | static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { |
976 | bool nonScalarSeen = false; |
977 | for (Attribute a : attributes) { |
978 | if (!a || llvm::cast<DenseIntElementsAttr>(Val&: a).getNumElements() != 0) { |
979 | if (nonScalarSeen) |
980 | return false; |
981 | nonScalarSeen = true; |
982 | } |
983 | } |
984 | return true; |
985 | } |
986 | |
987 | OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) { |
988 | // No broadcasting is needed if all operands but one are scalar. |
989 | if (hasAtMostSingleNonScalar(adaptor.getShapes())) |
990 | return BoolAttr::get(getContext(), true); |
991 | |
992 | if ([&] { |
993 | SmallVector<SmallVector<int64_t, 6>, 6> extents; |
994 | for (const auto &operand : adaptor.getShapes()) { |
995 | if (!operand) |
996 | return false; |
997 | extents.push_back(llvm::to_vector<6>( |
998 | llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>())); |
999 | } |
1000 | return OpTrait::util::staticallyKnownBroadcastable(extents); |
1001 | }()) |
1002 | return BoolAttr::get(getContext(), true); |
1003 | |
1004 | // Lastly, see if folding can be completed based on what constraints are known |
1005 | // on the input shapes. |
1006 | if ([&] { |
1007 | SmallVector<SmallVector<int64_t, 6>, 6> extents; |
1008 | for (auto shapeValue : getShapes()) { |
1009 | extents.emplace_back(); |
1010 | if (failed(getShapeVec(shapeValue, extents.back()))) |
1011 | return false; |
1012 | } |
1013 | return OpTrait::util::staticallyKnownBroadcastable(extents); |
1014 | }()) |
1015 | return BoolAttr::get(getContext(), true); |
1016 | |
1017 | // Because a failing witness result here represents an eventual assertion |
1018 | // failure, we do not replace it with a constant witness. |
1019 | return nullptr; |
1020 | } |
1021 | |
1022 | LogicalResult CstrBroadcastableOp::verify() { |
1023 | // Ensure that CstrBroadcastableOp contains at least two operands |
1024 | if (getNumOperands() < 2) |
1025 | return emitOpError("required at least 2 input shapes"); |
1026 | return success(); |
1027 | } |
1028 | |
1029 | //===----------------------------------------------------------------------===// |
1030 | // CstrEqOp |
1031 | //===----------------------------------------------------------------------===// |
1032 | |
1033 | void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1034 | MLIRContext *context) { |
1035 | // If inputs are equal, return passing witness |
1036 | patterns.add<CstrEqEqOps>(context); |
1037 | } |
1038 | |
1039 | OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) { |
1040 | if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) { |
1041 | return a && a == adaptor.getShapes().front(); |
1042 | })) |
1043 | return BoolAttr::get(getContext(), true); |
1044 | |
1045 | // Because a failing witness result here represents an eventual assertion |
1046 | // failure, we do not try to replace it with a constant witness. Similarly, we |
1047 | // cannot if there are any non-const inputs. |
1048 | return nullptr; |
1049 | } |
1050 | |
1051 | //===----------------------------------------------------------------------===// |
1052 | // ConstSizeOp |
1053 | //===----------------------------------------------------------------------===// |
1054 | |
1055 | void ConstSizeOp::build(OpBuilder &builder, OperationState &result, |
1056 | int64_t value) { |
1057 | build(builder, result, builder.getIndexAttr(value)); |
1058 | } |
1059 | |
1060 | OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); } |
1061 | |
1062 | void ConstSizeOp::getAsmResultNames( |
1063 | llvm::function_ref<void(Value, StringRef)> setNameFn) { |
1064 | SmallString<4> buffer; |
1065 | llvm::raw_svector_ostream os(buffer); |
1066 | os << "c"<< getValue(); |
1067 | setNameFn(getResult(), os.str()); |
1068 | } |
1069 | |
1070 | //===----------------------------------------------------------------------===// |
1071 | // ConstWitnessOp |
1072 | //===----------------------------------------------------------------------===// |
1073 | |
1074 | OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); } |
1075 | |
1076 | //===----------------------------------------------------------------------===// |
1077 | // CstrRequireOp |
1078 | //===----------------------------------------------------------------------===// |
1079 | |
1080 | OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) { |
1081 | return adaptor.getPred(); |
1082 | } |
1083 | |
1084 | //===----------------------------------------------------------------------===// |
1085 | // DimOp |
1086 | //===----------------------------------------------------------------------===// |
1087 | |
1088 | std::optional<int64_t> DimOp::getConstantIndex() { |
1089 | if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>()) |
1090 | return constSizeOp.getValue().getLimitedValue(); |
1091 | if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>()) |
1092 | return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt(); |
1093 | return std::nullopt; |
1094 | } |
1095 | |
1096 | OpFoldResult DimOp::fold(FoldAdaptor adaptor) { |
1097 | Type valType = getValue().getType(); |
1098 | auto valShapedType = llvm::dyn_cast<ShapedType>(valType); |
1099 | if (!valShapedType || !valShapedType.hasRank()) |
1100 | return nullptr; |
1101 | std::optional<int64_t> index = getConstantIndex(); |
1102 | if (!index.has_value()) |
1103 | return nullptr; |
1104 | if (index.value() < 0 || index.value() >= valShapedType.getRank()) |
1105 | return nullptr; |
1106 | auto extent = valShapedType.getDimSize(*index); |
1107 | if (ShapedType::isDynamic(extent)) |
1108 | return nullptr; |
1109 | return IntegerAttr::get(IndexType::get(getContext()), extent); |
1110 | } |
1111 | |
1112 | LogicalResult mlir::shape::DimOp::inferReturnTypes( |
1113 | MLIRContext *context, std::optional<Location> location, |
1114 | DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { |
1115 | inferredReturnTypes.assign({adaptor.getIndex().getType()}); |
1116 | return success(); |
1117 | } |
1118 | |
1119 | bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
1120 | return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); |
1121 | } |
1122 | |
1123 | //===----------------------------------------------------------------------===// |
1124 | // DivOp |
1125 | //===----------------------------------------------------------------------===// |
1126 | |
1127 | OpFoldResult DivOp::fold(FoldAdaptor adaptor) { |
1128 | auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs()); |
1129 | if (!lhs) |
1130 | return nullptr; |
1131 | auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()); |
1132 | if (!rhs || rhs.getValue().isZero()) |
1133 | return nullptr; |
1134 | |
1135 | // Division in APInt does not follow floor(lhs, rhs) when the result is |
1136 | // negative. Rather, APInt rounds toward zero. |
1137 | APInt quotient, remainder; |
1138 | APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); |
1139 | if (quotient.isNegative() && !remainder.isZero()) { |
1140 | quotient -= 1; |
1141 | } |
1142 | |
1143 | Type indexTy = IndexType::get(getContext()); |
1144 | return IntegerAttr::get(indexTy, quotient); |
1145 | } |
1146 | |
1147 | LogicalResult mlir::shape::DivOp::inferReturnTypes( |
1148 | MLIRContext *context, std::optional<Location> location, |
1149 | DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { |
1150 | if (llvm::isa<SizeType>(adaptor.getLhs().getType()) || |
1151 | llvm::isa<SizeType>(adaptor.getRhs().getType())) |
1152 | inferredReturnTypes.assign({SizeType::get(context)}); |
1153 | else |
1154 | inferredReturnTypes.assign({IndexType::get(context)}); |
1155 | return success(); |
1156 | } |
1157 | |
1158 | bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
1159 | // SizeType is compatible with IndexType. |
1160 | return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); |
1161 | } |
1162 | |
1163 | LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); } |
1164 | |
1165 | //===----------------------------------------------------------------------===// |
1166 | // ShapeEqOp |
1167 | //===----------------------------------------------------------------------===// |
1168 | |
1169 | OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) { |
1170 | bool allSame = true; |
1171 | if (!adaptor.getShapes().empty() && !adaptor.getShapes().front()) |
1172 | return {}; |
1173 | for (Attribute operand : adaptor.getShapes().drop_front()) { |
1174 | if (!operand) |
1175 | return {}; |
1176 | allSame = allSame && operand == adaptor.getShapes().front(); |
1177 | } |
1178 | return BoolAttr::get(getContext(), allSame); |
1179 | } |
1180 | |
1181 | //===----------------------------------------------------------------------===// |
1182 | // IndexToSizeOp |
1183 | //===----------------------------------------------------------------------===// |
1184 | |
1185 | OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) { |
1186 | // Constant values of both types, `shape.size` and `index`, are represented as |
1187 | // `IntegerAttr`s which makes constant folding simple. |
1188 | if (Attribute arg = adaptor.getArg()) |
1189 | return arg; |
1190 | return {}; |
1191 | } |
1192 | |
1193 | void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1194 | MLIRContext *context) { |
1195 | patterns.add<SizeToIndexToSizeCanonicalization>(context); |
1196 | } |
1197 | |
1198 | //===----------------------------------------------------------------------===// |
1199 | // FromExtentsOp |
1200 | //===----------------------------------------------------------------------===// |
1201 | |
1202 | OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) { |
1203 | if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; })) |
1204 | return nullptr; |
1205 | SmallVector<int64_t, 6> extents; |
1206 | for (auto attr : adaptor.getExtents()) |
1207 | extents.push_back(llvm::cast<IntegerAttr>(attr).getInt()); |
1208 | Builder builder(getContext()); |
1209 | return builder.getIndexTensorAttr(extents); |
1210 | } |
1211 | |
1212 | //===----------------------------------------------------------------------===// |
1213 | // FunctionLibraryOp |
1214 | //===----------------------------------------------------------------------===// |
1215 | |
1216 | void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result, |
1217 | StringRef name) { |
1218 | result.attributes.push_back(builder.getNamedAttr( |
1219 | ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); |
1220 | } |
1221 | |
1222 | FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { |
1223 | auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>( |
1224 | getMapping().get(op->getName().getIdentifier())); |
1225 | if (!attr) |
1226 | return nullptr; |
1227 | return lookupSymbol<FuncOp>(attr); |
1228 | } |
1229 | |
1230 | ParseResult FunctionLibraryOp::parse(OpAsmParser &parser, |
1231 | OperationState &result) { |
1232 | // Parse the op name. |
1233 | StringAttr nameAttr; |
1234 | if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), |
1235 | result.attributes)) |
1236 | return failure(); |
1237 | |
1238 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
1239 | return failure(); |
1240 | |
1241 | auto *bodyRegion = result.addRegion(); |
1242 | if (parser.parseRegion(*bodyRegion)) |
1243 | return failure(); |
1244 | |
1245 | if (parser.parseKeyword("mapping")) |
1246 | return failure(); |
1247 | |
1248 | DictionaryAttr mappingAttr; |
1249 | if (parser.parseAttribute(mappingAttr, |
1250 | parser.getBuilder().getType<NoneType>(), "mapping", |
1251 | result.attributes)) |
1252 | return failure(); |
1253 | return success(); |
1254 | } |
1255 | |
1256 | void FunctionLibraryOp::print(OpAsmPrinter &p) { |
1257 | p << ' '; |
1258 | p.printSymbolName(getName()); |
1259 | p.printOptionalAttrDictWithKeyword( |
1260 | (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"}); |
1261 | p << ' '; |
1262 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
1263 | /*printBlockTerminators=*/false); |
1264 | p << " mapping "; |
1265 | p.printAttributeWithoutType(getMappingAttr()); |
1266 | } |
1267 | |
1268 | //===----------------------------------------------------------------------===// |
1269 | // FuncOp |
1270 | //===----------------------------------------------------------------------===// |
1271 | |
1272 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
1273 | ArrayRef<NamedAttribute> attrs) { |
1274 | OpBuilder builder(location->getContext()); |
1275 | OperationState state(location, getOperationName()); |
1276 | FuncOp::build(builder, state, name, type, attrs); |
1277 | return cast<FuncOp>(Operation::create(state)); |
1278 | } |
1279 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
1280 | Operation::dialect_attr_range attrs) { |
1281 | SmallVector<NamedAttribute, 8> attrRef(attrs); |
1282 | return create(location, name, type, llvm::ArrayRef(attrRef)); |
1283 | } |
1284 | FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, |
1285 | ArrayRef<NamedAttribute> attrs, |
1286 | ArrayRef<DictionaryAttr> argAttrs) { |
1287 | FuncOp func = create(location, name, type, attrs); |
1288 | func.setAllArgAttrs(argAttrs); |
1289 | return func; |
1290 | } |
1291 | |
1292 | void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, |
1293 | FunctionType type, ArrayRef<NamedAttribute> attrs, |
1294 | ArrayRef<DictionaryAttr> argAttrs) { |
1295 | state.addAttribute(FuncOp::getSymNameAttrName(state.name), |
1296 | builder.getStringAttr(name)); |
1297 | state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name), |
1298 | TypeAttr::get(type)); |
1299 | state.attributes.append(attrs.begin(), attrs.end()); |
1300 | state.addRegion(); |
1301 | |
1302 | if (argAttrs.empty()) |
1303 | return; |
1304 | assert(type.getNumInputs() == argAttrs.size()); |
1305 | call_interface_impl::addArgAndResultAttrs( |
1306 | builder, state, argAttrs, /*resultAttrs=*/std::nullopt, |
1307 | getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); |
1308 | } |
1309 | |
1310 | ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
1311 | auto buildFuncType = |
1312 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
1313 | function_interface_impl::VariadicFlag, |
1314 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
1315 | |
1316 | return function_interface_impl::parseFunctionOp( |
1317 | parser, result, /*allowVariadic=*/false, |
1318 | getFunctionTypeAttrName(result.name), buildFuncType, |
1319 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
1320 | } |
1321 | |
1322 | void FuncOp::print(OpAsmPrinter &p) { |
1323 | function_interface_impl::printFunctionOp( |
1324 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
1325 | getArgAttrsAttrName(), getResAttrsAttrName()); |
1326 | } |
1327 | |
1328 | //===----------------------------------------------------------------------===// |
1329 | // GetExtentOp |
1330 | //===----------------------------------------------------------------------===// |
1331 | |
1332 | std::optional<int64_t> GetExtentOp::getConstantDim() { |
1333 | if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>()) |
1334 | return constSizeOp.getValue().getLimitedValue(); |
1335 | if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>()) |
1336 | return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt(); |
1337 | return std::nullopt; |
1338 | } |
1339 | |
1340 | OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) { |
1341 | auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape()); |
1342 | if (!elements) |
1343 | return nullptr; |
1344 | std::optional<int64_t> dim = getConstantDim(); |
1345 | if (!dim.has_value()) |
1346 | return nullptr; |
1347 | if (dim.value() >= elements.getNumElements()) |
1348 | return nullptr; |
1349 | return elements.getValues<Attribute>()[(uint64_t)dim.value()]; |
1350 | } |
1351 | |
1352 | void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, |
1353 | int64_t dim) { |
1354 | auto loc = result.location; |
1355 | auto dimAttr = builder.getIndexAttr(dim); |
1356 | if (llvm::isa<ShapeType>(shape.getType())) { |
1357 | Value dim = builder.create<ConstSizeOp>(loc, dimAttr); |
1358 | build(builder, result, builder.getType<SizeType>(), shape, dim); |
1359 | } else { |
1360 | Value dim = |
1361 | builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); |
1362 | build(builder, result, builder.getIndexType(), shape, dim); |
1363 | } |
1364 | } |
1365 | |
1366 | LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( |
1367 | MLIRContext *context, std::optional<Location> location, |
1368 | GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { |
1369 | inferredReturnTypes.assign({IndexType::get(context)}); |
1370 | return success(); |
1371 | } |
1372 | |
1373 | bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, |
1374 | TypeRange r) { |
1375 | // SizeType is compatible with IndexType. |
1376 | return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); |
1377 | } |
1378 | |
1379 | LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); } |
1380 | |
1381 | //===----------------------------------------------------------------------===// |
1382 | // IsBroadcastableOp |
1383 | //===----------------------------------------------------------------------===// |
1384 | |
1385 | void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1386 | MLIRContext *context) { |
1387 | patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context); |
1388 | } |
1389 | |
1390 | OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) { |
1391 | // Can always broadcast fewer than two shapes. |
1392 | if (adaptor.getShapes().size() < 2) { |
1393 | return BoolAttr::get(getContext(), true); |
1394 | } |
1395 | |
1396 | return nullptr; |
1397 | } |
1398 | |
1399 | //===----------------------------------------------------------------------===// |
1400 | // MeetOp |
1401 | //===----------------------------------------------------------------------===// |
1402 | |
1403 | LogicalResult mlir::shape::MeetOp::inferReturnTypes( |
1404 | MLIRContext *context, std::optional<Location> location, |
1405 | MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { |
1406 | if (adaptor.getOperands().empty()) |
1407 | return failure(); |
1408 | |
1409 | auto isShapeType = [](Type arg) { |
1410 | if (llvm::isa<ShapeType>(arg)) |
1411 | return true; |
1412 | return isExtentTensorType(arg); |
1413 | }; |
1414 | |
1415 | ValueRange::type_range types = adaptor.getOperands().getTypes(); |
1416 | Type acc = types.front(); |
1417 | for (auto t : drop_begin(types)) { |
1418 | Type l = acc, r = t; |
1419 | if (!llvm::isa<ShapeType, SizeType>(l)) |
1420 | std::swap(l, r); |
1421 | |
1422 | // Handle sizes, propagate error type if present. |
1423 | if (llvm::isa<SizeType>(l)) { |
1424 | if (llvm::isa<SizeType, IndexType>(r)) |
1425 | acc = l; |
1426 | else |
1427 | return emitOptionalError(location, "requires all sizes or shapes"); |
1428 | } else if (llvm::isa<IndexType>(l)) { |
1429 | if (llvm::isa<IndexType>(r)) |
1430 | acc = r; |
1431 | else |
1432 | return emitOptionalError(location, "requires all sizes or shapes"); |
1433 | } else if (llvm::isa<ShapeType>(l)) { |
1434 | // Handle shapes, propagate error type if present. |
1435 | if (isShapeType(r)) |
1436 | acc = l; |
1437 | else |
1438 | return emitOptionalError(location, "requires all sizes or shapes"); |
1439 | } else if (isExtentTensorType(l)) { |
1440 | auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0]; |
1441 | auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0]; |
1442 | if (ShapedType::isDynamic(rank1)) |
1443 | acc = l; |
1444 | else if (ShapedType::isDynamic(rank2)) |
1445 | acc = r; |
1446 | else if (rank1 != rank2) |
1447 | return emitOptionalError(location, "unequal shape cardinality"); |
1448 | else |
1449 | acc = l; |
1450 | } |
1451 | } |
1452 | inferredReturnTypes.assign({acc}); |
1453 | return success(); |
1454 | } |
1455 | |
1456 | bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
1457 | if (l.size() != 1 || r.size() != 1) |
1458 | return false; |
1459 | if (l == r) |
1460 | return true; |
1461 | |
1462 | Type lhs = l.front(); |
1463 | Type rhs = r.front(); |
1464 | |
1465 | if (!llvm::isa<ShapeType, SizeType>(lhs)) |
1466 | std::swap(lhs, rhs); |
1467 | |
1468 | if (llvm::isa<SizeType>(lhs)) |
1469 | return llvm::isa<SizeType, IndexType>(rhs); |
1470 | if (llvm::isa<ShapeType>(lhs)) |
1471 | return llvm::isa<ShapeType, TensorType>(rhs); |
1472 | |
1473 | if (succeeded(verifyCompatibleShapes({lhs, rhs}))) |
1474 | return true; |
1475 | return false; |
1476 | } |
1477 | |
1478 | //===----------------------------------------------------------------------===// |
1479 | // RankOp |
1480 | //===----------------------------------------------------------------------===// |
1481 | |
1482 | OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) { |
1483 | auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape()); |
1484 | if (!shape) |
1485 | return {}; |
1486 | int64_t rank = shape.getNumElements(); |
1487 | Builder builder(getContext()); |
1488 | return builder.getIndexAttr(rank); |
1489 | } |
1490 | |
1491 | /// Evaluate the `rank` operation for shapes of ranked tensors at compile time. |
1492 | /// Constant folding fails in cases where only the rank is constant, not the |
1493 | /// shape itself. |
1494 | /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. |
1495 | /// |
1496 | /// Example: |
1497 | /// |
1498 | /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> |
1499 | /// %rank = shape.rank %shape |
1500 | /// |
1501 | /// becomes |
1502 | /// |
1503 | /// %rank = shape.const_size 3 |
1504 | |
1505 | namespace { |
1506 | struct RankShapeOfCanonicalizationPattern |
1507 | : public OpRewritePattern<shape::RankOp> { |
1508 | using OpRewritePattern<shape::RankOp>::OpRewritePattern; |
1509 | |
1510 | LogicalResult matchAndRewrite(shape::RankOp op, |
1511 | PatternRewriter &rewriter) const override { |
1512 | auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>(); |
1513 | if (!shapeOfOp) |
1514 | return failure(); |
1515 | auto rankedTensorType = |
1516 | llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType()); |
1517 | if (!rankedTensorType) |
1518 | return failure(); |
1519 | int64_t rank = rankedTensorType.getRank(); |
1520 | if (llvm::isa<IndexType>(op.getType())) { |
1521 | rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(), |
1522 | rank); |
1523 | } else if (llvm::isa<shape::SizeType>(op.getType())) { |
1524 | rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank); |
1525 | } else { |
1526 | return failure(); |
1527 | } |
1528 | return success(); |
1529 | } |
1530 | }; |
1531 | } // namespace |
1532 | |
1533 | void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1534 | MLIRContext *context) { |
1535 | patterns.add<RankShapeOfCanonicalizationPattern>(context); |
1536 | } |
1537 | |
1538 | LogicalResult mlir::shape::RankOp::inferReturnTypes( |
1539 | MLIRContext *context, std::optional<Location> location, |
1540 | RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { |
1541 | if (llvm::isa<ShapeType>(adaptor.getShape().getType())) |
1542 | inferredReturnTypes.assign({SizeType::get(context)}); |
1543 | else |
1544 | inferredReturnTypes.assign({IndexType::get(context)}); |
1545 | return success(); |
1546 | } |
1547 | |
1548 | bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
1549 | // SizeType is compatible with IndexType. |
1550 | return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); |
1551 | } |
1552 | |
1553 | LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); } |
1554 | |
1555 | //===----------------------------------------------------------------------===// |
1556 | // NumElementsOp |
1557 | //===----------------------------------------------------------------------===// |
1558 | |
1559 | OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) { |
1560 | |
1561 | // Fold only when argument constant. |
1562 | Attribute shape = adaptor.getShape(); |
1563 | if (!shape) |
1564 | return {}; |
1565 | |
1566 | APInt product(64, 1); |
1567 | for (auto value : llvm::cast<DenseIntElementsAttr>(shape)) |
1568 | product *= value; |
1569 | Builder builder(getContext()); |
1570 | return builder.getIndexAttr(product.getLimitedValue()); |
1571 | } |
1572 | |
1573 | LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( |
1574 | MLIRContext *context, std::optional<Location> location, |
1575 | NumElementsOp::Adaptor adaptor, |
1576 | SmallVectorImpl<Type> &inferredReturnTypes) { |
1577 | if (llvm::isa<ShapeType>(adaptor.getShape().getType())) |
1578 | inferredReturnTypes.assign({SizeType::get(context)}); |
1579 | else |
1580 | inferredReturnTypes.assign({IndexType::get(context)}); |
1581 | return success(); |
1582 | } |
1583 | |
1584 | bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, |
1585 | TypeRange r) { |
1586 | // SizeType is compatible with IndexType. |
1587 | return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); |
1588 | } |
1589 | |
1590 | LogicalResult shape::NumElementsOp::verify() { |
1591 | return verifySizeOrIndexOp(*this); |
1592 | } |
1593 | |
1594 | //===----------------------------------------------------------------------===// |
1595 | // MaxOp |
1596 | //===----------------------------------------------------------------------===// |
1597 | |
1598 | OpFoldResult MaxOp::fold(FoldAdaptor adaptor) { |
1599 | // If operands are equal, just propagate one. |
1600 | if (getLhs() == getRhs()) |
1601 | return getLhs(); |
1602 | return nullptr; |
1603 | } |
1604 | |
1605 | LogicalResult mlir::shape::MaxOp::inferReturnTypes( |
1606 | MLIRContext *context, std::optional<Location> location, |
1607 | MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { |
1608 | if (adaptor.getLhs().getType() == adaptor.getRhs().getType()) |
1609 | inferredReturnTypes.assign({adaptor.getLhs().getType()}); |
1610 | else |
1611 | inferredReturnTypes.assign({SizeType::get(context)}); |
1612 | return success(); |
1613 | } |
1614 | |
1615 | bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
1616 | if (l.size() != 1 || r.size() != 1) |
1617 | return false; |
1618 | if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front())) |
1619 | return true; |
1620 | if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front())) |
1621 | return true; |
1622 | return false; |
1623 | } |
1624 | |
1625 | //===----------------------------------------------------------------------===// |
1626 | // MinOp |
1627 | //===----------------------------------------------------------------------===// |
1628 | |
1629 | OpFoldResult MinOp::fold(FoldAdaptor adaptor) { |
1630 | // If operands are equal, just propagate one. |
1631 | if (getLhs() == getRhs()) |
1632 | return getLhs(); |
1633 | return nullptr; |
1634 | } |
1635 | |
1636 | LogicalResult mlir::shape::MinOp::inferReturnTypes( |
1637 | MLIRContext *context, std::optional<Location> location, |
1638 | MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { |
1639 | if (adaptor.getLhs().getType() == adaptor.getRhs().getType()) |
1640 | inferredReturnTypes.assign({adaptor.getLhs().getType()}); |
1641 | else |
1642 | inferredReturnTypes.assign({SizeType::get(context)}); |
1643 | return success(); |
1644 | } |
1645 | |
1646 | bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
1647 | if (l.size() != 1 || r.size() != 1) |
1648 | return false; |
1649 | if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front())) |
1650 | return true; |
1651 | if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front())) |
1652 | return true; |
1653 | return false; |
1654 | } |
1655 | |
1656 | //===----------------------------------------------------------------------===// |
1657 | // MulOp |
1658 | //===----------------------------------------------------------------------===// |
1659 | |
1660 | OpFoldResult MulOp::fold(FoldAdaptor adaptor) { |
1661 | auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs()); |
1662 | if (!lhs) |
1663 | return nullptr; |
1664 | auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()); |
1665 | if (!rhs) |
1666 | return nullptr; |
1667 | APInt folded = lhs.getValue() * rhs.getValue(); |
1668 | Type indexTy = IndexType::get(getContext()); |
1669 | return IntegerAttr::get(indexTy, folded); |
1670 | } |
1671 | |
1672 | LogicalResult mlir::shape::MulOp::inferReturnTypes( |
1673 | MLIRContext *context, std::optional<Location> location, |
1674 | MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { |
1675 | if (llvm::isa<SizeType>(adaptor.getLhs().getType()) || |
1676 | llvm::isa<SizeType>(adaptor.getRhs().getType())) |
1677 | inferredReturnTypes.assign({SizeType::get(context)}); |
1678 | else |
1679 | inferredReturnTypes.assign({IndexType::get(context)}); |
1680 | return success(); |
1681 | } |
1682 | |
1683 | bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
1684 | // SizeType is compatible with IndexType. |
1685 | return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r); |
1686 | } |
1687 | |
1688 | LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } |
1689 | |
1690 | //===----------------------------------------------------------------------===// |
1691 | // ShapeOfOp |
1692 | //===----------------------------------------------------------------------===// |
1693 | |
1694 | namespace { |
1695 | /// Replace shape_of(x) where x has a constant shape with a const_shape op. |
1696 | struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> { |
1697 | using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; |
1698 | |
1699 | LogicalResult matchAndRewrite(shape::ShapeOfOp op, |
1700 | PatternRewriter &rewriter) const override { |
1701 | auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType()); |
1702 | if (!type || !type.hasStaticShape()) |
1703 | return failure(); |
1704 | Location loc = op.getLoc(); |
1705 | Value constShape = |
1706 | rewriter |
1707 | .create<ConstShapeOp>(loc, |
1708 | rewriter.getIndexTensorAttr(type.getShape())) |
1709 | .getResult(); |
1710 | if (constShape.getType() != op.getResult().getType()) |
1711 | constShape = rewriter.create<tensor::CastOp>( |
1712 | loc, op.getResult().getType(), constShape); |
1713 | rewriter.replaceOp(op, constShape); |
1714 | return success(); |
1715 | } |
1716 | }; |
1717 | |
1718 | // Canonicalize |
1719 | // |
1720 | // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> |
1721 | // %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex> |
1722 | // |
1723 | // to |
1724 | // |
1725 | // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32> |
1726 | // %1 = %shape |
1727 | // |
1728 | struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> { |
1729 | using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; |
1730 | |
1731 | LogicalResult matchAndRewrite(shape::ShapeOfOp op, |
1732 | PatternRewriter &rewriter) const override { |
1733 | auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>(); |
1734 | if (!tensorReshapeOp) |
1735 | return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape"); |
1736 | if (!isa<TensorType>(op.getType())) |
1737 | return rewriter.notifyMatchFailure(op, "result is not a tensor"); |
1738 | |
1739 | // Operand 'shape' of 'tensor.reshape' may now be used as the result of |
1740 | // 'shape.shape_of'. While its type is guaranteed to be compatible in well- |
1741 | // formed IR, it may not be identical (dynamically vs statically shaped), |
1742 | // in which case it needs to be cast first using 'tensor.cast'. |
1743 | // Additionally, it may not have identical element type (i32 vs index) |
1744 | // while it has identical shaped type (dynamic vs static), in which case it |
1745 | // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of' |
1746 | // op result must be shape or extent tensor. |
1747 | Value shape = tensorReshapeOp.getShape(); |
1748 | |
1749 | auto opTensorTy = cast<RankedTensorType>(op.getType()); |
1750 | auto shapeTensorTy = cast<RankedTensorType>(shape.getType()); |
1751 | |
1752 | if (opTensorTy != shapeTensorTy) { |
1753 | if (opTensorTy.getElementType() == shapeTensorTy.getElementType()) |
1754 | shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape); |
1755 | else if (!isExtentTensorType(shapeTensorTy)) |
1756 | shape = |
1757 | rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape); |
1758 | } |
1759 | |
1760 | rewriter.replaceOp(op, shape); |
1761 | return success(); |
1762 | } |
1763 | }; |
1764 | |
1765 | // Canonicalize |
1766 | // ``` |
1767 | // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex> |
1768 | // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex> |
1769 | // ``` |
1770 | // to |
1771 | // ``` |
1772 | // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex> |
1773 | // ``` |
1774 | struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { |
1775 | using OpRewritePattern<tensor::CastOp>::OpRewritePattern; |
1776 | |
1777 | LogicalResult matchAndRewrite(tensor::CastOp op, |
1778 | PatternRewriter &rewriter) const override { |
1779 | auto ty = llvm::dyn_cast<RankedTensorType>(op.getType()); |
1780 | if (!ty || ty.getRank() != 1) |
1781 | return failure(); |
1782 | |
1783 | auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>(); |
1784 | if (!shapeOfOp) |
1785 | return failure(); |
1786 | |
1787 | // Argument type must be ranked and must not conflict. |
1788 | auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType()); |
1789 | if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) |
1790 | return failure(); |
1791 | |
1792 | rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg()); |
1793 | return success(); |
1794 | } |
1795 | }; |
1796 | } // namespace |
1797 | |
1798 | void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1799 | MLIRContext *context) { |
1800 | patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape, |
1801 | ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>( |
1802 | context); |
1803 | } |
1804 | |
1805 | LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( |
1806 | MLIRContext *context, std::optional<Location> location, |
1807 | ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { |
1808 | if (llvm::isa<ValueShapeType>(adaptor.getArg().getType())) |
1809 | inferredReturnTypes.assign({ShapeType::get(context)}); |
1810 | else { |
1811 | auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType()); |
1812 | int64_t rank = |
1813 | shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic; |
1814 | Type indexTy = IndexType::get(context); |
1815 | Type extentTensorTy = RankedTensorType::get({rank}, indexTy); |
1816 | inferredReturnTypes.assign({extentTensorTy}); |
1817 | } |
1818 | return success(); |
1819 | } |
1820 | |
1821 | bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { |
1822 | if (l.size() != 1 || r.size() != 1) |
1823 | return false; |
1824 | if (l == r) |
1825 | return true; |
1826 | |
1827 | Type lhs = l.front(); |
1828 | Type rhs = r.front(); |
1829 | |
1830 | if (!llvm::isa<ShapeType, ShapedType>(lhs) || |
1831 | !llvm::isa<ShapeType, ShapedType>(rhs)) |
1832 | return false; |
1833 | |
1834 | if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs)) |
1835 | // Shape type is compatible with all other valid return types. |
1836 | return true; |
1837 | |
1838 | if (succeeded(verifyCompatibleShapes({lhs, rhs}))) |
1839 | return true; |
1840 | return false; |
1841 | } |
1842 | |
1843 | LogicalResult shape::ShapeOfOp::verify() { |
1844 | return verifyShapeOrExtentTensorOp(*this); |
1845 | } |
1846 | |
1847 | //===----------------------------------------------------------------------===// |
1848 | // SizeToIndexOp |
1849 | //===----------------------------------------------------------------------===// |
1850 | |
1851 | OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) { |
1852 | // Constant values of both types, `shape.size` and `index`, are represented as |
1853 | // `IntegerAttr`s which makes constant folding simple. |
1854 | if (Attribute arg = adaptor.getArg()) |
1855 | return arg; |
1856 | return OpFoldResult(); |
1857 | } |
1858 | |
1859 | void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1860 | MLIRContext *context) { |
1861 | patterns.add<IndexToSizeToIndexCanonicalization>(context); |
1862 | } |
1863 | |
1864 | bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1865 | if (inputs.size() != 1 || outputs.size() != 1) |
1866 | return false; |
1867 | return llvm::isa<IndexType, SizeType>(inputs[0]) && |
1868 | llvm::isa<IndexType>(outputs[0]); |
1869 | } |
1870 | |
1871 | //===----------------------------------------------------------------------===// |
1872 | // YieldOp |
1873 | //===----------------------------------------------------------------------===// |
1874 | |
1875 | LogicalResult shape::YieldOp::verify() { |
1876 | auto *parentOp = (*this)->getParentOp(); |
1877 | auto results = parentOp->getResults(); |
1878 | auto operands = getOperands(); |
1879 | |
1880 | if (parentOp->getNumResults() != getNumOperands()) |
1881 | return emitOpError() << "number of operands does not match number of " |
1882 | "results of its parent"; |
1883 | for (auto e : llvm::zip(results, operands)) |
1884 | if (std::get<0>(e).getType() != std::get<1>(e).getType()) |
1885 | return emitOpError() << "types mismatch between yield op and its parent"; |
1886 | |
1887 | return success(); |
1888 | } |
1889 | |
1890 | //===----------------------------------------------------------------------===// |
1891 | // SplitAtOp |
1892 | //===----------------------------------------------------------------------===// |
1893 | |
1894 | LogicalResult SplitAtOp::fold(FoldAdaptor adaptor, |
1895 | SmallVectorImpl<OpFoldResult> &results) { |
1896 | if (!adaptor.getOperand() || !adaptor.getIndex()) |
1897 | return failure(); |
1898 | auto shapeVec = llvm::to_vector<6>( |
1899 | llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>()); |
1900 | auto shape = llvm::ArrayRef(shapeVec); |
1901 | auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt(); |
1902 | // Verify that the split point is in the correct range. |
1903 | // TODO: Constant fold to an "error". |
1904 | int64_t rank = shape.size(); |
1905 | if (-rank > splitPoint || splitPoint > rank) |
1906 | return failure(); |
1907 | if (splitPoint < 0) |
1908 | splitPoint += shape.size(); |
1909 | Builder builder(adaptor.getOperand().getContext()); |
1910 | results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); |
1911 | results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); |
1912 | return success(); |
1913 | } |
1914 | |
1915 | //===----------------------------------------------------------------------===// |
1916 | // ToExtentTensorOp |
1917 | //===----------------------------------------------------------------------===// |
1918 | |
1919 | OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) { |
1920 | if (!adaptor.getInput()) |
1921 | return OpFoldResult(); |
1922 | Builder builder(getContext()); |
1923 | auto shape = llvm::to_vector<6>( |
1924 | llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>()); |
1925 | auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())}, |
1926 | builder.getIndexType()); |
1927 | return DenseIntElementsAttr::get(type, shape); |
1928 | } |
1929 | |
1930 | bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { |
1931 | if (inputs.size() != 1 || outputs.size() != 1) |
1932 | return false; |
1933 | if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) { |
1934 | if (!llvm::isa<IndexType>(inputTensor.getElementType()) || |
1935 | inputTensor.getRank() != 1) |
1936 | return false; |
1937 | } else if (!llvm::isa<ShapeType>(inputs[0])) { |
1938 | return false; |
1939 | } |
1940 | |
1941 | TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]); |
1942 | return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType()); |
1943 | } |
1944 | |
1945 | //===----------------------------------------------------------------------===// |
1946 | // ReduceOp |
1947 | //===----------------------------------------------------------------------===// |
1948 | |
1949 | void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, |
1950 | ValueRange initVals) { |
1951 | OpBuilder::InsertionGuard g(builder); |
1952 | result.addOperands(shape); |
1953 | result.addOperands(initVals); |
1954 | |
1955 | Region *bodyRegion = result.addRegion(); |
1956 | Block *bodyBlock = builder.createBlock( |
1957 | bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location); |
1958 | |
1959 | Type elementType; |
1960 | if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType())) |
1961 | elementType = tensorType.getElementType(); |
1962 | else |
1963 | elementType = SizeType::get(builder.getContext()); |
1964 | bodyBlock->addArgument(elementType, shape.getLoc()); |
1965 | |
1966 | for (Value initVal : initVals) { |
1967 | bodyBlock->addArgument(initVal.getType(), initVal.getLoc()); |
1968 | result.addTypes(initVal.getType()); |
1969 | } |
1970 | } |
1971 | |
1972 | LogicalResult ReduceOp::verify() { |
1973 | // Verify block arg types. |
1974 | Block &block = getRegion().front(); |
1975 | |
1976 | // The block takes index, extent, and aggregated values as arguments. |
1977 | auto blockArgsCount = getInitVals().size() + 2; |
1978 | if (block.getNumArguments() != blockArgsCount) |
1979 | return emitOpError() << "ReduceOp body is expected to have " |
1980 | << blockArgsCount << " arguments"; |
1981 | |
1982 | // The first block argument is the index and must always be of type `index`. |
1983 | if (!llvm::isa<IndexType>(block.getArgument(0).getType())) |
1984 | return emitOpError( |
1985 | "argument 0 of ReduceOp body is expected to be of IndexType"); |
1986 | |
1987 | // The second block argument is the extent and must be of type `size` or |
1988 | // `index`, depending on whether the reduce operation is applied to a shape or |
1989 | // to an extent tensor. |
1990 | Type extentTy = block.getArgument(1).getType(); |
1991 | if (llvm::isa<ShapeType>(getShape().getType())) { |
1992 | if (!llvm::isa<SizeType>(extentTy)) |
1993 | return emitOpError("argument 1 of ReduceOp body is expected to be of " |
1994 | "SizeType if the ReduceOp operates on a ShapeType"); |
1995 | } else { |
1996 | if (!llvm::isa<IndexType>(extentTy)) |
1997 | return emitOpError( |
1998 | "argument 1 of ReduceOp body is expected to be of IndexType if the " |
1999 | "ReduceOp operates on an extent tensor"); |
2000 | } |
2001 | |
2002 | for (const auto &type : llvm::enumerate(getInitVals())) |
2003 | if (block.getArgument(type.index() + 2).getType() != type.value().getType()) |
2004 | return emitOpError() << "type mismatch between argument " |
2005 | << type.index() + 2 |
2006 | << " of ReduceOp body and initial value " |
2007 | << type.index(); |
2008 | return success(); |
2009 | } |
2010 | |
2011 | ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { |
2012 | // Parse operands. |
2013 | SmallVector<OpAsmParser::UnresolvedOperand, 3> operands; |
2014 | Type shapeOrExtentTensorType; |
2015 | if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, |
2016 | OpAsmParser::Delimiter::Paren) || |
2017 | parser.parseColonType(shapeOrExtentTensorType) || |
2018 | parser.parseOptionalArrowTypeList(result.types)) |
2019 | return failure(); |
2020 | |
2021 | // Resolve operands. |
2022 | auto initVals = llvm::ArrayRef(operands).drop_front(); |
2023 | if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, |
2024 | result.operands) || |
2025 | parser.resolveOperands(initVals, result.types, parser.getNameLoc(), |
2026 | result.operands)) |
2027 | return failure(); |
2028 | |
2029 | // Parse the body. |
2030 | Region *body = result.addRegion(); |
2031 | if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) |
2032 | return failure(); |
2033 | |
2034 | // Parse attributes. |
2035 | if (parser.parseOptionalAttrDict(result.attributes)) |
2036 | return failure(); |
2037 | |
2038 | return success(); |
2039 | } |
2040 | |
2041 | void ReduceOp::print(OpAsmPrinter &p) { |
2042 | p << '(' << getShape() << ", "<< getInitVals() |
2043 | << ") : "<< getShape().getType(); |
2044 | p.printOptionalArrowTypeList(getResultTypes()); |
2045 | p << ' '; |
2046 | p.printRegion(getRegion()); |
2047 | p.printOptionalAttrDict((*this)->getAttrs()); |
2048 | } |
2049 | |
2050 | #define GET_OP_CLASSES |
2051 | #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" |
2052 | |
2053 | #define GET_TYPEDEF_CLASSES |
2054 | #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc" |
2055 |
Definitions
- getExtentTensorType
- isExtentTensorType
- getShapeVec
- isErrorPropagationPossible
- verifySizeOrIndexOp
- verifyShapeOrExtentTensorOp
- eachHasOnlyOneOfTypes
- eachHasOnlyOneOfTypes
- ShapeInlinerInterface
- isLegalToInline
- isLegalToInline
- AssumingWithTrue
- matchAndRewrite
- AssumingOpRemoveUnusedResults
- matchAndRewrite
- MergeAssumingAllOps
- matchAndRewrite
- AssumingAllOfCstrBroadcastable
- matchAndRewrite
- AssumingAllToCstrEqCanonicalization
- matchAndRewrite
- RemoveDuplicateOperandsPattern
- matchAndRewrite
- RemoveEmptyShapeOperandsPattern
- matchAndRewrite
- BroadcastForwardSingleOperandPattern
- matchAndRewrite
- BroadcastFoldConstantOperandsPattern
- matchAndRewrite
- CanonicalizeCastExtentTensorOperandsPattern
- matchAndRewrite
- BroadcastConcretizeResultTypePattern
- matchAndRewrite
- hasAtMostSingleNonScalar
- RankShapeOfCanonicalizationPattern
- matchAndRewrite
- ShapeOfOpToConstShapeOp
- matchAndRewrite
- ShapeOfFromReshape
- matchAndRewrite
- ShapeOfCastExtentTensor
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more