1//===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "mlir/Dialect/Shape/IR/Shape.h"
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15#include "mlir/Dialect/CommonFolders.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Traits.h"
18#include "mlir/Dialect/UB/IR/UBOps.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/DialectImplementation.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/IR/TypeUtilities.h"
25#include "mlir/Interfaces/FunctionImplementation.h"
26#include "mlir/Transforms/InliningUtils.h"
27#include "llvm/ADT/SetOperations.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/raw_ostream.h"
30
31using namespace mlir;
32using namespace mlir::shape;
33
34#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
35
36namespace {
37#include "ShapeCanonicalization.inc"
38} // namespace
39
40RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
41 return RankedTensorType::get(shape: {rank}, elementType: IndexType::get(context: ctx));
42}
43
44bool shape::isExtentTensorType(Type type) {
45 auto ranked = llvm::dyn_cast<RankedTensorType>(Val&: type);
46 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
47}
48
49LogicalResult shape::getShapeVec(Value input,
50 SmallVectorImpl<int64_t> &shapeValues) {
51 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
52 auto type = llvm::cast<ShapedType>(Val: inputOp.getArg().getType());
53 if (!type.hasRank())
54 return failure();
55 llvm::append_range(C&: shapeValues, R: type.getShape());
56 return success();
57 }
58 DenseIntElementsAttr attr;
59 if (matchPattern(value: input, pattern: m_Constant(bind_value: &attr))) {
60 llvm::append_range(C&: shapeValues, R: attr.getValues<int64_t>());
61 return success();
62 }
63 return failure();
64}
65
66static bool isErrorPropagationPossible(TypeRange operandTypes) {
67 return llvm::any_of(Range&: operandTypes,
68 P: llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
69}
70
71static LogicalResult verifySizeOrIndexOp(Operation *op) {
72 assert(op != nullptr && op->getNumResults() == 1);
73 Type resultTy = op->getResultTypes().front();
74 if (isErrorPropagationPossible(operandTypes: op->getOperandTypes())) {
75 if (!llvm::isa<SizeType>(Val: resultTy))
76 return op->emitOpError()
77 << "if at least one of the operands can hold error values then "
78 "the result must be of type `size` to propagate them";
79 }
80 return success();
81}
82
83static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
84 assert(op != nullptr && op->getNumResults() == 1);
85 Type resultTy = op->getResultTypes().front();
86 if (isErrorPropagationPossible(operandTypes: op->getOperandTypes())) {
87 if (!llvm::isa<ShapeType>(Val: resultTy))
88 return op->emitOpError()
89 << "if at least one of the operands can hold error values then "
90 "the result must be of type `shape` to propagate them";
91 }
92 return success();
93}
94
95template <typename... Ty>
96static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
98}
99
100template <typename... Ty, typename... ranges>
101static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103}
104
105//===----------------------------------------------------------------------===//
106// InlinerInterface
107//===----------------------------------------------------------------------===//
108
109namespace {
110/// This class defines the interface for inlining shape dialect ops.
111struct ShapeInlinerInterface : public DialectInlinerInterface {
112 using DialectInlinerInterface::DialectInlinerInterface;
113
114 // Returns true if the given region 'src' can be inlined into the region
115 // 'dest' that is attached to an operation registered to the current dialect.
116 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117 IRMapping &) const final {
118 return true;
119 }
120
121 // Returns true if the given operation 'op', that is registered to this
122 // dialect, can be inlined into the region 'dest' that is attached to an
123 // operation registered to the current dialect.
124 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125 IRMapping &) const final {
126 return true;
127 }
128};
129} // namespace
130
131void ShapeDialect::initialize() {
132 addOperations<
133#define GET_OP_LIST
134#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135 >();
136 addTypes<
137#define GET_TYPEDEF_LIST
138#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
139 >();
140 addInterfaces<ShapeInlinerInterface>();
141 // Allow unknown operations during prototyping and testing. As the dialect is
142 // still evolving it makes it simple to start with an unregistered ops and
143 // try different variants before actually defining the op.
144 allowUnknownOperations();
145 declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
146 AssumingYieldOp>();
147}
148
149Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
150 Attribute value, Type type,
151 Location loc) {
152 if (auto poison = dyn_cast<ub::PoisonAttr>(Val&: value))
153 return builder.create<ub::PoisonOp>(location: loc, args&: type, args&: poison);
154
155 if (llvm::isa<ShapeType>(Val: type) || isExtentTensorType(type))
156 return builder.create<ConstShapeOp>(
157 location: loc, args&: type, args: llvm::cast<DenseIntElementsAttr>(Val&: value));
158 if (llvm::isa<SizeType>(Val: type))
159 return builder.create<ConstSizeOp>(location: loc, args&: type,
160 args: llvm::cast<IntegerAttr>(Val&: value));
161 if (llvm::isa<WitnessType>(Val: type))
162 return builder.create<ConstWitnessOp>(location: loc, args&: type,
163 args: llvm::cast<BoolAttr>(Val&: value));
164
165 return arith::ConstantOp::materialize(builder, value, type, loc);
166}
167
168LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
169 NamedAttribute attribute) {
170 // Verify shape.lib attribute.
171 if (attribute.getName() == "shape.lib") {
172 if (!op->hasTrait<OpTrait::SymbolTable>())
173 return op->emitError(
174 message: "shape.lib attribute may only be on op implementing SymbolTable");
175
176 if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(Val: attribute.getValue())) {
177 auto *symbol = SymbolTable::lookupSymbolIn(op, symbol: symbolRef);
178 if (!symbol)
179 return op->emitError(message: "shape function library ")
180 << symbolRef << " not found";
181 return isa<shape::FunctionLibraryOp>(Val: symbol)
182 ? success()
183 : op->emitError()
184 << symbolRef << " required to be shape function library";
185 }
186
187 if (auto arr = llvm::dyn_cast<ArrayAttr>(Val: attribute.getValue())) {
188 // Verify all entries are function libraries and mappings in libraries
189 // refer to unique ops.
190 DenseSet<StringAttr> key;
191 for (auto it : arr) {
192 if (!llvm::isa<SymbolRefAttr>(Val: it))
193 return op->emitError(
194 message: "only SymbolRefAttr allowed in shape.lib attribute array");
195
196 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
197 Val: SymbolTable::lookupSymbolIn(op, symbol: llvm::cast<SymbolRefAttr>(Val&: it)));
198 if (!shapeFnLib)
199 return op->emitError()
200 << it << " does not refer to FunctionLibraryOp";
201 for (auto mapping : shapeFnLib.getMapping()) {
202 if (!key.insert(V: mapping.getName()).second) {
203 return op->emitError(message: "only one op to shape mapping allowed, found "
204 "multiple for `")
205 << mapping.getName() << "`";
206 }
207 }
208 }
209 return success();
210 }
211
212 return op->emitError(message: "only SymbolRefAttr or array of SymbolRefAttrs "
213 "allowed as shape.lib attribute");
214 }
215 return success();
216}
217
218//===----------------------------------------------------------------------===//
219// AnyOp
220//===----------------------------------------------------------------------===//
221
222// TODO: Canonicalization should be implemented for shapes that can be
223// determined through mixtures of the known dimensions of the inputs.
224OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
225 // Only the last operand is checked because AnyOp is commutative.
226 if (adaptor.getInputs().back())
227 return adaptor.getInputs().back();
228
229 return nullptr;
230}
231
232//===----------------------------------------------------------------------===//
233// AssumingOp
234//===----------------------------------------------------------------------===//
235
236ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
237 result.regions.reserve(N: 1);
238 Region *doRegion = result.addRegion();
239
240 auto &builder = parser.getBuilder();
241 OpAsmParser::UnresolvedOperand cond;
242 if (parser.parseOperand(result&: cond) ||
243 parser.resolveOperand(operand: cond, type: builder.getType<WitnessType>(),
244 result&: result.operands))
245 return failure();
246
247 // Parse optional results type list.
248 if (parser.parseOptionalArrowTypeList(result&: result.types))
249 return failure();
250
251 // Parse the region and add a terminator if elided.
252 if (parser.parseRegion(region&: *doRegion, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
253 return failure();
254 AssumingOp::ensureTerminator(region&: *doRegion, builder&: parser.getBuilder(), loc: result.location);
255
256 // Parse the optional attribute list.
257 if (parser.parseOptionalAttrDict(result&: result.attributes))
258 return failure();
259 return success();
260}
261
262void AssumingOp::print(OpAsmPrinter &p) {
263 bool yieldsResults = !getResults().empty();
264
265 p << " " << getWitness();
266 if (yieldsResults)
267 p << " -> (" << getResultTypes() << ")";
268 p << ' ';
269 p.printRegion(blocks&: getDoRegion(),
270 /*printEntryBlockArgs=*/false,
271 /*printBlockTerminators=*/yieldsResults);
272 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
273}
274
275namespace {
276// Removes AssumingOp with a passing witness and inlines the region.
277struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
278 using OpRewritePattern<AssumingOp>::OpRewritePattern;
279
280 LogicalResult matchAndRewrite(AssumingOp op,
281 PatternRewriter &rewriter) const override {
282 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
283 if (!witness || !witness.getPassingAttr())
284 return failure();
285
286 AssumingOp::inlineRegionIntoParent(op, rewriter);
287 return success();
288 }
289};
290
291struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
292 using OpRewritePattern<AssumingOp>::OpRewritePattern;
293
294 LogicalResult matchAndRewrite(AssumingOp op,
295 PatternRewriter &rewriter) const override {
296 Block *body = op.getBody();
297 auto yieldOp = llvm::cast<AssumingYieldOp>(Val: body->getTerminator());
298
299 // Find used values.
300 SmallVector<Value, 4> newYieldOperands;
301 for (auto [opResult, yieldOperand] :
302 llvm::zip(t: op.getResults(), u: yieldOp.getOperands())) {
303 if (!opResult.getUses().empty()) {
304 newYieldOperands.push_back(Elt: yieldOperand);
305 }
306 }
307
308 // Rewrite only if redundant results exist.
309 if (newYieldOperands.size() == yieldOp->getNumOperands())
310 return failure();
311
312 // Replace yield op in the old assuming op's body and move the entire region
313 // to the new assuming op.
314 rewriter.setInsertionPointToEnd(body);
315 auto newYieldOp =
316 rewriter.replaceOpWithNewOp<AssumingYieldOp>(op: yieldOp, args&: newYieldOperands);
317 rewriter.setInsertionPoint(op);
318 auto newOp = rewriter.create<AssumingOp>(
319 location: op.getLoc(), args: newYieldOp->getOperandTypes(), args: op.getWitness());
320 newOp.getDoRegion().takeBody(other&: op.getDoRegion());
321
322 // Use the new results to replace the previously used ones.
323 SmallVector<Value, 4> replacementValues;
324 auto src = newOp.getResults().begin();
325 for (auto it : op.getResults()) {
326 if (it.getUses().empty())
327 replacementValues.push_back(Elt: nullptr);
328 else
329 replacementValues.push_back(Elt: *src++);
330 }
331 rewriter.replaceOp(op, newValues: replacementValues);
332 return success();
333 }
334};
335} // namespace
336
337void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
338 MLIRContext *context) {
339 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(arg&: context);
340}
341
342// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
343void AssumingOp::getSuccessorRegions(
344 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
345 // AssumingOp has unconditional control flow into the region and back to the
346 // parent, so return the correct RegionSuccessor purely based on the index
347 // being None or 0.
348 if (!point.isParent()) {
349 regions.push_back(Elt: RegionSuccessor(getResults()));
350 return;
351 }
352
353 regions.push_back(Elt: RegionSuccessor(&getDoRegion()));
354}
355
356void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
357 PatternRewriter &rewriter) {
358 auto *blockBeforeAssuming = rewriter.getInsertionBlock();
359 auto *assumingBlock = op.getBody();
360 auto initPosition = rewriter.getInsertionPoint();
361 auto *blockAfterAssuming =
362 rewriter.splitBlock(block: blockBeforeAssuming, before: initPosition);
363
364 // Remove the AssumingOp and AssumingYieldOp.
365 auto &yieldOp = assumingBlock->back();
366 rewriter.inlineRegionBefore(region&: op.getDoRegion(), before: blockAfterAssuming);
367 rewriter.replaceOp(op, newValues: yieldOp.getOperands());
368 rewriter.eraseOp(op: &yieldOp);
369
370 // Merge blocks together as there was no branching behavior from the
371 // AssumingOp.
372 rewriter.mergeBlocks(source: assumingBlock, dest: blockBeforeAssuming);
373 rewriter.mergeBlocks(source: blockAfterAssuming, dest: blockBeforeAssuming);
374}
375
376void AssumingOp::build(
377 OpBuilder &builder, OperationState &result, Value witness,
378 function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {
379 OpBuilder::InsertionGuard g(builder);
380
381 result.addOperands(newOperands: witness);
382 Region *bodyRegion = result.addRegion();
383 builder.createBlock(parent: bodyRegion);
384
385 // Build body.
386 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
387 builder.create<AssumingYieldOp>(location: result.location, args&: yieldValues);
388
389 SmallVector<Type, 2> assumingTypes;
390 for (Value v : yieldValues)
391 assumingTypes.push_back(Elt: v.getType());
392 result.addTypes(newTypes: assumingTypes);
393}
394
395//===----------------------------------------------------------------------===//
396// AddOp
397//===----------------------------------------------------------------------===//
398
399LogicalResult mlir::shape::AddOp::inferReturnTypes(
400 MLIRContext *context, std::optional<Location> location,
401 AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
402 if (llvm::isa<SizeType>(Val: adaptor.getLhs().getType()) ||
403 llvm::isa<SizeType>(Val: adaptor.getRhs().getType()))
404 inferredReturnTypes.assign(IL: {SizeType::get(ctx: context)});
405 else
406 inferredReturnTypes.assign(IL: {IndexType::get(context)});
407 return success();
408}
409
410bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
411 // SizeType is compatible with IndexType.
412 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, rs: r);
413}
414
415OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
416 // add(x, 0) -> x
417 if (matchPattern(value: getRhs(), pattern: m_Zero()))
418 return getLhs();
419
420 return constFoldBinaryOp<IntegerAttr>(
421 operands: adaptor.getOperands(),
422 calculate: [](APInt a, const APInt &b) { return std::move(a) + b; });
423}
424
425LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(op: *this); }
426
427//===----------------------------------------------------------------------===//
428// AssumingAllOp
429//===----------------------------------------------------------------------===//
430
431namespace {
432
433// Merge multiple `shape.assuming_all` operations together.
434//
435// %0 = shape.assuming_all %w0, %w1
436// %1 = shape.assuming_all %w2, %0
437//
438// to:
439//
440// %0 = shape.assuming_all %w0, %w2, %w2
441struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
442 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
443
444 LogicalResult matchAndRewrite(AssumingAllOp op,
445 PatternRewriter &rewriter) const override {
446 SmallVector<Value> operands;
447
448 for (Value operand : op.getInputs()) {
449 if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
450 operands.append(in_start: assumeAll.operand_begin(), in_end: assumeAll->operand_end());
451 else
452 operands.push_back(Elt: operand);
453 }
454
455 // We didn't find any other `assuming_all` ops to merge with.
456 if (operands.size() == op.getNumOperands())
457 return failure();
458
459 // Replace with a new `assuming_all` operation with merged constraints.
460 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, args&: operands);
461 return success();
462 }
463};
464
465// Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
466// are subsumed by others.
467//
468// %0 = shape.cstr_broadcastable %shape0, %shape1
469// %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
470//
471// %2 = shape.cstr_broadcastable %shape3, %shape4
472// %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
473//
474// %4 = shape.assuming_all %0, %1, %2, %3
475//
476// to:
477//
478// %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
479// %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
480// %2 = shape.assuming_all %0, %1
481//
482// In this example if shapes [0, 1, 2] are broadcastable, then it means that
483// shapes [0, 1] are broadcastable too, and can be removed from the list of
484// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
485// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
486struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
487 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
488
489 LogicalResult matchAndRewrite(AssumingAllOp op,
490 PatternRewriter &rewriter) const override {
491 // Collect all `CstrBroadcastableOp` operands first.
492 SetVector<CstrBroadcastableOp> operands;
493 for (Value operand : op.getInputs()) {
494 // TODO: Apply this optimization if some of the witnesses are not
495 // produced by the `cstr_broadcastable`.
496 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
497 if (!broadcastable)
498 return failure();
499
500 operands.insert(X: broadcastable);
501 }
502
503 // Skip trivial `assuming_all` operations.
504 if (operands.size() <= 1)
505 return failure();
506
507 // Collect shapes checked by `cstr_broadcastable` operands.
508 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
509 for (auto cstr : operands) {
510 DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
511 shapes.emplace_back(Args&: cstr, Args: std::move(shapesSet));
512 }
513
514 // Sort by the number of shape operands (larger to smaller).
515 llvm::sort(C&: shapes, Comp: [](auto a, auto b) {
516 return a.first.getNumOperands() > b.first.getNumOperands();
517 });
518
519 // We start from the `cst_broadcastable` operations with largest number of
520 // shape operands, and remove redundant `cst_broadcastable` operations. We
521 // do this until we find a set of `cst_broadcastable` operations with
522 // non-overlapping constraints.
523 SmallVector<CstrBroadcastableOp> markedForErase;
524
525 for (unsigned i = 0; i < shapes.size(); ++i) {
526 auto isSubset = [&](auto pair) {
527 return llvm::set_is_subset(pair.second, shapes[i].second);
528 };
529
530 // Keep redundant `cstr_broadcastable` operations to be erased.
531 auto *it = std::remove_if(first: shapes.begin() + i + 1, last: shapes.end(), pred: isSubset);
532 for (auto *it0 = it; it0 < shapes.end(); ++it0)
533 markedForErase.push_back(Elt: it0->first);
534 shapes.erase(CS: it, CE: shapes.end());
535 }
536
537 // We didn't find any operands that could be removed.
538 if (markedForErase.empty())
539 return failure();
540
541 // Collect non-overlapping `cst_broadcastable` constraints.
542 SmallVector<Value> uniqueConstraints;
543 for (auto &shape : shapes)
544 uniqueConstraints.push_back(Elt: shape.first.getResult());
545
546 // Replace with a new `assuming_all` operation ...
547 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, args&: uniqueConstraints);
548
549 // ... and maybe erase `cstr_broadcastable` ops without uses.
550 for (auto &op : markedForErase)
551 if (op->use_empty())
552 rewriter.eraseOp(op);
553
554 return success();
555 }
556};
557
558struct AssumingAllToCstrEqCanonicalization
559 : public OpRewritePattern<AssumingAllOp> {
560 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
561
562 LogicalResult matchAndRewrite(AssumingAllOp op,
563 PatternRewriter &rewriter) const override {
564 SmallVector<Value, 8> shapes;
565 for (Value w : op.getInputs()) {
566 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
567 if (!cstrEqOp)
568 return failure();
569 bool disjointShapes = llvm::none_of(Range: cstrEqOp.getShapes(), P: [&](Value s) {
570 return llvm::is_contained(Range&: shapes, Element: s);
571 });
572 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
573 return failure();
574 shapes.append(in_start: cstrEqOp.getShapes().begin(), in_end: cstrEqOp.getShapes().end());
575 }
576 rewriter.replaceOpWithNewOp<CstrEqOp>(op, args&: shapes);
577 return success();
578 }
579};
580
581template <typename OpTy>
582struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
583 using OpRewritePattern<OpTy>::OpRewritePattern;
584
585 LogicalResult matchAndRewrite(OpTy op,
586 PatternRewriter &rewriter) const override {
587 // Find unique operands.
588 SetVector<Value> unique(op.operand_begin(), op.operand_end());
589
590 // Reduce op to equivalent with unique operands.
591 if (unique.size() < op.getNumOperands()) {
592 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
593 unique.takeVector(), op->getAttrs());
594 return success();
595 }
596
597 return failure();
598 }
599};
600} // namespace
601
602void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
603 MLIRContext *context) {
604 patterns
605 .add<MergeAssumingAllOps, AssumingAllOneOp,
606 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
607 RemoveDuplicateOperandsPattern<AssumingAllOp>>(arg&: context);
608}
609
610OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
611 // Iterate in reverse to first handle all constant operands. They are
612 // guaranteed to be the tail of the inputs because this is commutative.
613 for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
614 Attribute a = adaptor.getInputs()[idx];
615 // Cannot fold if any inputs are not constant;
616 if (!a)
617 return nullptr;
618
619 // We do not need to keep statically known values after handling them in
620 // this method.
621 getOperation()->eraseOperand(idx);
622
623 // Always false if any input is statically known false
624 if (!llvm::cast<BoolAttr>(Val&: a).getValue())
625 return a;
626 }
627 // If this is reached, all inputs were statically known passing.
628 return BoolAttr::get(context: getContext(), value: true);
629}
630
631LogicalResult AssumingAllOp::verify() {
632 // Ensure that AssumingAllOp contains at least one operand
633 if (getNumOperands() == 0)
634 return emitOpError(message: "no operands specified");
635
636 return success();
637}
638
639//===----------------------------------------------------------------------===//
640// BroadcastOp
641//===----------------------------------------------------------------------===//
642
643OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
644 if (getShapes().size() == 1) {
645 // Otherwise, we need a cast which would be a canonicalization, not folding.
646 if (getShapes().front().getType() != getType())
647 return nullptr;
648 return getShapes().front();
649 }
650
651 if (!adaptor.getShapes().front())
652 return nullptr;
653
654 SmallVector<int64_t, 6> resultShape(
655 llvm::cast<DenseIntElementsAttr>(Val: adaptor.getShapes().front())
656 .getValues<int64_t>());
657
658 for (auto next : adaptor.getShapes().drop_front()) {
659 if (!next)
660 return nullptr;
661 auto nextShape = llvm::to_vector<6>(
662 Range: llvm::cast<DenseIntElementsAttr>(Val&: next).getValues<int64_t>());
663
664 SmallVector<int64_t, 6> tmpShape;
665 // If the shapes are not compatible, we can't fold it.
666 // TODO: Fold to an "error".
667 if (!OpTrait::util::getBroadcastedShape(shape1: resultShape, shape2: nextShape, resultShape&: tmpShape))
668 return nullptr;
669
670 resultShape.clear();
671 std::copy(first: tmpShape.begin(), last: tmpShape.end(),
672 result: std::back_inserter(x&: resultShape));
673 }
674
675 Builder builder(getContext());
676 return builder.getIndexTensorAttr(values: resultShape);
677}
678
679LogicalResult BroadcastOp::verify() {
680 return verifyShapeOrExtentTensorOp(op: *this);
681}
682
683namespace {
684template <typename OpTy>
685struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
686 using OpRewritePattern<OpTy>::OpRewritePattern;
687
688 LogicalResult matchAndRewrite(OpTy op,
689 PatternRewriter &rewriter) const override {
690 auto isPotentiallyNonEmptyShape = [](Value shape) {
691 if (auto extentTensorTy =
692 llvm::dyn_cast<RankedTensorType>(Val: shape.getType())) {
693 if (extentTensorTy.getDimSize(idx: 0) == 0)
694 return false;
695 }
696 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
697 if (constShape.getShape().empty())
698 return false;
699 }
700 return true;
701 };
702 auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
703 isPotentiallyNonEmptyShape);
704
705 // Replace the op with empty shape constant if all operants are reduced to
706 // be empty.
707 if (newOperands.empty()) {
708 rewriter.replaceOpWithNewOp<ConstShapeOp>(
709 op, op->getResultTypes().front(), rewriter.getIndexTensorAttr(values: {}));
710 return success();
711 }
712
713 // Reduce op to equivalent without empty shape operands.
714 if (newOperands.size() < op.getNumOperands()) {
715 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
716 op->getAttrs());
717 return success();
718 }
719
720 return failure();
721 }
722};
723
724struct BroadcastForwardSingleOperandPattern
725 : public OpRewritePattern<BroadcastOp> {
726 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
727
728 LogicalResult matchAndRewrite(BroadcastOp op,
729 PatternRewriter &rewriter) const override {
730 if (op.getNumOperands() != 1)
731 return failure();
732 Value replacement = op.getShapes().front();
733
734 // Insert cast if needed.
735 if (replacement.getType() != op.getType()) {
736 auto loc = op.getLoc();
737 if (llvm::isa<ShapeType>(Val: op.getType())) {
738 replacement = rewriter.create<FromExtentTensorOp>(location: loc, args&: replacement);
739 } else {
740 assert(!llvm::isa<ShapeType>(op.getType()) &&
741 !llvm::isa<ShapeType>(replacement.getType()) &&
742 "expect extent tensor cast");
743 replacement =
744 rewriter.create<tensor::CastOp>(location: loc, args: op.getType(), args&: replacement);
745 }
746 }
747
748 rewriter.replaceOp(op, newValues: replacement);
749 return success();
750 }
751};
752
753struct BroadcastFoldConstantOperandsPattern
754 : public OpRewritePattern<BroadcastOp> {
755 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
756
757 LogicalResult matchAndRewrite(BroadcastOp op,
758 PatternRewriter &rewriter) const override {
759 SmallVector<int64_t, 8> foldedConstantShape;
760 SmallVector<Value, 8> newShapeOperands;
761 for (Value shape : op.getShapes()) {
762 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
763 SmallVector<int64_t, 8> newFoldedConstantShape;
764 if (OpTrait::util::getBroadcastedShape(
765 shape1: foldedConstantShape,
766 shape2: llvm::to_vector<8>(Range: constShape.getShape().getValues<int64_t>()),
767 resultShape&: newFoldedConstantShape)) {
768 foldedConstantShape = newFoldedConstantShape;
769 continue;
770 }
771 }
772 newShapeOperands.push_back(Elt: shape);
773 }
774
775 // Need at least two constant operands to fold anything.
776 if (op.getNumOperands() - newShapeOperands.size() < 2)
777 return failure();
778
779 auto foldedConstantOperandsTy = RankedTensorType::get(
780 shape: {static_cast<int64_t>(foldedConstantShape.size())},
781 elementType: rewriter.getIndexType());
782 newShapeOperands.push_back(Elt: rewriter.create<ConstShapeOp>(
783 location: op.getLoc(), args&: foldedConstantOperandsTy,
784 args: rewriter.getIndexTensorAttr(values: foldedConstantShape)));
785 rewriter.replaceOpWithNewOp<BroadcastOp>(op, args: op.getType(),
786 args&: newShapeOperands);
787 return success();
788 }
789};
790
791template <typename OpTy>
792struct CanonicalizeCastExtentTensorOperandsPattern
793 : public OpRewritePattern<OpTy> {
794 using OpRewritePattern<OpTy>::OpRewritePattern;
795
796 LogicalResult matchAndRewrite(OpTy op,
797 PatternRewriter &rewriter) const override {
798 // Canonicalize operands.
799 bool anyChange = false;
800 auto canonicalizeOperand = [&](Value operand) -> Value {
801 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
802 // Only eliminate the cast if it holds no shape information.
803 bool isInformationLoosingCast =
804 llvm::cast<RankedTensorType>(Val: castOp.getType()).isDynamicDim(idx: 0);
805 if (isInformationLoosingCast) {
806 anyChange = true;
807 return castOp.getSource();
808 }
809 }
810 return operand;
811 };
812 auto newOperands = llvm::to_vector<8>(
813 llvm::map_range(op.getOperands(), canonicalizeOperand));
814
815 // Rewrite op if any change required.
816 if (!anyChange)
817 return failure();
818 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
819 return success();
820 }
821};
822
823struct BroadcastConcretizeResultTypePattern
824 : public OpRewritePattern<BroadcastOp> {
825 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
826
827 LogicalResult matchAndRewrite(BroadcastOp op,
828 PatternRewriter &rewriter) const override {
829 // Only concretize dynamic extent tensor result types.
830 auto resultTy = llvm::dyn_cast<RankedTensorType>(Val: op.getType());
831 if (!resultTy || !resultTy.isDynamicDim(idx: 0))
832 return failure();
833
834 // Infer resulting shape rank if possible.
835 int64_t maxRank = 0;
836 for (Value shape : op.getShapes()) {
837 if (auto extentTensorTy =
838 llvm::dyn_cast<RankedTensorType>(Val: shape.getType())) {
839 // Cannot infer resulting shape rank if any operand is dynamically
840 // ranked.
841 if (extentTensorTy.isDynamicDim(idx: 0))
842 return failure();
843 maxRank = std::max(a: maxRank, b: extentTensorTy.getDimSize(idx: 0));
844 }
845 }
846
847 auto newOp = rewriter.create<BroadcastOp>(
848 location: op.getLoc(), args: getExtentTensorType(ctx: getContext(), rank: maxRank),
849 args: op.getShapes());
850 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, args: op.getType(), args&: newOp);
851 return success();
852 }
853};
854} // namespace
855
856void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
857 MLIRContext *context) {
858 patterns.add<BroadcastConcretizeResultTypePattern,
859 BroadcastFoldConstantOperandsPattern,
860 BroadcastForwardSingleOperandPattern,
861 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
862 RemoveDuplicateOperandsPattern<BroadcastOp>,
863 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(arg&: context);
864}
865
866//===----------------------------------------------------------------------===//
867// ConcatOp
868//===----------------------------------------------------------------------===//
869
870OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
871 if (!adaptor.getLhs() || !adaptor.getRhs())
872 return nullptr;
873 auto lhsShape = llvm::to_vector<6>(
874 Range: llvm::cast<DenseIntElementsAttr>(Val: adaptor.getLhs()).getValues<int64_t>());
875 auto rhsShape = llvm::to_vector<6>(
876 Range: llvm::cast<DenseIntElementsAttr>(Val: adaptor.getRhs()).getValues<int64_t>());
877 SmallVector<int64_t, 6> resultShape;
878 resultShape.append(in_start: lhsShape.begin(), in_end: lhsShape.end());
879 resultShape.append(in_start: rhsShape.begin(), in_end: rhsShape.end());
880 Builder builder(getContext());
881 return builder.getIndexTensorAttr(values: resultShape);
882}
883
884//===----------------------------------------------------------------------===//
885// ConstShapeOp
886//===----------------------------------------------------------------------===//
887
888void ConstShapeOp::print(OpAsmPrinter &p) {
889 p << " ";
890 p.printOptionalAttrDict(attrs: (*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
891 p << "[";
892 interleaveComma(c: getShape().getValues<int64_t>(), os&: p);
893 p << "] : ";
894 p.printType(type: getType());
895}
896
897ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
898 if (parser.parseOptionalAttrDict(result&: result.attributes))
899 return failure();
900 // We piggy-back on ArrayAttr parsing, though we don't internally store the
901 // shape as an ArrayAttr.
902 // TODO: Implement custom parser and maybe make syntax a bit more concise.
903 Attribute extentsRaw;
904 NamedAttrList dummy;
905 if (parser.parseAttribute(result&: extentsRaw, attrName: "dummy", attrs&: dummy))
906 return failure();
907 auto extentsArray = llvm::dyn_cast<ArrayAttr>(Val&: extentsRaw);
908 if (!extentsArray)
909 return failure();
910 SmallVector<int64_t, 6> ints;
911 for (Attribute extent : extentsArray) {
912 IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(Val&: extent);
913 if (!attr)
914 return failure();
915 ints.push_back(Elt: attr.getInt());
916 }
917 Builder &builder = parser.getBuilder();
918 result.addAttribute(name: "shape", attr: builder.getIndexTensorAttr(values: ints));
919 Type resultTy;
920 if (parser.parseColonType(result&: resultTy))
921 return failure();
922 result.types.push_back(Elt: resultTy);
923 return success();
924}
925
926OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
927
928void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
929 MLIRContext *context) {
930 patterns.add<TensorCastConstShape>(arg&: context);
931}
932
933LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
934 MLIRContext *context, std::optional<Location> location,
935 ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
936 Builder b(context);
937 const Properties prop = adaptor.getProperties();
938 inferredReturnTypes.assign(IL: {RankedTensorType::get(
939 shape: {static_cast<int64_t>(prop.shape.size())}, elementType: b.getIndexType())});
940 return success();
941}
942
943bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
944 TypeRange r) {
945 if (l.size() != 1 || r.size() != 1)
946 return false;
947
948 Type lhs = l.front();
949 Type rhs = r.front();
950
951 if (llvm::isa<ShapeType>(Val: lhs) || llvm::isa<ShapeType>(Val: rhs))
952 // Shape type is compatible with all other valid return types.
953 return true;
954 return lhs == rhs;
955}
956
957//===----------------------------------------------------------------------===//
958// CstrBroadcastableOp
959//===----------------------------------------------------------------------===//
960
961void CstrBroadcastableOp::getCanonicalizationPatterns(
962 RewritePatternSet &patterns, MLIRContext *context) {
963 // Canonicalization patterns have overlap with the considerations during
964 // folding in case additional shape information is inferred at some point that
965 // does not result in folding.
966 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
967 CstrBroadcastableEqOps,
968 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
969 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(arg&: context);
970}
971
972// Return true if there is exactly one attribute not representing a scalar
973// broadcast.
974static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
975 bool nonScalarSeen = false;
976 for (Attribute a : attributes) {
977 if (!a || llvm::cast<DenseIntElementsAttr>(Val&: a).getNumElements() != 0) {
978 if (nonScalarSeen)
979 return false;
980 nonScalarSeen = true;
981 }
982 }
983 return true;
984}
985
986OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
987 // No broadcasting is needed if all operands but one are scalar.
988 if (hasAtMostSingleNonScalar(attributes: adaptor.getShapes()))
989 return BoolAttr::get(context: getContext(), value: true);
990
991 if ([&] {
992 SmallVector<SmallVector<int64_t, 6>, 6> extents;
993 for (const auto &operand : adaptor.getShapes()) {
994 if (!operand)
995 return false;
996 extents.push_back(Elt: llvm::to_vector<6>(
997 Range: llvm::cast<DenseIntElementsAttr>(Val: operand).getValues<int64_t>()));
998 }
999 return OpTrait::util::staticallyKnownBroadcastable(shapes: extents);
1000 }())
1001 return BoolAttr::get(context: getContext(), value: true);
1002
1003 // Lastly, see if folding can be completed based on what constraints are known
1004 // on the input shapes.
1005 if ([&] {
1006 SmallVector<SmallVector<int64_t, 6>, 6> extents;
1007 for (auto shapeValue : getShapes()) {
1008 extents.emplace_back();
1009 if (failed(Result: getShapeVec(input: shapeValue, shapeValues&: extents.back())))
1010 return false;
1011 }
1012 return OpTrait::util::staticallyKnownBroadcastable(shapes: extents);
1013 }())
1014 return BoolAttr::get(context: getContext(), value: true);
1015
1016 // Because a failing witness result here represents an eventual assertion
1017 // failure, we do not replace it with a constant witness.
1018 return nullptr;
1019}
1020
1021LogicalResult CstrBroadcastableOp::verify() {
1022 // Ensure that CstrBroadcastableOp contains at least two operands
1023 if (getNumOperands() < 2)
1024 return emitOpError(message: "required at least 2 input shapes");
1025 return success();
1026}
1027
1028//===----------------------------------------------------------------------===//
1029// CstrEqOp
1030//===----------------------------------------------------------------------===//
1031
1032void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1033 MLIRContext *context) {
1034 // If inputs are equal, return passing witness
1035 patterns.add<CstrEqEqOps>(arg&: context);
1036}
1037
1038OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1039 if (llvm::all_of(Range: adaptor.getShapes(), P: [&](Attribute a) {
1040 return a && a == adaptor.getShapes().front();
1041 }))
1042 return BoolAttr::get(context: getContext(), value: true);
1043
1044 // Because a failing witness result here represents an eventual assertion
1045 // failure, we do not try to replace it with a constant witness. Similarly, we
1046 // cannot if there are any non-const inputs.
1047 return nullptr;
1048}
1049
1050//===----------------------------------------------------------------------===//
1051// ConstSizeOp
1052//===----------------------------------------------------------------------===//
1053
1054void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1055 int64_t value) {
1056 build(odsBuilder&: builder, odsState&: result, value: builder.getIndexAttr(value));
1057}
1058
1059OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1060
1061void ConstSizeOp::getAsmResultNames(
1062 llvm::function_ref<void(Value, StringRef)> setNameFn) {
1063 SmallString<4> buffer;
1064 llvm::raw_svector_ostream os(buffer);
1065 os << "c" << getValue();
1066 setNameFn(getResult(), os.str());
1067}
1068
1069//===----------------------------------------------------------------------===//
1070// ConstWitnessOp
1071//===----------------------------------------------------------------------===//
1072
1073OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1074
1075//===----------------------------------------------------------------------===//
1076// CstrRequireOp
1077//===----------------------------------------------------------------------===//
1078
1079OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1080 return adaptor.getPred();
1081}
1082
1083//===----------------------------------------------------------------------===//
1084// DimOp
1085//===----------------------------------------------------------------------===//
1086
1087std::optional<int64_t> DimOp::getConstantIndex() {
1088 if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1089 return constSizeOp.getValue().getLimitedValue();
1090 if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1091 return llvm::cast<IntegerAttr>(Val: constantOp.getValue()).getInt();
1092 return std::nullopt;
1093}
1094
1095OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1096 Type valType = getValue().getType();
1097 auto valShapedType = llvm::dyn_cast<ShapedType>(Val&: valType);
1098 if (!valShapedType || !valShapedType.hasRank())
1099 return nullptr;
1100 std::optional<int64_t> index = getConstantIndex();
1101 if (!index.has_value())
1102 return nullptr;
1103 if (index.value() < 0 || index.value() >= valShapedType.getRank())
1104 return nullptr;
1105 auto extent = valShapedType.getDimSize(idx: *index);
1106 if (ShapedType::isDynamic(dValue: extent))
1107 return nullptr;
1108 return IntegerAttr::get(type: IndexType::get(context: getContext()), value: extent);
1109}
1110
1111LogicalResult mlir::shape::DimOp::inferReturnTypes(
1112 MLIRContext *context, std::optional<Location> location,
1113 DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1114 inferredReturnTypes.assign(IL: {adaptor.getIndex().getType()});
1115 return success();
1116}
1117
1118bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1119 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, rs: r);
1120}
1121
1122//===----------------------------------------------------------------------===//
1123// DivOp
1124//===----------------------------------------------------------------------===//
1125
1126OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1127 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(Val: adaptor.getLhs());
1128 if (!lhs)
1129 return nullptr;
1130 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(Val: adaptor.getRhs());
1131 if (!rhs || rhs.getValue().isZero())
1132 return nullptr;
1133
1134 // Division in APInt does not follow floor(lhs, rhs) when the result is
1135 // negative. Rather, APInt rounds toward zero.
1136 APInt quotient, remainder;
1137 APInt::sdivrem(LHS: lhs.getValue(), RHS: rhs.getValue(), Quotient&: quotient, Remainder&: remainder);
1138 if (quotient.isNegative() && !remainder.isZero()) {
1139 quotient -= 1;
1140 }
1141
1142 Type indexTy = IndexType::get(context: getContext());
1143 return IntegerAttr::get(type: indexTy, value: quotient);
1144}
1145
1146LogicalResult mlir::shape::DivOp::inferReturnTypes(
1147 MLIRContext *context, std::optional<Location> location,
1148 DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1149 if (llvm::isa<SizeType>(Val: adaptor.getLhs().getType()) ||
1150 llvm::isa<SizeType>(Val: adaptor.getRhs().getType()))
1151 inferredReturnTypes.assign(IL: {SizeType::get(ctx: context)});
1152 else
1153 inferredReturnTypes.assign(IL: {IndexType::get(context)});
1154 return success();
1155}
1156
1157bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1158 // SizeType is compatible with IndexType.
1159 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, rs: r);
1160}
1161
1162LogicalResult DivOp::verify() { return verifySizeOrIndexOp(op: *this); }
1163
1164//===----------------------------------------------------------------------===//
1165// ShapeEqOp
1166//===----------------------------------------------------------------------===//
1167
1168OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1169 bool allSame = true;
1170 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1171 return {};
1172 for (Attribute operand : adaptor.getShapes().drop_front()) {
1173 if (!operand)
1174 return {};
1175 allSame = allSame && operand == adaptor.getShapes().front();
1176 }
1177 return BoolAttr::get(context: getContext(), value: allSame);
1178}
1179
1180//===----------------------------------------------------------------------===//
1181// IndexToSizeOp
1182//===----------------------------------------------------------------------===//
1183
1184OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1185 // Constant values of both types, `shape.size` and `index`, are represented as
1186 // `IntegerAttr`s which makes constant folding simple.
1187 if (Attribute arg = adaptor.getArg())
1188 return arg;
1189 return {};
1190}
1191
1192void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1193 MLIRContext *context) {
1194 patterns.add<SizeToIndexToSizeCanonicalization>(arg&: context);
1195}
1196
1197//===----------------------------------------------------------------------===//
1198// FromExtentsOp
1199//===----------------------------------------------------------------------===//
1200
1201OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1202 if (llvm::any_of(Range: adaptor.getExtents(), P: [](Attribute a) { return !a; }))
1203 return nullptr;
1204 SmallVector<int64_t, 6> extents;
1205 for (auto attr : adaptor.getExtents())
1206 extents.push_back(Elt: llvm::cast<IntegerAttr>(Val&: attr).getInt());
1207 Builder builder(getContext());
1208 return builder.getIndexTensorAttr(values: extents);
1209}
1210
1211//===----------------------------------------------------------------------===//
1212// FunctionLibraryOp
1213//===----------------------------------------------------------------------===//
1214
1215void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1216 StringRef name) {
1217 result.attributes.push_back(newAttribute: builder.getNamedAttr(
1218 name: ::mlir::SymbolTable::getSymbolAttrName(), val: builder.getStringAttr(bytes: name)));
1219}
1220
1221FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1222 auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1223 Val: getMapping().get(name: op->getName().getIdentifier()));
1224 if (!attr)
1225 return nullptr;
1226 return lookupSymbol<FuncOp>(symbol: attr);
1227}
1228
1229ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1230 OperationState &result) {
1231 // Parse the op name.
1232 StringAttr nameAttr;
1233 if (parser.parseSymbolName(result&: nameAttr, attrName: ::mlir::SymbolTable::getSymbolAttrName(),
1234 attrs&: result.attributes))
1235 return failure();
1236
1237 if (parser.parseOptionalAttrDictWithKeyword(result&: result.attributes))
1238 return failure();
1239
1240 auto *bodyRegion = result.addRegion();
1241 if (parser.parseRegion(region&: *bodyRegion))
1242 return failure();
1243
1244 if (parser.parseKeyword(keyword: "mapping"))
1245 return failure();
1246
1247 DictionaryAttr mappingAttr;
1248 if (parser.parseAttribute(result&: mappingAttr,
1249 type: parser.getBuilder().getType<NoneType>(), attrName: "mapping",
1250 attrs&: result.attributes))
1251 return failure();
1252 return success();
1253}
1254
1255void FunctionLibraryOp::print(OpAsmPrinter &p) {
1256 p << ' ';
1257 p.printSymbolName(symbolRef: getName());
1258 p.printOptionalAttrDictWithKeyword(
1259 attrs: (*this)->getAttrs(), elidedAttrs: {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1260 p << ' ';
1261 p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false,
1262 /*printBlockTerminators=*/false);
1263 p << " mapping ";
1264 p.printAttributeWithoutType(attr: getMappingAttr());
1265}
1266
1267//===----------------------------------------------------------------------===//
1268// FuncOp
1269//===----------------------------------------------------------------------===//
1270
1271FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1272 ArrayRef<NamedAttribute> attrs) {
1273 OpBuilder builder(location->getContext());
1274 OperationState state(location, getOperationName());
1275 FuncOp::build(odsBuilder&: builder, odsState&: state, name, type, attrs);
1276 return cast<FuncOp>(Val: Operation::create(state));
1277}
1278FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1279 Operation::dialect_attr_range attrs) {
1280 SmallVector<NamedAttribute, 8> attrRef(attrs);
1281 return create(location, name, type, attrs: llvm::ArrayRef(attrRef));
1282}
1283FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1284 ArrayRef<NamedAttribute> attrs,
1285 ArrayRef<DictionaryAttr> argAttrs) {
1286 FuncOp func = create(location, name, type, attrs);
1287 func.setAllArgAttrs(argAttrs);
1288 return func;
1289}
1290
1291void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1292 FunctionType type, ArrayRef<NamedAttribute> attrs,
1293 ArrayRef<DictionaryAttr> argAttrs) {
1294 state.addAttribute(name: FuncOp::getSymNameAttrName(name: state.name),
1295 attr: builder.getStringAttr(bytes: name));
1296 state.addAttribute(name: FuncOp::getFunctionTypeAttrName(name: state.name),
1297 attr: TypeAttr::get(type));
1298 state.attributes.append(inStart: attrs.begin(), inEnd: attrs.end());
1299 state.addRegion();
1300
1301 if (argAttrs.empty())
1302 return;
1303 assert(type.getNumInputs() == argAttrs.size());
1304 call_interface_impl::addArgAndResultAttrs(
1305 builder, result&: state, argAttrs, /*resultAttrs=*/{},
1306 argAttrsName: getArgAttrsAttrName(name: state.name), resAttrsName: getResAttrsAttrName(name: state.name));
1307}
1308
1309ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1310 auto buildFuncType =
1311 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1312 function_interface_impl::VariadicFlag,
1313 std::string &) { return builder.getFunctionType(inputs: argTypes, results); };
1314
1315 return function_interface_impl::parseFunctionOp(
1316 parser, result, /*allowVariadic=*/false,
1317 typeAttrName: getFunctionTypeAttrName(name: result.name), funcTypeBuilder: buildFuncType,
1318 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
1319}
1320
1321void FuncOp::print(OpAsmPrinter &p) {
1322 function_interface_impl::printFunctionOp(
1323 p, op: *this, /*isVariadic=*/false, typeAttrName: getFunctionTypeAttrName(),
1324 argAttrsName: getArgAttrsAttrName(), resAttrsName: getResAttrsAttrName());
1325}
1326
1327//===----------------------------------------------------------------------===//
1328// GetExtentOp
1329//===----------------------------------------------------------------------===//
1330
1331std::optional<int64_t> GetExtentOp::getConstantDim() {
1332 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1333 return constSizeOp.getValue().getLimitedValue();
1334 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1335 return llvm::cast<IntegerAttr>(Val: constantOp.getValue()).getInt();
1336 return std::nullopt;
1337}
1338
1339OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1340 auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(Val: adaptor.getShape());
1341 if (!elements)
1342 return nullptr;
1343 std::optional<int64_t> dim = getConstantDim();
1344 if (!dim.has_value())
1345 return nullptr;
1346 if (dim.value() >= elements.getNumElements())
1347 return nullptr;
1348 return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1349}
1350
1351void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1352 int64_t dim) {
1353 auto loc = result.location;
1354 auto dimAttr = builder.getIndexAttr(value: dim);
1355 if (llvm::isa<ShapeType>(Val: shape.getType())) {
1356 Value dim = builder.create<ConstSizeOp>(location: loc, args&: dimAttr);
1357 build(odsBuilder&: builder, odsState&: result, extent: builder.getType<SizeType>(), shape, dim);
1358 } else {
1359 Value dim =
1360 builder.create<arith::ConstantOp>(location: loc, args: builder.getIndexType(), args&: dimAttr);
1361 build(odsBuilder&: builder, odsState&: result, extent: builder.getIndexType(), shape, dim);
1362 }
1363}
1364
1365LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1366 MLIRContext *context, std::optional<Location> location,
1367 GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1368 inferredReturnTypes.assign(IL: {IndexType::get(context)});
1369 return success();
1370}
1371
1372bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1373 TypeRange r) {
1374 // SizeType is compatible with IndexType.
1375 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, rs: r);
1376}
1377
1378LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(op: *this); }
1379
1380//===----------------------------------------------------------------------===//
1381// IsBroadcastableOp
1382//===----------------------------------------------------------------------===//
1383
1384void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1385 MLIRContext *context) {
1386 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(arg&: context);
1387}
1388
1389OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1390 // Can always broadcast fewer than two shapes.
1391 if (adaptor.getShapes().size() < 2) {
1392 return BoolAttr::get(context: getContext(), value: true);
1393 }
1394
1395 return nullptr;
1396}
1397
1398//===----------------------------------------------------------------------===//
1399// MeetOp
1400//===----------------------------------------------------------------------===//
1401
1402LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1403 MLIRContext *context, std::optional<Location> location,
1404 MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1405 if (adaptor.getOperands().empty())
1406 return failure();
1407
1408 auto isShapeType = [](Type arg) {
1409 if (llvm::isa<ShapeType>(Val: arg))
1410 return true;
1411 return isExtentTensorType(type: arg);
1412 };
1413
1414 ValueRange::type_range types = adaptor.getOperands().getTypes();
1415 Type acc = types.front();
1416 for (auto t : drop_begin(RangeOrContainer&: types)) {
1417 Type l = acc, r = t;
1418 if (!llvm::isa<ShapeType, SizeType>(Val: l))
1419 std::swap(a&: l, b&: r);
1420
1421 // Handle sizes, propagate error type if present.
1422 if (llvm::isa<SizeType>(Val: l)) {
1423 if (llvm::isa<SizeType, IndexType>(Val: r))
1424 acc = l;
1425 else
1426 return emitOptionalError(loc: location, args: "requires all sizes or shapes");
1427 } else if (llvm::isa<IndexType>(Val: l)) {
1428 if (llvm::isa<IndexType>(Val: r))
1429 acc = r;
1430 else
1431 return emitOptionalError(loc: location, args: "requires all sizes or shapes");
1432 } else if (llvm::isa<ShapeType>(Val: l)) {
1433 // Handle shapes, propagate error type if present.
1434 if (isShapeType(r))
1435 acc = l;
1436 else
1437 return emitOptionalError(loc: location, args: "requires all sizes or shapes");
1438 } else if (isExtentTensorType(type: l)) {
1439 auto rank1 = llvm::cast<RankedTensorType>(Val&: l).getShape()[0];
1440 auto rank2 = llvm::cast<RankedTensorType>(Val&: r).getShape()[0];
1441 if (ShapedType::isDynamic(dValue: rank1))
1442 acc = l;
1443 else if (ShapedType::isDynamic(dValue: rank2))
1444 acc = r;
1445 else if (rank1 != rank2)
1446 return emitOptionalError(loc: location, args: "unequal shape cardinality");
1447 else
1448 acc = l;
1449 }
1450 }
1451 inferredReturnTypes.assign(IL: {acc});
1452 return success();
1453}
1454
1455bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1456 if (l.size() != 1 || r.size() != 1)
1457 return false;
1458 if (l == r)
1459 return true;
1460
1461 Type lhs = l.front();
1462 Type rhs = r.front();
1463
1464 if (!llvm::isa<ShapeType, SizeType>(Val: lhs))
1465 std::swap(a&: lhs, b&: rhs);
1466
1467 if (llvm::isa<SizeType>(Val: lhs))
1468 return llvm::isa<SizeType, IndexType>(Val: rhs);
1469 if (llvm::isa<ShapeType>(Val: lhs))
1470 return llvm::isa<ShapeType, TensorType>(Val: rhs);
1471
1472 if (succeeded(Result: verifyCompatibleShapes(types: {lhs, rhs})))
1473 return true;
1474 return false;
1475}
1476
1477//===----------------------------------------------------------------------===//
1478// RankOp
1479//===----------------------------------------------------------------------===//
1480
1481OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1482 auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(Val: adaptor.getShape());
1483 if (!shape)
1484 return {};
1485 int64_t rank = shape.getNumElements();
1486 Builder builder(getContext());
1487 return builder.getIndexAttr(value: rank);
1488}
1489
1490/// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1491/// Constant folding fails in cases where only the rank is constant, not the
1492/// shape itself.
1493/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1494///
1495/// Example:
1496///
1497/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1498/// %rank = shape.rank %shape
1499///
1500/// becomes
1501///
1502/// %rank = shape.const_size 3
1503
1504namespace {
1505struct RankShapeOfCanonicalizationPattern
1506 : public OpRewritePattern<shape::RankOp> {
1507 using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1508
1509 LogicalResult matchAndRewrite(shape::RankOp op,
1510 PatternRewriter &rewriter) const override {
1511 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1512 if (!shapeOfOp)
1513 return failure();
1514 auto rankedTensorType =
1515 llvm::dyn_cast<RankedTensorType>(Val: shapeOfOp.getArg().getType());
1516 if (!rankedTensorType)
1517 return failure();
1518 int64_t rank = rankedTensorType.getRank();
1519 if (llvm::isa<IndexType>(Val: op.getType())) {
1520 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op: op.getOperation(),
1521 args&: rank);
1522 } else if (llvm::isa<shape::SizeType>(Val: op.getType())) {
1523 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op: op.getOperation(), args&: rank);
1524 } else {
1525 return failure();
1526 }
1527 return success();
1528 }
1529};
1530} // namespace
1531
1532void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1533 MLIRContext *context) {
1534 patterns.add<RankShapeOfCanonicalizationPattern>(arg&: context);
1535}
1536
1537LogicalResult mlir::shape::RankOp::inferReturnTypes(
1538 MLIRContext *context, std::optional<Location> location,
1539 RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1540 if (llvm::isa<ShapeType>(Val: adaptor.getShape().getType()))
1541 inferredReturnTypes.assign(IL: {SizeType::get(ctx: context)});
1542 else
1543 inferredReturnTypes.assign(IL: {IndexType::get(context)});
1544 return success();
1545}
1546
1547bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1548 // SizeType is compatible with IndexType.
1549 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, rs: r);
1550}
1551
1552LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(op: *this); }
1553
1554//===----------------------------------------------------------------------===//
1555// NumElementsOp
1556//===----------------------------------------------------------------------===//
1557
1558OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1559
1560 // Fold only when argument constant.
1561 Attribute shape = adaptor.getShape();
1562 if (!shape)
1563 return {};
1564
1565 APInt product(64, 1);
1566 for (auto value : llvm::cast<DenseIntElementsAttr>(Val&: shape))
1567 product *= value;
1568 Builder builder(getContext());
1569 return builder.getIndexAttr(value: product.getLimitedValue());
1570}
1571
1572LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1573 MLIRContext *context, std::optional<Location> location,
1574 NumElementsOp::Adaptor adaptor,
1575 SmallVectorImpl<Type> &inferredReturnTypes) {
1576 if (llvm::isa<ShapeType>(Val: adaptor.getShape().getType()))
1577 inferredReturnTypes.assign(IL: {SizeType::get(ctx: context)});
1578 else
1579 inferredReturnTypes.assign(IL: {IndexType::get(context)});
1580 return success();
1581}
1582
1583bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1584 TypeRange r) {
1585 // SizeType is compatible with IndexType.
1586 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, rs: r);
1587}
1588
1589LogicalResult shape::NumElementsOp::verify() {
1590 return verifySizeOrIndexOp(op: *this);
1591}
1592
1593//===----------------------------------------------------------------------===//
1594// MaxOp
1595//===----------------------------------------------------------------------===//
1596
1597OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1598 // If operands are equal, just propagate one.
1599 if (getLhs() == getRhs())
1600 return getLhs();
1601 return nullptr;
1602}
1603
1604LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1605 MLIRContext *context, std::optional<Location> location,
1606 MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1607 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1608 inferredReturnTypes.assign(IL: {adaptor.getLhs().getType()});
1609 else
1610 inferredReturnTypes.assign(IL: {SizeType::get(ctx: context)});
1611 return success();
1612}
1613
1614bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1615 if (l.size() != 1 || r.size() != 1)
1616 return false;
1617 if (llvm::isa<ShapeType>(Val: l.front()) && llvm::isa<ShapeType>(Val: r.front()))
1618 return true;
1619 if (llvm::isa<SizeType>(Val: l.front()) && llvm::isa<SizeType>(Val: r.front()))
1620 return true;
1621 return false;
1622}
1623
1624//===----------------------------------------------------------------------===//
1625// MinOp
1626//===----------------------------------------------------------------------===//
1627
1628OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1629 // If operands are equal, just propagate one.
1630 if (getLhs() == getRhs())
1631 return getLhs();
1632 return nullptr;
1633}
1634
1635LogicalResult mlir::shape::MinOp::inferReturnTypes(
1636 MLIRContext *context, std::optional<Location> location,
1637 MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1638 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1639 inferredReturnTypes.assign(IL: {adaptor.getLhs().getType()});
1640 else
1641 inferredReturnTypes.assign(IL: {SizeType::get(ctx: context)});
1642 return success();
1643}
1644
1645bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1646 if (l.size() != 1 || r.size() != 1)
1647 return false;
1648 if (llvm::isa<ShapeType>(Val: l.front()) && llvm::isa<ShapeType>(Val: r.front()))
1649 return true;
1650 if (llvm::isa<SizeType>(Val: l.front()) && llvm::isa<SizeType>(Val: r.front()))
1651 return true;
1652 return false;
1653}
1654
1655//===----------------------------------------------------------------------===//
1656// MulOp
1657//===----------------------------------------------------------------------===//
1658
1659OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1660 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(Val: adaptor.getLhs());
1661 if (!lhs)
1662 return nullptr;
1663 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(Val: adaptor.getRhs());
1664 if (!rhs)
1665 return nullptr;
1666 APInt folded = lhs.getValue() * rhs.getValue();
1667 Type indexTy = IndexType::get(context: getContext());
1668 return IntegerAttr::get(type: indexTy, value: folded);
1669}
1670
1671LogicalResult mlir::shape::MulOp::inferReturnTypes(
1672 MLIRContext *context, std::optional<Location> location,
1673 MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1674 if (llvm::isa<SizeType>(Val: adaptor.getLhs().getType()) ||
1675 llvm::isa<SizeType>(Val: adaptor.getRhs().getType()))
1676 inferredReturnTypes.assign(IL: {SizeType::get(ctx: context)});
1677 else
1678 inferredReturnTypes.assign(IL: {IndexType::get(context)});
1679 return success();
1680}
1681
1682bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1683 // SizeType is compatible with IndexType.
1684 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, rs: r);
1685}
1686
1687LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(op: *this); }
1688
1689//===----------------------------------------------------------------------===//
1690// ShapeOfOp
1691//===----------------------------------------------------------------------===//
1692
1693namespace {
1694/// Replace shape_of(x) where x has a constant shape with a const_shape op.
1695struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1696 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1697
1698 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1699 PatternRewriter &rewriter) const override {
1700 auto type = llvm::dyn_cast<ShapedType>(Val: op.getArg().getType());
1701 if (!type || !type.hasStaticShape())
1702 return failure();
1703 Location loc = op.getLoc();
1704 Value constShape =
1705 rewriter
1706 .create<ConstShapeOp>(location: loc,
1707 args: rewriter.getIndexTensorAttr(values: type.getShape()))
1708 .getResult();
1709 if (constShape.getType() != op.getResult().getType())
1710 constShape = rewriter.create<tensor::CastOp>(
1711 location: loc, args: op.getResult().getType(), args&: constShape);
1712 rewriter.replaceOp(op, newValues: constShape);
1713 return success();
1714 }
1715};
1716
1717// Canonicalize
1718//
1719// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1720// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1721//
1722// to
1723//
1724// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1725// %1 = %shape
1726//
1727struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
1728 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1729
1730 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1731 PatternRewriter &rewriter) const override {
1732 auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1733 if (!tensorReshapeOp)
1734 return rewriter.notifyMatchFailure(arg&: op, msg: "producer is not tensor.reshape");
1735 if (!isa<TensorType>(Val: op.getType()))
1736 return rewriter.notifyMatchFailure(arg&: op, msg: "result is not a tensor");
1737
1738 // Operand 'shape' of 'tensor.reshape' may now be used as the result of
1739 // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1740 // formed IR, it may not be identical (dynamically vs statically shaped),
1741 // in which case it needs to be cast first using 'tensor.cast'.
1742 // Additionally, it may not have identical element type (i32 vs index)
1743 // while it has identical shaped type (dynamic vs static), in which case it
1744 // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
1745 // op result must be shape or extent tensor.
1746 Value shape = tensorReshapeOp.getShape();
1747
1748 auto opTensorTy = cast<RankedTensorType>(Val: op.getType());
1749 auto shapeTensorTy = cast<RankedTensorType>(Val: shape.getType());
1750
1751 if (opTensorTy != shapeTensorTy) {
1752 if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1753 shape = rewriter.create<tensor::CastOp>(location: op.getLoc(), args&: opTensorTy, args&: shape);
1754 else if (!isExtentTensorType(type: shapeTensorTy))
1755 shape =
1756 rewriter.create<arith::IndexCastOp>(location: op.getLoc(), args&: opTensorTy, args&: shape);
1757 }
1758
1759 rewriter.replaceOp(op, newValues: shape);
1760 return success();
1761 }
1762};
1763
1764// Canonicalize
1765// ```
1766// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1767// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1768// ```
1769// to
1770// ```
1771// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1772// ```
1773struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1774 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1775
1776 LogicalResult matchAndRewrite(tensor::CastOp op,
1777 PatternRewriter &rewriter) const override {
1778 auto ty = llvm::dyn_cast<RankedTensorType>(Val: op.getType());
1779 if (!ty || ty.getRank() != 1)
1780 return failure();
1781
1782 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1783 if (!shapeOfOp)
1784 return failure();
1785
1786 // Argument type must be ranked and must not conflict.
1787 auto argTy = llvm::dyn_cast<RankedTensorType>(Val: shapeOfOp.getArg().getType());
1788 if (!argTy || (!ty.isDynamicDim(idx: 0) && ty.getDimSize(idx: 0) != argTy.getRank()))
1789 return failure();
1790
1791 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, args&: ty, args: shapeOfOp.getArg());
1792 return success();
1793 }
1794};
1795} // namespace
1796
1797void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1798 MLIRContext *context) {
1799 patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1800 ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1801 arg&: context);
1802}
1803
1804LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1805 MLIRContext *context, std::optional<Location> location,
1806 ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1807 if (llvm::isa<ValueShapeType>(Val: adaptor.getArg().getType()))
1808 inferredReturnTypes.assign(IL: {ShapeType::get(ctx: context)});
1809 else {
1810 auto shapedTy = llvm::cast<ShapedType>(Val: adaptor.getArg().getType());
1811 int64_t rank =
1812 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1813 Type indexTy = IndexType::get(context);
1814 Type extentTensorTy = RankedTensorType::get(shape: {rank}, elementType: indexTy);
1815 inferredReturnTypes.assign(IL: {extentTensorTy});
1816 }
1817 return success();
1818}
1819
1820bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1821 if (l.size() != 1 || r.size() != 1)
1822 return false;
1823 if (l == r)
1824 return true;
1825
1826 Type lhs = l.front();
1827 Type rhs = r.front();
1828
1829 if (!llvm::isa<ShapeType, ShapedType>(Val: lhs) ||
1830 !llvm::isa<ShapeType, ShapedType>(Val: rhs))
1831 return false;
1832
1833 if (llvm::isa<ShapeType>(Val: lhs) || llvm::isa<ShapeType>(Val: rhs))
1834 // Shape type is compatible with all other valid return types.
1835 return true;
1836
1837 if (succeeded(Result: verifyCompatibleShapes(types: {lhs, rhs})))
1838 return true;
1839 return false;
1840}
1841
1842LogicalResult shape::ShapeOfOp::verify() {
1843 return verifyShapeOrExtentTensorOp(op: *this);
1844}
1845
1846//===----------------------------------------------------------------------===//
1847// SizeToIndexOp
1848//===----------------------------------------------------------------------===//
1849
1850OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1851 // Constant values of both types, `shape.size` and `index`, are represented as
1852 // `IntegerAttr`s which makes constant folding simple.
1853 if (Attribute arg = adaptor.getArg())
1854 return arg;
1855 return OpFoldResult();
1856}
1857
1858void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1859 MLIRContext *context) {
1860 patterns.add<IndexToSizeToIndexCanonicalization>(arg&: context);
1861}
1862
1863bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1864 if (inputs.size() != 1 || outputs.size() != 1)
1865 return false;
1866 return llvm::isa<IndexType, SizeType>(Val: inputs[0]) &&
1867 llvm::isa<IndexType>(Val: outputs[0]);
1868}
1869
1870//===----------------------------------------------------------------------===//
1871// YieldOp
1872//===----------------------------------------------------------------------===//
1873
1874LogicalResult shape::YieldOp::verify() {
1875 auto *parentOp = (*this)->getParentOp();
1876 auto results = parentOp->getResults();
1877 auto operands = getOperands();
1878
1879 if (parentOp->getNumResults() != getNumOperands())
1880 return emitOpError() << "number of operands does not match number of "
1881 "results of its parent";
1882 for (auto e : llvm::zip(t&: results, u&: operands))
1883 if (std::get<0>(t&: e).getType() != std::get<1>(t&: e).getType())
1884 return emitOpError() << "types mismatch between yield op and its parent";
1885
1886 return success();
1887}
1888
1889//===----------------------------------------------------------------------===//
1890// SplitAtOp
1891//===----------------------------------------------------------------------===//
1892
1893LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1894 SmallVectorImpl<OpFoldResult> &results) {
1895 if (!adaptor.getOperand() || !adaptor.getIndex())
1896 return failure();
1897 auto shapeVec = llvm::to_vector<6>(
1898 Range: llvm::cast<DenseIntElementsAttr>(Val: adaptor.getOperand()).getValues<int64_t>());
1899 auto shape = llvm::ArrayRef(shapeVec);
1900 auto splitPoint = llvm::cast<IntegerAttr>(Val: adaptor.getIndex()).getInt();
1901 // Verify that the split point is in the correct range.
1902 // TODO: Constant fold to an "error".
1903 int64_t rank = shape.size();
1904 if (-rank > splitPoint || splitPoint > rank)
1905 return failure();
1906 if (splitPoint < 0)
1907 splitPoint += shape.size();
1908 Builder builder(adaptor.getOperand().getContext());
1909 results.push_back(Elt: builder.getIndexTensorAttr(values: shape.take_front(N: splitPoint)));
1910 results.push_back(Elt: builder.getIndexTensorAttr(values: shape.drop_front(N: splitPoint)));
1911 return success();
1912}
1913
1914//===----------------------------------------------------------------------===//
1915// ToExtentTensorOp
1916//===----------------------------------------------------------------------===//
1917
1918OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1919 if (!adaptor.getInput())
1920 return OpFoldResult();
1921 Builder builder(getContext());
1922 auto shape = llvm::to_vector<6>(
1923 Range: llvm::cast<DenseIntElementsAttr>(Val: adaptor.getInput()).getValues<int64_t>());
1924 auto type = RankedTensorType::get(shape: {static_cast<int64_t>(shape.size())},
1925 elementType: builder.getIndexType());
1926 return DenseIntElementsAttr::get(type, arg&: shape);
1927}
1928
1929bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1930 if (inputs.size() != 1 || outputs.size() != 1)
1931 return false;
1932 if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(Val: inputs[0])) {
1933 if (!llvm::isa<IndexType>(Val: inputTensor.getElementType()) ||
1934 inputTensor.getRank() != 1)
1935 return false;
1936 } else if (!llvm::isa<ShapeType>(Val: inputs[0])) {
1937 return false;
1938 }
1939
1940 TensorType outputTensor = llvm::dyn_cast<TensorType>(Val: outputs[0]);
1941 return outputTensor && llvm::isa<IndexType>(Val: outputTensor.getElementType());
1942}
1943
1944//===----------------------------------------------------------------------===//
1945// ReduceOp
1946//===----------------------------------------------------------------------===//
1947
1948void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1949 ValueRange initVals) {
1950 OpBuilder::InsertionGuard g(builder);
1951 result.addOperands(newOperands: shape);
1952 result.addOperands(newOperands: initVals);
1953
1954 Region *bodyRegion = result.addRegion();
1955 Block *bodyBlock = builder.createBlock(
1956 parent: bodyRegion, /*insertPt=*/{}, argTypes: builder.getIndexType(), locs: result.location);
1957
1958 Type elementType;
1959 if (auto tensorType = llvm::dyn_cast<TensorType>(Val: shape.getType()))
1960 elementType = tensorType.getElementType();
1961 else
1962 elementType = SizeType::get(ctx: builder.getContext());
1963 bodyBlock->addArgument(type: elementType, loc: shape.getLoc());
1964
1965 for (Value initVal : initVals) {
1966 bodyBlock->addArgument(type: initVal.getType(), loc: initVal.getLoc());
1967 result.addTypes(newTypes: initVal.getType());
1968 }
1969}
1970
1971LogicalResult ReduceOp::verify() {
1972 // Verify block arg types.
1973 Block &block = getRegion().front();
1974
1975 // The block takes index, extent, and aggregated values as arguments.
1976 auto blockArgsCount = getInitVals().size() + 2;
1977 if (block.getNumArguments() != blockArgsCount)
1978 return emitOpError() << "ReduceOp body is expected to have "
1979 << blockArgsCount << " arguments";
1980
1981 // The first block argument is the index and must always be of type `index`.
1982 if (!llvm::isa<IndexType>(Val: block.getArgument(i: 0).getType()))
1983 return emitOpError(
1984 message: "argument 0 of ReduceOp body is expected to be of IndexType");
1985
1986 // The second block argument is the extent and must be of type `size` or
1987 // `index`, depending on whether the reduce operation is applied to a shape or
1988 // to an extent tensor.
1989 Type extentTy = block.getArgument(i: 1).getType();
1990 if (llvm::isa<ShapeType>(Val: getShape().getType())) {
1991 if (!llvm::isa<SizeType>(Val: extentTy))
1992 return emitOpError(message: "argument 1 of ReduceOp body is expected to be of "
1993 "SizeType if the ReduceOp operates on a ShapeType");
1994 } else {
1995 if (!llvm::isa<IndexType>(Val: extentTy))
1996 return emitOpError(
1997 message: "argument 1 of ReduceOp body is expected to be of IndexType if the "
1998 "ReduceOp operates on an extent tensor");
1999 }
2000
2001 for (const auto &type : llvm::enumerate(First: getInitVals()))
2002 if (block.getArgument(i: type.index() + 2).getType() != type.value().getType())
2003 return emitOpError() << "type mismatch between argument "
2004 << type.index() + 2
2005 << " of ReduceOp body and initial value "
2006 << type.index();
2007 return success();
2008}
2009
2010ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2011 // Parse operands.
2012 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2013 Type shapeOrExtentTensorType;
2014 if (parser.parseOperandList(result&: operands, /*requiredOperandCount=*/-1,
2015 delimiter: OpAsmParser::Delimiter::Paren) ||
2016 parser.parseColonType(result&: shapeOrExtentTensorType) ||
2017 parser.parseOptionalArrowTypeList(result&: result.types))
2018 return failure();
2019
2020 // Resolve operands.
2021 auto initVals = llvm::ArrayRef(operands).drop_front();
2022 if (parser.resolveOperand(operand: operands.front(), type: shapeOrExtentTensorType,
2023 result&: result.operands) ||
2024 parser.resolveOperands(operands&: initVals, types&: result.types, loc: parser.getNameLoc(),
2025 result&: result.operands))
2026 return failure();
2027
2028 // Parse the body.
2029 Region *body = result.addRegion();
2030 if (parser.parseRegion(region&: *body, /*args=*/arguments: {}, /*argTypes=*/enableNameShadowing: {}))
2031 return failure();
2032
2033 // Parse attributes.
2034 if (parser.parseOptionalAttrDict(result&: result.attributes))
2035 return failure();
2036
2037 return success();
2038}
2039
2040void ReduceOp::print(OpAsmPrinter &p) {
2041 p << '(' << getShape() << ", " << getInitVals()
2042 << ") : " << getShape().getType();
2043 p.printOptionalArrowTypeList(types: getResultTypes());
2044 p << ' ';
2045 p.printRegion(blocks&: getRegion());
2046 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
2047}
2048
2049#define GET_OP_CLASSES
2050#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2051
2052#define GET_TYPEDEF_CLASSES
2053#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
2054

source code of mlir/lib/Dialect/Shape/IR/Shape.cpp